CVParameterSelection.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 20k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    CVParameterSelection.java
  18.  *    Copyright (C) 1999 Len Trigg
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. /**
  26.  * Class for performing parameter selection by cross-validation for any
  27.  * classifier. For more information, see<p>
  28.  *
  29.  * R. Kohavi (1995). <i>Wrappers for Performance
  30.  * Enhancement and Oblivious Decision Graphs</i>. PhD
  31.  * Thesis. Department of Computer Science, Stanford University. <p>
  32.  *
  33.  * Valid options are:<p>
  34.  *
  35.  * -D <br>
  36.  * Turn on debugging output.<p>
  37.  *
  38.  * -W classname <br>
  39.  * Specify the full class name of classifier to perform cross-validation
  40.  * selection on.<p>
  41.  *
  42.  * -X num <br>
  43.  * Number of folds used for cross validation (default 10). <p>
  44.  *
  45.  * -S seed <br>
  46.  * Random number seed (default 1).<p>
  47.  *
  48.  * -P "N 1 5 10" <br>
  49.  * Sets an optimisation parameter for the classifier with name -N,
  50.  * lower bound 1, upper bound 5, and 10 optimisation steps.
  51.  * The upper bound may be the character 'A' or 'I' to substitute 
  52.  * the number of attributes or instances in the training data,
  53.  * respectively.
  54.  * This parameter may be supplied more than once to optimise over
  55.  * several classifier options simultaneously. <p>
  56.  *
  57.  * Options after -- are passed to the designated sub-classifier. <p>
  58.  *
  59.  * @author Len Trigg (trigg@cs.waikato.ac.nz)
  60.  * @version $Revision: 1.12 $ 
  61. */
  62. public class CVParameterSelection extends Classifier 
  63.   implements OptionHandler, Summarizable {
  64.   /*
  65.    * A data structure to hold values associated with a single
  66.    * cross-validation search parameter
  67.    */
  68.   protected class CVParameter {
  69.     /**  Char used to identify the option of interest */
  70.     private char m_ParamChar;    
  71.     /**  Lower bound for the CV search */
  72.     private double m_Lower;      
  73.     /**  Upper bound for the CV search */
  74.     private double m_Upper;      
  75.     /**  Increment during the search */
  76.     private double m_Steps;      
  77.     /**  The parameter value with the best performance */
  78.     private double m_ParamValue; 
  79.     /**  True if the parameter should be added at the end of the argument list */
  80.     private boolean m_AddAtEnd;  
  81.     /**  True if the parameter should be rounded to an integer */
  82.     private boolean m_RoundParam;
  83.     /**
  84.      * Constructs a CVParameter.
  85.      */
  86.     public CVParameter(String param) throws Exception {
  87.      
  88.       // Tokenize the string into it's parts
  89.       StreamTokenizer st = new StreamTokenizer(new StringReader(param));
  90.       if (st.nextToken() != StreamTokenizer.TT_WORD) {
  91. throw new Exception("CVParameter " + param 
  92.     + ": Character parameter identifier expected");
  93.       }
  94.       m_ParamChar = st.sval.charAt(0);
  95.       if (st.nextToken() != StreamTokenizer.TT_NUMBER) {
  96. throw new Exception("CVParameter " + param 
  97.     + ": Numeric lower bound expected");
  98.       }
  99.       m_Lower = st.nval;
  100.       if (st.nextToken() == StreamTokenizer.TT_NUMBER) {
  101. m_Upper = st.nval;
  102. if (m_Upper < m_Lower) {
  103.   throw new Exception("CVParameter " + param
  104.       + ": Upper bound is less than lower bound");
  105. }
  106.       } else if (st.ttype == StreamTokenizer.TT_WORD) {
  107. if (st.sval.toUpperCase().charAt(0) == 'A') {
  108.   m_Upper = m_Lower - 1;
  109. } else if (st.sval.toUpperCase().charAt(0) == 'I') {
  110.   m_Upper = m_Lower - 2;
  111. } else {
  112.   throw new Exception("CVParameter " + param 
  113.       + ": Upper bound must be numeric, or 'A' or 'N'");
  114. }
  115.       } else {
  116. throw new Exception("CVParameter " + param 
  117.       + ": Upper bound must be numeric, or 'A' or 'N'");
  118.       }
  119.       if (st.nextToken() != StreamTokenizer.TT_NUMBER) {
  120. throw new Exception("CVParameter " + param 
  121.     + ": Numeric number of steps expected");
  122.       }
  123.       m_Steps = st.nval;
  124.       if (st.nextToken() == StreamTokenizer.TT_WORD) {
  125. if (st.sval.toUpperCase().charAt(0) == 'R') {
  126.   m_RoundParam = true;
  127. }
  128.       }
  129.     }
  130.     /**
  131.      * Returns a CVParameter as a string.
  132.      */
  133.     public String toString() {
  134.       String result = m_ParamChar + " " + m_Lower + " ";
  135.       switch ((int)(m_Lower - m_Upper + 0.5)) {
  136.       case 1:
  137. result += "A";
  138. break;
  139.       case 2:
  140. result += "I";
  141. break;
  142.       default:
  143. result += m_Upper;
  144. break;
  145.       }
  146.       result += " " + m_Steps;
  147.       if (m_RoundParam) {
  148. result += " R";
  149.       }
  150.       return result;
  151.     }
  152.   }
  153.   /** The generated base classifier */
  154.   protected Classifier m_Classifier = new weka.classifiers.ZeroR();
  155.   /**
  156.    * The base classifier options (not including those being set
  157.    * by cross-validation)
  158.    */
  159.   protected String [] m_ClassifierOptions;
  160.   /** The set of all classifier options as determined by cross-validation */
  161.   protected String [] m_BestClassifierOptions;
  162.   /** The cross-validated performance of the best options */
  163.   protected double m_BestPerformance;
  164.   /** The set of parameters to cross-validate over */
  165.   protected FastVector m_CVParams;
  166.   /** The number of attributes in the data */
  167.   protected int m_NumAttributes;
  168.   /** The number of instances in a training fold */
  169.   protected int m_TrainFoldSize;
  170.   
  171.   /** The number of folds used in cross-validation */
  172.   protected int m_NumFolds = 10;
  173.   /** Random number seed */
  174.   protected int m_Seed = 1;
  175.   /** Debugging mode, gives extra output if true */
  176.   protected boolean m_Debug;
  177.   /**
  178.    * Create the options array to pass to the classifier. The parameter
  179.    * values and positions are taken from m_ClassifierOptions and
  180.    * m_CVParams.
  181.    *
  182.    * @return the options array
  183.    */
  184.   protected String [] createOptions() {
  185.     
  186.     String [] options = new String [m_ClassifierOptions.length 
  187.    + 2 * m_CVParams.size()];
  188.     int start = 0, end = options.length;
  189.     // Add the cross-validation parameters and their values
  190.     for (int i = 0; i < m_CVParams.size(); i++) {
  191.       CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);
  192.       double paramValue = cvParam.m_ParamValue;
  193.       if (cvParam.m_RoundParam) {
  194. paramValue = (double)((int) (paramValue + 0.5));
  195.       }
  196.       if (cvParam.m_AddAtEnd) {
  197. options[--end] = "" + 
  198. Utils.doubleToString(paramValue,4);
  199. options[--end] = "-" + cvParam.m_ParamChar;
  200.       } else {
  201. options[start++] = "-" + cvParam.m_ParamChar;
  202. options[start++] = "" 
  203. + Utils.doubleToString(paramValue,4);
  204.       }
  205.     }
  206.     // Add the static parameters
  207.     System.arraycopy(m_ClassifierOptions, 0,
  208.      options, start,
  209.      m_ClassifierOptions.length);
  210.     return options;
  211.   }
  212.   /**
  213.    * Finds the best parameter combination. (recursive for each parameter
  214.    * being optimised).
  215.    *
  216.    * @param depth the index of the parameter to be optimised at this level
  217.    * @exception Exception if an error occurs
  218.    */
  219.   protected void findParamsByCrossValidation(int depth, Instances trainData)
  220.     throws Exception {
  221.     if (depth < m_CVParams.size()) {
  222.       CVParameter cvParam = (CVParameter)m_CVParams.elementAt(depth);
  223.       double upper;
  224.       switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
  225.       case 1:
  226. upper = m_NumAttributes;
  227. break;
  228.       case 2:
  229. upper = m_TrainFoldSize;
  230. break;
  231.       default:
  232. upper = cvParam.m_Upper;
  233. break;
  234.       }
  235.       double increment = (upper - cvParam.m_Lower) / (cvParam.m_Steps - 1);
  236.       for(cvParam.m_ParamValue = cvParam.m_Lower; 
  237.   cvParam.m_ParamValue <= upper; 
  238.   cvParam.m_ParamValue += increment) {
  239. findParamsByCrossValidation(depth + 1, trainData);
  240.       }
  241.     } else {
  242.       
  243.       Evaluation evaluation = new Evaluation(trainData);
  244.       // Set the classifier options
  245.       String [] options = createOptions();
  246.       if (m_Debug) {
  247. System.err.print("Setting options for " 
  248.  + m_Classifier.getClass().getName() + ":");
  249. for (int i = 0; i < options.length; i++) {
  250.   System.err.print(" " + options[i]);
  251. }
  252. System.err.println("");
  253.       }
  254.       ((OptionHandler)m_Classifier).setOptions(options);
  255.       for (int j = 0; j < m_NumFolds; j++) {
  256. Instances train = trainData.trainCV(m_NumFolds, j);
  257. Instances test = trainData.testCV(m_NumFolds, j);
  258. m_Classifier.buildClassifier(train);
  259. evaluation.setPriors(train);
  260. evaluation.evaluateModel(m_Classifier, test);
  261.       }
  262.       double error = evaluation.errorRate();
  263.       if (m_Debug) {
  264. System.err.println("Cross-validated error rate: " 
  265.    + Utils.doubleToString(error, 6, 4));
  266.       }
  267.       if ((m_BestPerformance == -99) || (error < m_BestPerformance)) {
  268. m_BestPerformance = error;
  269. m_BestClassifierOptions = createOptions();
  270.       }
  271.     }
  272.   }
  273.   /**
  274.    * Returns an enumeration describing the available options
  275.    *
  276.    * @return an enumeration of all the available options
  277.    */
  278.   public Enumeration listOptions() {
  279.     Vector newVector = new Vector(5);
  280.     newVector.addElement(new Option(
  281.       "tTurn on debugging output.",
  282.       "D", 0, "-D"));
  283.     newVector.addElement(new Option(
  284.       "tFull name of classifier to perform parameter selection on.n"
  285.       + "teg: weka.classifiers.NaiveBayes",
  286.       "W", 1, "-W <classifier class name>"));
  287.     newVector.addElement(new Option(
  288.       "tNumber of folds used for cross validation (default 10).",
  289.       "X", 1, "-X <number of folds>"));
  290.     newVector.addElement(new Option(
  291.       "tClassifier parameter options.n"
  292.       + "teg: "N 1 5 10" Sets an optimisation parameter for then"
  293.       + "tclassifier with name -N, with lower bound 1, upper boundn"
  294.       + "t5, and 10 optimisation steps. The upper bound may be then"
  295.       + "tcharacter 'A' or 'I' to substitute the number ofn"
  296.       + "tattributes or instances in the training data,n"
  297.       + "trespectively. This parameter may be supplied more thann"
  298.       + "tonce to optimise over several classifier optionsn"
  299.       + "tsimultaneously.",
  300.       "P", 1, "-P <classifier parameter>"));
  301.     newVector.addElement(new Option(
  302.       "tSets the random number seed (default 1).",
  303.       "S", 1, "-S <random number seed>"));
  304.     if ((m_Classifier != null) &&
  305. (m_Classifier instanceof OptionHandler)) {
  306.       newVector.addElement(new Option("",
  307.         "", 0,
  308. "nOptions specific to sub-classifier "
  309.         + m_Classifier.getClass().getName()
  310. + ":n(use -- to signal start of sub-classifier options)"));
  311.       Enumeration enum = ((OptionHandler)m_Classifier).listOptions();
  312.       while (enum.hasMoreElements()) {
  313. newVector.addElement(enum.nextElement());
  314.       }
  315.     }
  316.     return newVector.elements();
  317.   }
  318.   /**
  319.    * Parses a given list of options. Valid options are:<p>
  320.    *
  321.    * -D <br>
  322.    * Turn on debugging output.<p>
  323.    *
  324.    * -W classname <br>
  325.    * Specify the full class name of classifier to perform cross-validation
  326.    * selection on.<p>
  327.    *
  328.    * -X num <br>
  329.    * Number of folds used for cross validation (default 10). <p>
  330.    *
  331.    * -S seed <br>
  332.    * Random number seed (default 1).<p>
  333.    *
  334.    * -P "N 1 5 10" <br>
  335.    * Sets an optimisation parameter for the classifier with name -N,
  336.    * lower bound 1, upper bound 5, and 10 optimisation steps.
  337.    * The upper bound may be the character 'A' or 'I' to substitute 
  338.    * the number of attributes or instances in the training data,
  339.    * respectively.
  340.    * This parameter may be supplied more than once to optimise over
  341.    * several classifier options simultaneously. <p>
  342.    *
  343.    * Options after -- are passed to the designated sub-classifier. <p>
  344.    *
  345.    * @param options the list of options as an array of strings
  346.    * @exception Exception if an option is not supported
  347.    */
  348.   public void setOptions(String[] options) throws Exception {
  349.     
  350.     setDebug(Utils.getFlag('D', options));
  351.     String foldsString = Utils.getOption('X', options);
  352.     if (foldsString.length() != 0) {
  353.       setNumFolds(Integer.parseInt(foldsString));
  354.     } else {
  355.       setNumFolds(10);
  356.     }
  357.     String randomString = Utils.getOption('S', options);
  358.     if (randomString.length() != 0) {
  359.       setSeed(Integer.parseInt(randomString));
  360.     } else {
  361.       setSeed(1);
  362.     }
  363.     String cvParam;
  364.     m_CVParams = new FastVector();
  365.     do {
  366.       cvParam = Utils.getOption('P', options);
  367.       if (cvParam.length() != 0) {
  368. addCVParameter(cvParam);
  369.       }
  370.     } while (cvParam.length() != 0);
  371.     if (m_CVParams.size() == 0) {
  372.       throw new Exception("A parameter specifier must be given with"
  373.   + " the -P option.");
  374.     }
  375.     String classifierName = Utils.getOption('W', options);
  376.     if (classifierName.length() == 0) {
  377.       throw new Exception("A classifier must be specified with"
  378.   + " the -W option.");
  379.     }
  380.     setClassifier(Classifier.forName(classifierName,
  381.      Utils.partitionOptions(options)));
  382.     if (!(m_Classifier instanceof OptionHandler)) {
  383.       throw new Exception("Base classifier must accept options");
  384.     }
  385.   }
  386.   /**
  387.    * Gets the current settings of the Classifier.
  388.    *
  389.    * @return an array of strings suitable for passing to setOptions
  390.    */
  391.   public String [] getOptions() {
  392.     String [] classifierOptions = new String [0];
  393.     if ((m_Classifier != null) && 
  394. (m_Classifier instanceof OptionHandler)) {
  395.       classifierOptions = ((OptionHandler)m_Classifier).getOptions();
  396.     }
  397.     int current = 0;
  398.     String [] options = new String [classifierOptions.length + 8];
  399.     if (m_CVParams != null) {
  400.       options = new String [m_CVParams.size() * 2 + options.length];
  401.       for (int i = 0; i < m_CVParams.size(); i++) {
  402. options[current++] = "-P"; options[current++] = "" + getCVParameter(i);
  403.       }
  404.     }
  405.     if (getDebug()) {
  406.       options[current++] = "-D";
  407.     }
  408.     options[current++] = "-X"; options[current++] = "" + getNumFolds();
  409.     options[current++] = "-S"; options[current++] = "" + getSeed();
  410.     if (getClassifier() != null) {
  411.       options[current++] = "-W";
  412.       options[current++] = getClassifier().getClass().getName();
  413.     }
  414.     options[current++] = "--";
  415.     System.arraycopy(classifierOptions, 0, options, current, 
  416.      classifierOptions.length);
  417.     current += classifierOptions.length;
  418.     while (current < options.length) {
  419.       options[current++] = "";
  420.     }
  421.     return options;
  422.   }
  423.   /**
  424.    * Generates the classifier.
  425.    *
  426.    * @param instances set of instances serving as training data 
  427.    * @exception Exception if the classifier has not been generated successfully
  428.    */
  429.   public void buildClassifier(Instances instances) throws Exception {
  430.     if (instances.checkForStringAttributes()) {
  431.       throw new Exception("Can't handle string attributes!");
  432.     }
  433.     Instances trainData = new Instances(instances);
  434.     trainData.deleteWithMissingClass();
  435.     if (trainData.numInstances() == 0) {
  436.       throw new Exception("No training instances without missing class.");
  437.     }
  438.     if (trainData.numInstances() < m_NumFolds) {
  439.       throw new Exception("Number of training instances smaller than number of folds.");
  440.     }
  441.     trainData.randomize(new Random(m_Seed));
  442.     if (trainData.classAttribute().isNominal()) {
  443.       trainData.stratify(m_NumFolds);
  444.     }
  445.     m_BestPerformance = -99;
  446.     m_BestClassifierOptions = null;
  447.     m_NumAttributes = trainData.numAttributes();
  448.     m_TrainFoldSize = trainData.trainCV(m_NumFolds, 0).numInstances();
  449.     
  450.     // Set up m_ClassifierOptions -- take getOptions() and remove
  451.     // those being optimised.
  452.     m_ClassifierOptions = ((OptionHandler)m_Classifier).getOptions();
  453.     for (int i = 0; i < m_CVParams.size(); i++) {
  454.       Utils.getOption(((CVParameter)m_CVParams.elementAt(i)).m_ParamChar,
  455.       m_ClassifierOptions);
  456.     }
  457.     findParamsByCrossValidation(0, trainData);
  458.     String [] options = (String [])m_BestClassifierOptions.clone();
  459.     ((OptionHandler)m_Classifier).setOptions(options);
  460.     m_Classifier.buildClassifier(trainData);
  461.   }
  462.   /**
  463.    * Predicts the class value for the given test instance.
  464.    *
  465.    * @param instance the instance to be classified
  466.    * @return the predicted class value
  467.    * @exception Exception if an error occurred during the prediction
  468.    */
  469.   public double classifyInstance(Instance instance) throws Exception {
  470.     
  471.     return m_Classifier.classifyInstance(instance);
  472.   }
  473.   /**
  474.    * Sets the seed for random number generation.
  475.    *
  476.    * @param seed the random number seed
  477.    */
  478.   public void setSeed(int seed) {
  479.     
  480.     m_Seed = seed;;
  481.   }
  482.   /**
  483.    * Gets the random number seed.
  484.    * 
  485.    * @return the random number seed
  486.    */
  487.   public int getSeed() {
  488.     return m_Seed;
  489.   }
  490.   /**
  491.    * Adds a scheme parameter to the list of parameters to be set
  492.    * by cross-validation
  493.    *
  494.    * @param cvParam the string representation of a scheme parameter. The
  495.    * format is: <br>
  496.    * param_char lower_bound upper_bound increment <br>
  497.    * eg to search a parameter -P from 1 to 10 by increments of 2: <br>
  498.    * P 1 10 2 <br>
  499.    * @exception Exception if the parameter specifier is of the wrong format
  500.    */
  501.   public void addCVParameter(String cvParam) throws Exception {
  502.     CVParameter newCV = new CVParameter(cvParam);
  503.     
  504.     m_CVParams.addElement(newCV);
  505.   }
  506.   /**
  507.    * Gets the scheme paramter with the given index.
  508.    */
  509.   public String getCVParameter(int index) {
  510.     if (m_CVParams.size() <= index) {
  511.       return "";
  512.     }
  513.     return ((CVParameter)m_CVParams.elementAt(index)).toString();
  514.   }
  515.   /**
  516.    * Sets debugging mode
  517.    *
  518.    * @param debug true if debug output should be printed
  519.    */
  520.   public void setDebug(boolean debug) {
  521.     m_Debug = debug;
  522.   }
  523.   /**
  524.    * Gets whether debugging is turned on
  525.    *
  526.    * @return true if debugging output is on
  527.    */
  528.   public boolean getDebug() {
  529.     return m_Debug;
  530.   }
  531.   /**
  532.    * Get the number of folds used for cross-validation.
  533.    *
  534.    * @return the number of folds used for cross-validation.
  535.    */
  536.   public int getNumFolds() {
  537.     
  538.     return m_NumFolds;
  539.   }
  540.   
  541.   /**
  542.    * Set the number of folds used for cross-validation.
  543.    *
  544.    * @param newNumFolds the number of folds used for cross-validation.
  545.    */
  546.   public void setNumFolds(int newNumFolds) {
  547.     
  548.     m_NumFolds = newNumFolds;
  549.   }
  550.   /**
  551.    * Set the classifier for boosting. 
  552.    *
  553.    * @param newClassifier the Classifier to use.
  554.    */
  555.   public void setClassifier(Classifier newClassifier) {
  556.     m_Classifier = newClassifier;
  557.   }
  558.   /**
  559.    * Get the classifier used as the classifier
  560.    *
  561.    * @return the classifier used as the classifier
  562.    */
  563.   public Classifier getClassifier() {
  564.     return m_Classifier;
  565.   }
  566.  
  567.   /**
  568.    * Returns description of the cross-validated classifier.
  569.    *
  570.    * @return description of the cross-validated classifier as a string
  571.    */
  572.   public String toString() {
  573.     if (m_BestClassifierOptions == null)
  574.       return "CVParameterSelection: No model built yet.";
  575.     String result = "Cross-validated Parameter selection.n"
  576.     + "Classifier: " + m_Classifier.getClass().getName() + "n";
  577.     try {
  578.       for (int i = 0; i < m_CVParams.size(); i++) {
  579. CVParameter cvParam = (CVParameter)m_CVParams.elementAt(i);
  580. result += "Cross-validation Parameter: '-" 
  581.   + cvParam.m_ParamChar + "'"
  582.   + " ranged from " + cvParam.m_Lower 
  583.   + " to ";
  584. switch ((int)(cvParam.m_Lower - cvParam.m_Upper + 0.5)) {
  585. case 1:
  586.   result += m_NumAttributes;
  587.   break;
  588. case 2:
  589.   result += m_TrainFoldSize;
  590.   break;
  591. default:
  592.   result += cvParam.m_Upper;
  593.   break;
  594. }
  595. result += " with " + cvParam.m_Steps + " stepsn";
  596.       }
  597.     } catch (Exception ex) {
  598.       result += ex.getMessage();
  599.     }
  600.     result += "Classifier Options: "
  601.       + Utils.joinOptions(m_BestClassifierOptions)
  602.       + "nn" + m_Classifier.toString();
  603.     return result;
  604.   }
  605.   public String toSummaryString() {
  606.     String result = "Selected values: "
  607.       + Utils.joinOptions(m_BestClassifierOptions);
  608.     return result + 'n';
  609.   }
  610.   
  611.   /**
  612.    * Main method for testing this class.
  613.    *
  614.    * @param argv the options
  615.    */
  616.   public static void main(String [] argv) {
  617.     try {
  618.       System.out.println(Evaluation.evaluateModel(new CVParameterSelection(), 
  619.   argv));
  620.     } catch (Exception e) {
  621.       System.err.println(e.getMessage());
  622.     }
  623.   }
  624. }
  625.