RegressionSplitEvaluator.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 18k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    RegressionSplitEvaluator.java
  18.  *    Copyright (C) 1999 Len Trigg
  19.  *
  20.  */
  21. package weka.experiment;
  22. import weka.core.Instance;
  23. import weka.core.Instances;
  24. import weka.core.OptionHandler;
  25. import weka.core.Option;
  26. import weka.core.Utils;
  27. import weka.core.Attribute;
  28. import weka.core.Summarizable;
  29. import weka.core.AdditionalMeasureProducer;
  30. import weka.classifiers.Classifier;
  31. import weka.classifiers.Evaluation;
  32. import java.util.Enumeration;
  33. import java.util.Vector;
  34. import java.io.Serializable;
  35. import java.io.ObjectStreamClass;
  36. /**
  37.  * A SplitEvaluator that produces results for a classification scheme
  38.  * on a numeric class attribute.
  39.  *
  40.  * @author Len Trigg (trigg@cs.waikato.ac.nz)
  41.  * @version $Revision: 1.13 $
  42.  */
  43. public class RegressionSplitEvaluator implements SplitEvaluator, 
  44.   OptionHandler, AdditionalMeasureProducer {
  45.   
  46.   /** The classifier used for evaluation */
  47.   protected Classifier m_Classifier = new weka.classifiers.ZeroR();
  48.   
  49.   /** The names of any additional measures to look for in SplitEvaluators */
  50.   protected String [] m_AdditionalMeasures = null;
  51.   /** Array of booleans corresponding to the measures in m_AdditionalMeasures
  52.       indicating which of the AdditionalMeasures the current classifier
  53.       can produce */
  54.   protected boolean [] m_doesProduce = null;
  55.   /** Holds the statistics for the most recent application of the classifier */
  56.   protected String m_result = null;
  57.   /** The classifier options (if any) */
  58.   protected String m_ClassifierOptions = "";
  59.   /** The classifier version */
  60.   protected String m_ClassifierVersion = "";
  61.   /** The length of a key */
  62.   private static final int KEY_SIZE = 3;
  63.   /** The length of a result */
  64.   private static final int RESULT_SIZE = 15;
  65.   /**
  66.    * No args constructor.
  67.    */
  68.   public RegressionSplitEvaluator() {
  69.     updateOptions();
  70.   }
  71.   /**
  72.    * Returns a string describing this split evaluator
  73.    * @return a description of the split evaluator suitable for
  74.    * displaying in the explorer/experimenter gui
  75.    */
  76.   public String globalInfo() {
  77.     return "A SplitEvaluator that produces results for a classification "
  78.       +"scheme on a numeric class attribute.";
  79.   }
  80.   /**
  81.    * Returns an enumeration describing the available options.
  82.    *
  83.    * @return an enumeration of all the available options
  84.    */
  85.   public Enumeration listOptions() {
  86.     Vector newVector = new Vector(1);
  87.     newVector.addElement(new Option(
  88.      "tThe full class name of the classifier.n"
  89.       +"teg: weka.classifiers.NaiveBayes", 
  90.      "W", 1, 
  91.      "-W <class name>"));
  92.     if ((m_Classifier != null) &&
  93. (m_Classifier instanceof OptionHandler)) {
  94.       newVector.addElement(new Option(
  95.      "",
  96.      "", 0, "nOptions specific to classifier "
  97.      + m_Classifier.getClass().getName() + ":"));
  98.       Enumeration enum = ((OptionHandler)m_Classifier).listOptions();
  99.       while (enum.hasMoreElements()) {
  100. newVector.addElement(enum.nextElement());
  101.       }
  102.     }
  103.     return newVector.elements();
  104.   }
  105.   /**
  106.    * Parses a given list of options. Valid options are:<p>
  107.    *
  108.    * -W classname <br>
  109.    * Specify the full class name of the classifier to evaluate. <p>
  110.    *
  111.    * All option after -- will be passed to the classifier.
  112.    *
  113.    * @param options the list of options as an array of strings
  114.    * @exception Exception if an option is not supported
  115.    */
  116.   public void setOptions(String[] options) throws Exception {
  117.     
  118.     String cName = Utils.getOption('W', options);
  119.     if (cName.length() == 0) {
  120.       throw new Exception("A classifier must be specified with"
  121.   + " the -W option.");
  122.     }
  123.     // Do it first without options, so if an exception is thrown during
  124.     // the option setting, listOptions will contain options for the actual
  125.     // Classifier.
  126.     setClassifier(Classifier.forName(cName, null));
  127.     if (getClassifier() instanceof OptionHandler) {
  128.       ((OptionHandler) getClassifier())
  129. .setOptions(Utils.partitionOptions(options));
  130.       updateOptions();
  131.     }
  132.   }
  133.   /**
  134.    * Gets the current settings of the Classifier.
  135.    *
  136.    * @return an array of strings suitable for passing to setOptions
  137.    */
  138.   public String [] getOptions() {
  139.     String [] classifierOptions = new String [0];
  140.     if ((m_Classifier != null) && 
  141. (m_Classifier instanceof OptionHandler)) {
  142.       classifierOptions = ((OptionHandler)m_Classifier).getOptions();
  143.     }
  144.     
  145.     String [] options = new String [classifierOptions.length + 3];
  146.     int current = 0;
  147.     if (getClassifier() != null) {
  148.       options[current++] = "-W";
  149.       options[current++] = getClassifier().getClass().getName();
  150.     }
  151.     options[current++] = "--";
  152.     System.arraycopy(classifierOptions, 0, options, current, 
  153.      classifierOptions.length);
  154.     current += classifierOptions.length;
  155.     while (current < options.length) {
  156.       options[current++] = "";
  157.     }
  158.     return options;
  159.   }
  160.   /**
  161.    * Set a list of method names for additional measures to look for
  162.    * in Classifiers. This could contain many measures (of which only a
  163.    * subset may be produceable by the current Classifier) if an experiment
  164.    * is the type that iterates over a set of properties.
  165.    * @param additionalMeasures an array of method names.
  166.    */
  167.   public void setAdditionalMeasures(String [] additionalMeasures) {
  168.     m_AdditionalMeasures = additionalMeasures;
  169.     // determine which (if any) of the additional measures this classifier
  170.     // can produce
  171.     if (m_AdditionalMeasures != null && m_AdditionalMeasures.length > 0) {
  172.       m_doesProduce = new boolean [m_AdditionalMeasures.length];
  173.       if (m_Classifier instanceof AdditionalMeasureProducer) {
  174. Enumeration en = ((AdditionalMeasureProducer)m_Classifier).
  175.   enumerateMeasures();
  176. while (en.hasMoreElements()) {
  177.   String mname = (String)en.nextElement();
  178.   for (int j=0;j<m_AdditionalMeasures.length;j++) {
  179.     if (mname.compareTo(m_AdditionalMeasures[j]) == 0) {
  180.       m_doesProduce[j] = true;
  181.     }
  182.   }
  183. }
  184.       }
  185.     } else {
  186.       m_doesProduce = null;
  187.     }
  188.   }
  189.   
  190.     /**
  191.    * Returns an enumeration of any additional measure names that might be
  192.    * in the classifier
  193.    * @return an enumeration of the measure names
  194.    */
  195.   public Enumeration enumerateMeasures() {
  196.     Vector newVector = new Vector();
  197.     if (m_Classifier instanceof AdditionalMeasureProducer) {
  198.       Enumeration en = ((AdditionalMeasureProducer)m_Classifier).
  199. enumerateMeasures();
  200.       while (en.hasMoreElements()) {
  201. String mname = (String)en.nextElement();
  202. newVector.addElement(mname);
  203.       }
  204.     }
  205.     return newVector.elements();
  206.   }
  207.   
  208.   /**
  209.    * Returns the value of the named measure
  210.    * @param measureName the name of the measure to query for its value
  211.    * @return the value of the named measure
  212.    * @exception IllegalArgumentException if the named measure is not supported
  213.    */
  214.   public double getMeasure(String additionalMeasureName) {
  215.     if (m_Classifier instanceof AdditionalMeasureProducer) {
  216.       return ((AdditionalMeasureProducer)m_Classifier).
  217. getMeasure(additionalMeasureName);
  218.     } else {
  219.       throw new IllegalArgumentException("RegressionSplitEvaluator: "
  220.   +"Can't return value for : "+additionalMeasureName
  221.   +". "+m_Classifier.getClass().getName()+" "
  222.   +"is not an AdditionalMeasureProducer");
  223.     }
  224.   }
  225.   /**
  226.    * Gets the data types of each of the key columns produced for a single run.
  227.    * The number of key fields must be constant
  228.    * for a given SplitEvaluator.
  229.    *
  230.    * @return an array containing objects of the type of each key column. The 
  231.    * objects should be Strings, or Doubles.
  232.    */
  233.   public Object [] getKeyTypes() {
  234.     Object [] keyTypes = new Object[KEY_SIZE];
  235.     keyTypes[0] = "";
  236.     keyTypes[1] = "";
  237.     keyTypes[2] = "";
  238.     return keyTypes;
  239.   }
  240.   /**
  241.    * Gets the names of each of the key columns produced for a single run.
  242.    * The number of key fields must be constant
  243.    * for a given SplitEvaluator.
  244.    *
  245.    * @return an array containing the name of each key column
  246.    */
  247.   public String [] getKeyNames() {
  248.     String [] keyNames = new String[KEY_SIZE];
  249.     keyNames[0] = "Scheme";
  250.     keyNames[1] = "Scheme_options";
  251.     keyNames[2] = "Scheme_version_ID";
  252.     return keyNames;
  253.   }
  254.   /**
  255.    * Gets the key describing the current SplitEvaluator. For example
  256.    * This may contain the name of the classifier used for classifier
  257.    * predictive evaluation. The number of key fields must be constant
  258.    * for a given SplitEvaluator.
  259.    *
  260.    * @return an array of objects containing the key.
  261.    */
  262.   public Object [] getKey(){
  263.     Object [] key = new Object[KEY_SIZE];
  264.     key[0] = m_Classifier.getClass().getName();
  265.     key[1] = m_ClassifierOptions;
  266.     key[2] = m_ClassifierVersion;
  267.     return key;
  268.   }
  269.   /**
  270.    * Gets the data types of each of the result columns produced for a 
  271.    * single run. The number of result fields must be constant
  272.    * for a given SplitEvaluator.
  273.    *
  274.    * @return an array containing objects of the type of each result column. 
  275.    * The objects should be Strings, or Doubles.
  276.    */
  277.   public Object [] getResultTypes() {
  278.     int addm = (m_AdditionalMeasures != null) 
  279.       ? m_AdditionalMeasures.length 
  280.       : 0;
  281.     Object [] resultTypes = new Object[RESULT_SIZE+addm];
  282.     Double doub = new Double(0);
  283.     int current = 0;
  284.     resultTypes[current++] = doub;
  285.     resultTypes[current++] = doub;
  286.     resultTypes[current++] = doub;
  287.     resultTypes[current++] = doub;
  288.     resultTypes[current++] = doub;
  289.     resultTypes[current++] = doub;
  290.     resultTypes[current++] = doub;
  291.     resultTypes[current++] = doub;
  292.     resultTypes[current++] = doub;
  293.     resultTypes[current++] = doub;
  294.     resultTypes[current++] = doub;
  295.     resultTypes[current++] = doub;
  296.     // Timing stats
  297.     resultTypes[current++] = doub;
  298.     resultTypes[current++] = doub;
  299.     resultTypes[current++] = "";
  300.     // add any additional measures
  301.     for (int i=0;i<addm;i++) {
  302.       resultTypes[current++] = doub;
  303.     }
  304.     if (current != RESULT_SIZE+addm) {
  305.       throw new Error("ResultTypes didn't fit RESULT_SIZE");
  306.     }
  307.     return resultTypes;
  308.   }
  309.   /**
  310.    * Gets the names of each of the result columns produced for a single run.
  311.    * The number of result fields must be constant
  312.    * for a given SplitEvaluator.
  313.    *
  314.    * @return an array containing the name of each result column
  315.    */
  316.   public String [] getResultNames() {
  317.     int addm = (m_AdditionalMeasures != null) 
  318.       ? m_AdditionalMeasures.length 
  319.       : 0;
  320.     String [] resultNames = new String[RESULT_SIZE+addm];
  321.     int current = 0;
  322.     resultNames[current++] = "Number_of_instances";
  323.     // Sensitive stats - certainty of predictions
  324.     resultNames[current++] = "Mean_absolute_error";
  325.     resultNames[current++] = "Root_mean_squared_error";
  326.     resultNames[current++] = "Relative_absolute_error";
  327.     resultNames[current++] = "Root_relative_squared_error";
  328.     resultNames[current++] = "Correlation_coefficient";
  329.     // SF stats
  330.     resultNames[current++] = "SF_prior_entropy";
  331.     resultNames[current++] = "SF_scheme_entropy";
  332.     resultNames[current++] = "SF_entropy_gain";
  333.     resultNames[current++] = "SF_mean_prior_entropy";
  334.     resultNames[current++] = "SF_mean_scheme_entropy";
  335.     resultNames[current++] = "SF_mean_entropy_gain";
  336.     // Timing stats
  337.     resultNames[current++] = "Time_training";
  338.     resultNames[current++] = "Time_testing";
  339.     // Classifier defined extras
  340.     resultNames[current++] = "Summary";
  341.     // add any additional measures
  342.     for (int i=0;i<addm;i++) {
  343.       resultNames[current++] = m_AdditionalMeasures[i];
  344.     }
  345.     if (current != RESULT_SIZE+addm) {
  346.       throw new Error("ResultNames didn't fit RESULT_SIZE");
  347.     }
  348.     return resultNames;
  349.   }
  350.   /**
  351.    * Gets the results for the supplied train and test datasets.
  352.    *
  353.    * @param train the training Instances.
  354.    * @param test the testing Instances.
  355.    * @return the results stored in an array. The objects stored in
  356.    * the array may be Strings, Doubles, or null (for the missing value).
  357.    * @exception Exception if a problem occurs while getting the results
  358.    */
  359.   public Object [] getResult(Instances train, Instances test) 
  360.     throws Exception {
  361.     if (train.classAttribute().type() != Attribute.NUMERIC) {
  362.       throw new Exception("Class attribute is not numeric!");
  363.     }
  364.     if (m_Classifier == null) {
  365.       throw new Exception("No classifier has been specified");
  366.     }
  367.     int addm = (m_AdditionalMeasures != null) 
  368.       ? m_AdditionalMeasures.length 
  369.       : 0;
  370.     Object [] result = new Object[RESULT_SIZE+addm];
  371.     Evaluation eval = new Evaluation(train);
  372.     long trainTimeStart = System.currentTimeMillis();
  373.     m_Classifier.buildClassifier(train);
  374.     long trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
  375.     long testTimeStart = System.currentTimeMillis();
  376.     eval.evaluateModel(m_Classifier, test);
  377.     long testTimeElapsed = System.currentTimeMillis() - testTimeStart;
  378.     m_result = eval.toSummaryString();
  379.     // The results stored are all per instance -- can be multiplied by the
  380.     // number of instances to get absolute numbers
  381.     int current = 0;
  382.     result[current++] = new Double(eval.numInstances());
  383.     result[current++] = new Double(eval.meanAbsoluteError());
  384.     result[current++] = new Double(eval.rootMeanSquaredError());
  385.     result[current++] = new Double(eval.relativeAbsoluteError());
  386.     result[current++] = new Double(eval.rootRelativeSquaredError());
  387.     result[current++] = new Double(eval.correlationCoefficient());
  388.     result[current++] = new Double(eval.SFPriorEntropy());
  389.     result[current++] = new Double(eval.SFSchemeEntropy());
  390.     result[current++] = new Double(eval.SFEntropyGain());
  391.     result[current++] = new Double(eval.SFMeanPriorEntropy());
  392.     result[current++] = new Double(eval.SFMeanSchemeEntropy());
  393.     result[current++] = new Double(eval.SFMeanEntropyGain());
  394.     // Timing stats
  395.     result[current++] = new Double(trainTimeElapsed / 1000.0);
  396.     result[current++] = new Double(testTimeElapsed / 1000.0);
  397.     if (m_Classifier instanceof Summarizable) {
  398.       result[current++] = ((Summarizable)m_Classifier).toSummaryString();
  399.     } else {
  400.       result[current++] = null;
  401.     }
  402.     for (int i=0;i<addm;i++) {
  403.       if (m_doesProduce[i]) {
  404. try {
  405.   double dv = ((AdditionalMeasureProducer)m_Classifier).
  406.     getMeasure(m_AdditionalMeasures[i]);
  407.   Double value = new Double(dv);
  408.   result[current++] = value;
  409. } catch (Exception ex) {
  410.   System.err.println(ex);
  411. }
  412.       } else {
  413. result[current++] = null;
  414.       }
  415.     }
  416.     if (current != RESULT_SIZE+addm) {
  417.       throw new Error("Results didn't fit RESULT_SIZE");
  418.     }
  419.     return result;
  420.   }
  421.   /**
  422.    * Returns the tip text for this property
  423.    * @return tip text for this property suitable for
  424.    * displaying in the explorer/experimenter gui
  425.    */
  426.   public String classifierTipText() {
  427.     return "The classifier to use.";
  428.   }
  429.   /**
  430.    * Get the value of Classifier.
  431.    *
  432.    * @return Value of Classifier.
  433.    */
  434.   public Classifier getClassifier() {
  435.     
  436.     return m_Classifier;
  437.   }
  438.   
  439.   /**
  440.    * Sets the classifier.
  441.    *
  442.    * @param newClassifier the new classifier to use.
  443.    */
  444.   public void setClassifier(Classifier newClassifier) {
  445.     
  446.     m_Classifier = newClassifier;
  447.     updateOptions();
  448.     System.err.println("RegressionSplitEvaluator: In set classifier");
  449.   }
  450.   /**
  451.    * Updates the options that the current classifier is using.
  452.    */
  453.   protected void updateOptions() {
  454.     
  455.     if (m_Classifier instanceof OptionHandler) {
  456.       m_ClassifierOptions = Utils.joinOptions(((OptionHandler)m_Classifier)
  457.       .getOptions());
  458.     } else {
  459.       m_ClassifierOptions = "";
  460.     }
  461.     if (m_Classifier instanceof Serializable) {
  462.       ObjectStreamClass obs = ObjectStreamClass.lookup(m_Classifier
  463.        .getClass());
  464.       m_ClassifierVersion = "" + obs.getSerialVersionUID();
  465.     } else {
  466.       m_ClassifierVersion = "";
  467.     }
  468.   }
  469.   /**
  470.    * Set the Classifier to use, given it's class name. A new classifier will be
  471.    * instantiated.
  472.    *
  473.    * @param newClassifier the Classifier class name.
  474.    * @exception Exception if the class name is invalid.
  475.    */
  476.   public void setClassifierName(String newClassifierName) throws Exception {
  477.     try {
  478.       setClassifier((Classifier)Class.forName(newClassifierName)
  479.     .newInstance());
  480.     } catch (Exception ex) {
  481.       throw new Exception("Can't find Classifier with class name: "
  482.   + newClassifierName);
  483.     }
  484.   }
  485.   /**
  486.    * Gets the raw output from the classifier
  487.    * @return the raw output from the classifier
  488.    */
  489.   public String getRawResultOutput() {
  490.     StringBuffer result = new StringBuffer();
  491.     if (m_Classifier == null) {
  492.       return "<null> classifier";
  493.     }
  494.     result.append(toString());
  495.     result.append("Classifier model: n"+m_Classifier.toString()+'n');
  496.     // append the performance statistics
  497.     if (m_result != null) {
  498.       result.append(m_result);
  499.       
  500.       if (m_doesProduce != null) {
  501. for (int i=0;i<m_doesProduce.length;i++) {
  502.   if (m_doesProduce[i]) {
  503.     try {
  504.       double dv = ((AdditionalMeasureProducer)m_Classifier).
  505. getMeasure(m_AdditionalMeasures[i]);
  506.       Double value = new Double(dv);
  507.       result.append(m_AdditionalMeasures[i]+" : "+value+'n');
  508.     } catch (Exception ex) {
  509.       System.err.println(ex);
  510.     }
  511.   } 
  512. }
  513.       }
  514.     }
  515.     return result.toString();
  516.   }
  517.   /**
  518.    * Returns a text description of the split evaluator.
  519.    *
  520.    * @return a text description of the split evaluator.
  521.    */
  522.   public String toString() {
  523.     String result = "RegressionSplitEvaluator: ";
  524.     if (m_Classifier == null) {
  525.       return result + "<null> classifier";
  526.     }
  527.     return result + m_Classifier.getClass().getName() + " " 
  528.       + m_ClassifierOptions + "(version " + m_ClassifierVersion + ")";
  529.   }
  530. } // RegressionSplitEvaluator