WrapperSubsetEval.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 15k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    WrapperSubsetEval.java
  18.  *    Copyright (C) 1999 Mark Hall
  19.  *
  20.  */
  21. package  weka.attributeSelection;
  22. import  java.io.*;
  23. import  java.util.*;
  24. import  weka.core.*;
  25. import  weka.classifiers.*;
  26. import  weka.classifiers.rules.ZeroR;
  27. import  weka.filters.unsupervised.attribute.Remove;
  28. import  weka.filters.Filter;
  29. /** 
  30.  * Wrapper attribute subset evaluator. <p>
  31.  * For more information see: <br>
  32.  * 
  33.  * Kohavi, R., John G., Wrappers for Feature Subset Selection. 
  34.  * In <i>Artificial Intelligence journal</i>, special issue on relevance, 
  35.  * Vol. 97, Nos 1-2, pp.273-324. <p>
  36.  *
  37.  * Valid options are:<p>
  38.  *
  39.  * -B <base learner> <br>
  40.  * Class name of base learner to use for accuracy estimation.
  41.  * Place any classifier options last on the command line following a
  42.  * "--". Eg  -B weka.classifiers.bayes.NaiveBayes ... -- -K <p>
  43.  *
  44.  * -F <num> <br>
  45.  * Number of cross validation folds to use for estimating accuracy.
  46.  * <default=5> <p>
  47.  *
  48.  * -T <num> <br>
  49.  * Threshold by which to execute another cross validation (standard deviation
  50.  * ---expressed as a percentage of the mean). <p>
  51.  *
  52.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  53.  * @version $Revision: 1.17 $
  54.  */
  55. public class WrapperSubsetEval
  56.   extends SubsetEvaluator
  57.   implements OptionHandler
  58. {
  59.   /** training instances */
  60.   private Instances m_trainInstances;
  61.   /** class index */
  62.   private int m_classIndex;
  63.   /** number of attributes in the training data */
  64.   private int m_numAttribs;
  65.   /** number of instances in the training data */
  66.   private int m_numInstances;
  67.   /** holds an evaluation object */
  68.   private Evaluation m_Evaluation;
  69.   /** holds the base classifier object */
  70.   private Classifier m_BaseClassifier;
  71.   /** number of folds to use for cross validation */
  72.   private int m_folds;
  73.   /** random number seed */
  74.   private int m_seed;
  75.   /** 
  76.    * the threshold by which to do further cross validations when
  77.    * estimating the accuracy of a subset
  78.    */
  79.   private double m_threshold;
  80.   /**
  81.    * Returns a string describing this attribute evaluator
  82.    * @return a description of the evaluator suitable for
  83.    * displaying in the explorer/experimenter gui
  84.    */
  85.   public String globalInfo() {
  86.     return "WrapperSubsetEval:nn"
  87.       +"Evaluates attribute sets by using a learning scheme. Cross "
  88.       +"validation is used to estimate the accuracy of the learning "
  89.       +"scheme for a set of attributes.n";
  90.   }
  91.   /**
  92.    * Constructor. Calls restOptions to set default options
  93.    **/
  94.   public WrapperSubsetEval () {
  95.     resetOptions();
  96.   }
  97.   /**
  98.    * Returns an enumeration describing the available options.
  99.    * @return an enumeration of all the available options.
  100.    **/
  101.   public Enumeration listOptions () {
  102.     Vector newVector = new Vector(4);
  103.     newVector.addElement(new Option("tclass name of base learner to use for" 
  104.     + "ntaccuracy estimation. Place any" 
  105.     + "ntclassifier options LAST on the" 
  106.     + "ntcommand line following a "--"." 
  107.     + "nteg. -B weka.classifiers.bayes.NaiveBayes ... " 
  108.     + "-- -K", "B", 1, "-B <base learner>"));
  109.     newVector.addElement(new Option("tnumber of cross validation folds to " 
  110.     + "usentfor estimating accuracy." 
  111.     + "nt(default=5)", "F", 1, "-F <num>"));
  112.     newVector.addElement(new Option("tSeed for cross validation accuracy "
  113.     +"ntestimation."
  114.     +"nt(default = 1)", "S", 1,"-S <seed>"));
  115.     newVector.addElement(new Option("tthreshold by which to execute " 
  116.     + "another cross validation" 
  117.     + "nt(standard deviation---" 
  118.     + "expressed as a percentage of the " 
  119.     + "mean).nt(default=0.01(1%))"
  120.     , "T", 1, "-T <num>"));
  121.     if ((m_BaseClassifier != null) && 
  122. (m_BaseClassifier instanceof OptionHandler)) {
  123.       newVector.addElement(new Option("", "", 0, "nOptions specific to" 
  124.       + "scheme " 
  125.       + m_BaseClassifier.getClass().getName() 
  126.       + ":"));
  127.       Enumeration enum = ((OptionHandler)m_BaseClassifier).listOptions();
  128.       while (enum.hasMoreElements()) {
  129.         newVector.addElement(enum.nextElement());
  130.       }
  131.     }
  132.     return  newVector.elements();
  133.   }
  134.   /**
  135.    * Parses a given list of options.
  136.    *
  137.    * Valid options are:<p>
  138.    *
  139.    * -B <base learner> <br>
  140.    * Class name of base learner to use for accuracy estimation.
  141.    * Place any classifier options last on the command line following a
  142.    * "--". Eg  -B weka.classifiers.bayes.NaiveBayes ... -- -K <p>
  143.    *
  144.    * -F <num> <br>
  145.    * Number of cross validation folds to use for estimating accuracy.
  146.    * <default=5> <p>
  147.    *
  148.    * -T <num> <br>
  149.    * Threshold by which to execute another cross validation (standard deviation
  150.    * ---expressed as a percentage of the mean). <p>
  151.    *
  152.    * @param options the list of options as an array of strings
  153.    * @exception Exception if an option is not supported
  154.    *
  155.    **/
  156.   public void setOptions (String[] options)
  157.     throws Exception
  158.   {
  159.     String optionString;
  160.     resetOptions();
  161.     optionString = Utils.getOption('B', options);
  162.     if (optionString.length() == 0) {
  163.       throw  new Exception("A learning scheme must be specified with" 
  164.    + "-B option");
  165.     }
  166.     setClassifier(Classifier.forName(optionString, 
  167.      Utils.partitionOptions(options)));
  168.     optionString = Utils.getOption('F', options);
  169.     if (optionString.length() != 0) {
  170.       setFolds(Integer.parseInt(optionString));
  171.     }
  172.     optionString = Utils.getOption('S', options);
  173.     if (optionString.length() != 0) {
  174.       setSeed(Integer.parseInt(optionString));
  175.     }
  176.     //       optionString = Utils.getOption('S',options);
  177.     //       if (optionString.length() != 0)
  178.     //         {
  179.     //    seed = Integer.parseInt(optionString);
  180.     //         }
  181.     optionString = Utils.getOption('T', options);
  182.     if (optionString.length() != 0) {
  183.       Double temp;
  184.       temp = Double.valueOf(optionString);
  185.       setThreshold(temp.doubleValue());
  186.     }
  187.   }
  188.   
  189.   /**
  190.    * Returns the tip text for this property
  191.    * @return tip text for this property suitable for
  192.    * displaying in the explorer/experimenter gui
  193.    */
  194.   public String thresholdTipText() {
  195.     return "Repeat xval if stdev of mean exceeds this value.";
  196.   }
  197.   /**
  198.    * Set the value of the threshold for repeating cross validation
  199.    *
  200.    * @param t the value of the threshold
  201.    */
  202.   public void setThreshold (double t) {
  203.     m_threshold = t;
  204.   }
  205.   /**
  206.    * Get the value of the threshold
  207.    *
  208.    * @return the threshold as a double
  209.    */
  210.   public double getThreshold () {
  211.     return  m_threshold;
  212.   }
  213.   /**
  214.    * Returns the tip text for this property
  215.    * @return tip text for this property suitable for
  216.    * displaying in the explorer/experimenter gui
  217.    */
  218.   public String foldsTipText() {
  219.     return "Number of xval folds to use when estimating subset accuracy.";
  220.   }
  221.   /**
  222.    * Set the number of folds to use for accuracy estimation
  223.    *
  224.    * @param f the number of folds
  225.    */
  226.   public void setFolds (int f) {
  227.     m_folds = f;
  228.   }
  229.   /**
  230.    * Get the number of folds used for accuracy estimation
  231.    *
  232.    * @return the number of folds
  233.    */
  234.   public int getFolds () {
  235.     return  m_folds;
  236.   }
  237.   /**
  238.    * Returns the tip text for this property
  239.    * @return tip text for this property suitable for
  240.    * displaying in the explorer/experimenter gui
  241.    */
  242.   public String seedTipText() {
  243.     return "Seed to use for randomly generating xval splits.";
  244.   }
  245.   /**
  246.    * Set the seed to use for cross validation
  247.    *
  248.    * @param s the seed
  249.    */
  250.   public void setSeed (int s) {
  251.     m_seed = s;
  252.   }
  253.   /**
  254.    * Get the random number seed used for cross validation
  255.    *
  256.    * @return the seed
  257.    */
  258.   public int getSeed () {
  259.     return  m_seed;
  260.   }
  261.   /**
  262.    * Returns the tip text for this property
  263.    * @return tip text for this property suitable for
  264.    * displaying in the explorer/experimenter gui
  265.    */
  266.   public String classifierTipText() {
  267.     return "Classifier to use for estimating the accuracy of subsets";
  268.   }
  269.   /**
  270.    * Set the classifier to use for accuracy estimation
  271.    *
  272.    * @param newClassifier the Classifier to use.
  273.    */
  274.   public void setClassifier (Classifier newClassifier) {
  275.     m_BaseClassifier = newClassifier;
  276.   }
  277.   /**
  278.    * Get the classifier used as the base learner.
  279.    *
  280.    * @return the classifier used as the classifier
  281.    */
  282.   public Classifier getClassifier () {
  283.     return  m_BaseClassifier;
  284.   }
  285.   /**
  286.    * Gets the current settings of WrapperSubsetEval.
  287.    *
  288.    * @return an array of strings suitable for passing to setOptions()
  289.    */
  290.   public String[] getOptions () {
  291.     String[] classifierOptions = new String[0];
  292.     if ((m_BaseClassifier != null) && 
  293. (m_BaseClassifier instanceof OptionHandler)) {
  294.       classifierOptions = ((OptionHandler)m_BaseClassifier).getOptions();
  295.     }
  296.     String[] options = new String[9 + classifierOptions.length];
  297.     int current = 0;
  298.     if (getClassifier() != null) {
  299.       options[current++] = "-B";
  300.       options[current++] = getClassifier().getClass().getName();
  301.     }
  302.     options[current++] = "-F";
  303.     options[current++] = "" + getFolds();
  304.     options[current++] = "-T";
  305.     options[current++] = "" + getThreshold();
  306.     options[current++] = "-S";
  307.     options[current++] = "" + getSeed();
  308.     options[current++] = "--";
  309.     System.arraycopy(classifierOptions, 0, options, current, 
  310.      classifierOptions.length);
  311.     current += classifierOptions.length;
  312.     while (current < options.length) {
  313.       options[current++] = "";
  314.     }
  315.     return  options;
  316.   }
  317.   protected void resetOptions () {
  318.     m_trainInstances = null;
  319.     m_Evaluation = null;
  320.     m_BaseClassifier = new ZeroR();
  321.     m_folds = 5;
  322.     m_seed = 1;
  323.     m_threshold = 0.01;
  324.   }
  325.   /**
  326.    * Generates a attribute evaluator. Has to initialize all fields of the 
  327.    * evaluator that are not being set via options.
  328.    *
  329.    * @param data set of instances serving as training data 
  330.    * @exception Exception if the evaluator has not been 
  331.    * generated successfully
  332.    */
  333.   public void buildEvaluator (Instances data)
  334.     throws Exception
  335.   {
  336.     if (data.checkForStringAttributes()) {
  337.       throw  new UnsupportedAttributeTypeException("Can't handle string attributes!");
  338.     }
  339.     m_trainInstances = data;
  340.     m_classIndex = m_trainInstances.classIndex();
  341.     m_numAttribs = m_trainInstances.numAttributes();
  342.     m_numInstances = m_trainInstances.numInstances();
  343.   }
  344.   /**
  345.    * Evaluates a subset of attributes
  346.    *
  347.    * @param subset a bitset representing the attribute subset to be 
  348.    * evaluated 
  349.    * @exception Exception if the subset could not be evaluated
  350.    */
  351.   public double evaluateSubset (BitSet subset)
  352.     throws Exception
  353.   {
  354.     double errorRate = 0;
  355.     double[] repError = new double[5];
  356.     boolean ok = true;
  357.     int numAttributes = 0;
  358.     int i, j;
  359.     Random Rnd = new Random(m_seed);
  360.     Remove delTransform = new Remove();
  361.     delTransform.setInvertSelection(true);
  362.     // copy the instances
  363.     Instances trainCopy = new Instances(m_trainInstances);
  364.     // count attributes set in the BitSet
  365.     for (i = 0; i < m_numAttribs; i++) {
  366.       if (subset.get(i)) {
  367.         numAttributes++;
  368.       }
  369.     }
  370.     // set up an array of attribute indexes for the filter (+1 for the class)
  371.     int[] featArray = new int[numAttributes + 1];
  372.     for (i = 0, j = 0; i < m_numAttribs; i++) {
  373.       if (subset.get(i)) {
  374.         featArray[j++] = i;
  375.       }
  376.     }
  377.     featArray[j] = m_classIndex;
  378.     delTransform.setAttributeIndicesArray(featArray);
  379.     delTransform.setInputFormat(trainCopy);
  380.     trainCopy = Filter.useFilter(trainCopy, delTransform);
  381.     // max of 5 repititions ofcross validation
  382.     for (i = 0; i < 5; i++) {
  383.       trainCopy.randomize(Rnd); // randomize instances
  384.       m_Evaluation = new Evaluation(trainCopy);
  385.       m_Evaluation.crossValidateModel(m_BaseClassifier, trainCopy, m_folds);
  386.       repError[i] = m_Evaluation.errorRate();
  387.       // check on the standard deviation
  388.       if (!repeat(repError, i + 1)) {
  389.         break;
  390.       }
  391.     }
  392.     for (j = 0; j < i; j++) {
  393.       errorRate += repError[j];
  394.     }
  395.     errorRate /= (double)i;
  396.     return  -errorRate;
  397.   }
  398.   /**
  399.    * Returns a string describing the wrapper
  400.    *
  401.    * @return the description as a string
  402.    */
  403.   public String toString () {
  404.     StringBuffer text = new StringBuffer();
  405.     if (m_trainInstances == null) {
  406.       text.append("tWrapper subset evaluator has not been built yetn");
  407.     }
  408.     else {
  409.       text.append("tWrapper Subset Evaluatorn");
  410.       text.append("tLearning scheme: " 
  411.   + getClassifier().getClass().getName() + "n");
  412.       text.append("tScheme options: ");
  413.       String[] classifierOptions = new String[0];
  414.       if (m_BaseClassifier instanceof OptionHandler) {
  415.         classifierOptions = ((OptionHandler)m_BaseClassifier).getOptions();
  416.         for (int i = 0; i < classifierOptions.length; i++) {
  417.           text.append(classifierOptions[i] + " ");
  418.         }
  419.       }
  420.       text.append("n");
  421.       if (m_trainInstances.attribute(m_classIndex).isNumeric()) {
  422. text.append("tAccuracy estimation: RMSEn");
  423.       } else {
  424. text.append("tAccuracy estimation: classification errorn");
  425.       }
  426.       
  427.       text.append("tNumber of folds for accuracy estimation: " 
  428.   + m_folds 
  429.   + "n");
  430.     }
  431.     return  text.toString();
  432.   }
  433.   /**
  434.    * decides whether to do another repeat of cross validation. If the
  435.    * standard deviation of the cross validations
  436.    * is greater than threshold% of the mean (default 1%) then another 
  437.    * repeat is done. 
  438.    *
  439.    * @param repError an array of cross validation results
  440.    * @param entries the number of cross validations done so far
  441.    * @return true if another cv is to be done
  442.    */
  443.   private boolean repeat (double[] repError, int entries) {
  444.     int i;
  445.     double mean = 0;
  446.     double variance = 0;
  447.     if (entries == 1) {
  448.       return  true;
  449.     }
  450.     for (i = 0; i < entries; i++) {
  451.       mean += repError[i];
  452.     }
  453.     mean /= (double)entries;
  454.     for (i = 0; i < entries; i++) {
  455.       variance += ((repError[i] - mean)*(repError[i] - mean));
  456.     }
  457.     variance /= (double)entries;
  458.     if (variance > 0) {
  459.       variance = Math.sqrt(variance);
  460.     }
  461.     if ((variance/mean) > m_threshold) {
  462.       return  true;
  463.     }
  464.     return  false;
  465.   }
  466.   /**
  467.    * Main method for testing this class.
  468.    *
  469.    * @param args the options
  470.    */
  471.   public static void main (String[] args) {
  472.     try {
  473.       System.out.println(AttributeSelection.
  474.  SelectAttributes(new WrapperSubsetEval(), args));
  475.     }
  476.     catch (Exception e) {
  477.       e.printStackTrace();
  478.       System.out.println(e.getMessage());
  479.     }
  480.   }
  481. }