AdditiveRegression.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 18k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    AdditiveRegression.java
  18.  *    Copyright (C) 2000 Mark Hall
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. import weka.classifiers.*;
  26. /**
  27.  * Meta classifier that enhances the performance of a regression base
  28.  * classifier. Each iteration fits a model to the residuals left by the
  29.  * classifier on the previous iteration. Prediction is accomplished by
  30.  * adding the predictions of each classifier. Smoothing is accomplished
  31.  * through varying the shrinkage (learning rate) parameter. <p>
  32.  *
  33.  * <pre>
  34.  * Analysing:  Root_relative_squared_error
  35.  * Datasets:   36
  36.  * Resultsets: 2
  37.  * Confidence: 0.05 (two tailed)
  38.  * Date:       10/13/00 10:00 AM
  39.  *
  40.  *
  41.  * Dataset                   (1) m5.M5Prim | (2) AdditiveRegression -S 0.7 
  42.  *                                         |    -B weka.classifiers.m5.M5Prime 
  43.  *                          ----------------------------
  44.  * auto93.names              (10)    54.4  |    49.41 * 
  45.  * autoHorse.names           (10)    32.76 |    26.34 * 
  46.  * autoMpg.names             (10)    35.32 |    34.84 * 
  47.  * autoPrice.names           (10)    40.01 |    36.57 * 
  48.  * baskball                  (10)    79.46 |    79.85   
  49.  * bodyfat.names             (10)    10.38 |    11.41 v 
  50.  * bolts                     (10)    19.29 |    12.61 * 
  51.  * breastTumor               (10)    96.95 |    96.23 * 
  52.  * cholesterol               (10)   101.03 |    98.88 * 
  53.  * cleveland                 (10)    71.29 |    70.87 * 
  54.  * cloud                     (10)    38.82 |    39.18   
  55.  * cpu                       (10)    22.26 |    14.74 * 
  56.  * detroit                   (10)   228.16 |    83.7  * 
  57.  * echoMonths                (10)    71.52 |    69.15 * 
  58.  * elusage                   (10)    48.94 |    49.03   
  59.  * fishcatch                 (10)    16.61 |    15.36 * 
  60.  * fruitfly                  (10)   100    |   100    * 
  61.  * gascons                   (10)    18.72 |    14.26 * 
  62.  * housing                   (10)    38.62 |    36.53 * 
  63.  * hungarian                 (10)    74.67 |    72.19 * 
  64.  * longley                   (10)    31.23 |    28.26 * 
  65.  * lowbwt                    (10)    62.26 |    61.48 * 
  66.  * mbagrade                  (10)    89.2  |    89.2    
  67.  * meta                      (10)   163.15 |   188.28 v 
  68.  * pbc                       (10)    81.35 |    79.4  * 
  69.  * pharynx                   (10)   105.41 |   105.03   
  70.  * pollution                 (10)    72.24 |    68.16 * 
  71.  * pwLinear                  (10)    32.42 |    33.33 v 
  72.  * quake                     (10)   100.21 |    99.93   
  73.  * schlvote                  (10)    92.41 |    98.23 v 
  74.  * sensory                   (10)    88.03 |    87.94   
  75.  * servo                     (10)    37.07 |    35.5  * 
  76.  * sleep                     (10)    70.17 |    71.65   
  77.  * strike                    (10)    84.98 |    83.96 * 
  78.  * veteran                   (10)    90.61 |    88.77 * 
  79.  * vineyard                  (10)    79.41 |    73.95 * 
  80.  *                        ----------------------------
  81.  *                              (v| |*) |   (4|8|24) 
  82.  *
  83.  * </pre> <p>
  84.  *
  85.  * For more information see: <p>
  86.  *
  87.  * Friedman, J.H. (1999). Stochastic Gradient Boosting. Technical Report
  88.  * Stanford University. http://www-stat.stanford.edu/~jhf/ftp/stobst.ps. <p>
  89.  *
  90.  * Valid options from the command line are: <p>
  91.  * 
  92.  * -B classifierstring <br>
  93.  * Classifierstring should contain the full class name of a classifier
  94.  * followed by options to the classifier.
  95.  * (required).<p>
  96.  *
  97.  * -S shrinkage rate <br>
  98.  * Smaller values help prevent overfitting and have a smoothing effect 
  99.  * (but increase learning time).
  100.  * (default = 1.0, ie no shrinkage). <p>
  101.  *
  102.  * -M max models <br>
  103.  * Set the maximum number of models to generate. Values <= 0 indicate 
  104.  * no maximum, ie keep going until the reduction in error threshold is 
  105.  * reached.
  106.  * (default = -1). <p>
  107.  *
  108.  * -D <br>
  109.  * Debugging output. <p>
  110.  *
  111.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  112.  * @version $Revision: 1.6 $
  113.  */
  114. public class AdditiveRegression extends Classifier 
  115.   implements OptionHandler,
  116.      AdditionalMeasureProducer {
  117.   
  118.   /**
  119.    * Base classifier.
  120.    */
  121.   protected Classifier m_Classifier = new weka.classifiers.DecisionStump();
  122.   /**
  123.    * Class index.
  124.    */
  125.   private int m_classIndex;
  126.   /**
  127.    * Shrinkage (Learning rate). Default = no shrinkage.
  128.    */
  129.   protected double m_shrinkage = 1.0;
  130.   
  131.   /**
  132.    * The list of iteratively generated models.
  133.    */
  134.   private FastVector m_additiveModels = new FastVector();
  135.   /**
  136.    * Produce debugging output.
  137.    */
  138.   private boolean m_debug = false;
  139.   /**
  140.    * Maximum number of models to produce. -1 indicates keep going until the error
  141.    * threshold is met.
  142.    */
  143.   protected int m_maxModels = -1;
  144.   /**
  145.    * Returns a string describing this attribute evaluator
  146.    * @return a description of the evaluator suitable for
  147.    * displaying in the explorer/experimenter gui
  148.    */
  149.   public String globalInfo() {
  150.     return " Meta classifier that enhances the performance of a regression "
  151.       +"base classifier. Each iteration fits a model to the residuals left "
  152.       +"by the classifier on the previous iteration. Prediction is "
  153.       +"accomplished by adding the predictions of each classifier. "
  154.       +"Reducing the shrinkage (learning rate) parameter helps prevent "
  155.       +"overfitting and has a smoothing effect but increases the learning "
  156.       +"time.  For more information see: Friedman, J.H. (1999). Stochastic "
  157.       +"Gradient Boosting. Technical Report Stanford University. "
  158.       +"http://www-stat.stanford.edu/~jhf/ftp/stobst.ps.";
  159.   }
  160.   /**
  161.    * Default constructor specifying DecisionStump as the classifier
  162.    */
  163.   public AdditiveRegression() {
  164.     this(new weka.classifiers.DecisionStump());
  165.   }
  166.   /**
  167.    * Constructor which takes base classifier as argument.
  168.    *
  169.    * @param classifier the base classifier to use
  170.    */
  171.   public AdditiveRegression(Classifier classifier) {
  172.     m_Classifier = classifier;
  173.   }
  174.   /**
  175.    * Returns an enumeration describing the available options
  176.    *
  177.    * @return an enumeration of all the available options
  178.    */
  179.   public Enumeration listOptions() {
  180.     Vector newVector = new Vector(4);
  181.     newVector.addElement(new Option(
  182.       "tFull class name of classifier to use, followedn"
  183.       + "tby scheme options. (required)n"
  184.       + "teg: "weka.classifiers.NaiveBayes -D"",
  185.       "B", 1, "-B <classifier specification>"));
  186.     newVector.addElement(new Option(
  187.       "tSpecify shrinkage rate. "
  188.       +"(default=1.0, ie. no shrinkage)n", 
  189.       "S", 1, "-S"));
  190.     newVector.addElement(new Option(
  191.       "tTurn on debugging output.",
  192.       "D", 0, "-D"));
  193.     newVector.addElement(new Option(
  194.       "tSpecify max models to generate. "
  195.       +"(default = -1, ie. no max; keep going until error reduction threshold "
  196.       +"is reached)n", 
  197.       "M", 1, "-M"));
  198.      
  199.     return newVector.elements();
  200.   }
  201.   /**
  202.    * Parses a given list of options. Valid options are:<p>
  203.    *
  204.    * -B classifierstring <br>
  205.    * Classifierstring should contain the full class name of a classifier
  206.    * followed by options to the classifier.
  207.    * (required).<p>
  208.    *
  209.    * -S shrinkage rate <br>
  210.    * Smaller values help prevent overfitting and have a smoothing effect 
  211.    * (but increase learning time).
  212.    * (default = 1.0, ie. no shrinkage). <p>
  213.    *
  214.    * -D <br>
  215.    * Debugging output. <p>
  216.    *
  217.    * -M max models <br>
  218.    * Set the maximum number of models to generate. Values <= 0 indicate 
  219.    * no maximum, ie keep going until the reduction in error threshold is 
  220.    * reached.
  221.    * (default = -1). <p>
  222.    *
  223.    * @param options the list of options as an array of strings
  224.    * @exception Exception if an option is not supported
  225.    */
  226.   public void setOptions(String[] options) throws Exception {
  227.     setDebug(Utils.getFlag('D', options));
  228.     String classifierString = Utils.getOption('B', options);
  229.     if (classifierString.length() == 0) {
  230.       throw new Exception("A classifier must be specified"
  231.   + " with the -B option.");
  232.     }
  233.     String [] classifierSpec = Utils.splitOptions(classifierString);
  234.     if (classifierSpec.length == 0) {
  235.       throw new Exception("Invalid classifier specification string");
  236.     }
  237.     String classifierName = classifierSpec[0];
  238.     classifierSpec[0] = "";
  239.     setClassifier(Classifier.forName(classifierName, classifierSpec));
  240.     String optionString = Utils.getOption('S', options);
  241.     if (optionString.length() != 0) {
  242.       Double temp;
  243.       temp = Double.valueOf(optionString);
  244.       setShrinkage(temp.doubleValue());
  245.     }
  246.     optionString = Utils.getOption('M', options);
  247.     if (optionString.length() != 0) {
  248.       setMaxModels(Integer.parseInt(optionString));
  249.     }
  250.     Utils.checkForRemainingOptions(options);
  251.   }
  252.   /**
  253.    * Gets the current settings of the Classifier.
  254.    *
  255.    * @return an array of strings suitable for passing to setOptions
  256.    */
  257.   public String [] getOptions() {
  258.     
  259.     String [] options = new String [7];
  260.     int current = 0;
  261.     if (getDebug()) {
  262.       options[current++] = "-D";
  263.     }
  264.     options[current++] = "-B";
  265.     options[current++] = "" + getClassifierSpec();
  266.     options[current++] = "-S"; options[current++] = ""+getShrinkage();
  267.     options[current++] = "-M"; options[current++] = ""+getMaxModels();
  268.     while (current < options.length) {
  269.       options[current++] = "";
  270.     }
  271.     return options;
  272.   }
  273.   
  274.   /**
  275.    * Returns the tip text for this property
  276.    * @return tip text for this property suitable for
  277.    * displaying in the explorer/experimenter gui
  278.    */
  279.   public String debugTipText() {
  280.     return "Turn on debugging output";
  281.   }
  282.   /**
  283.    * Set whether debugging output is produced.
  284.    *
  285.    * @param d true if debugging output is to be produced
  286.    */
  287.   public void setDebug(boolean d) {
  288.     m_debug = d;
  289.   }
  290.   /**
  291.    * Gets whether debugging has been turned on
  292.    *
  293.    * @return true if debugging has been turned on
  294.    */
  295.   public boolean getDebug() {
  296.     return m_debug;
  297.   }
  298.   /**
  299.    * Returns the tip text for this property
  300.    * @return tip text for this property suitable for
  301.    * displaying in the explorer/experimenter gui
  302.    */
  303.   public String classifierTipText() {
  304.     return "Classifier to use";
  305.   }
  306.   /**
  307.    * Sets the classifier
  308.    *
  309.    * @param classifier the classifier with all options set.
  310.    */
  311.   public void setClassifier(Classifier classifier) {
  312.     m_Classifier = classifier;
  313.   }
  314.   /**
  315.    * Gets the classifier used.
  316.    *
  317.    * @return the classifier
  318.    */
  319.   public Classifier getClassifier() {
  320.     return m_Classifier;
  321.   }
  322.   
  323.   /**
  324.    * Gets the classifier specification string, which contains the class name of
  325.    * the classifier and any options to the classifier
  326.    *
  327.    * @return the classifier string.
  328.    */
  329.   protected String getClassifierSpec() {
  330.     
  331.     Classifier c = getClassifier();
  332.     if (c instanceof OptionHandler) {
  333.       return c.getClass().getName() + " "
  334. + Utils.joinOptions(((OptionHandler)c).getOptions());
  335.     }
  336.     return c.getClass().getName();
  337.   }
  338.   
  339.   /**
  340.    * Returns the tip text for this property
  341.    * @return tip text for this property suitable for
  342.    * displaying in the explorer/experimenter gui
  343.    */
  344.   public String maxModelsTipText() {
  345.     return "Max models to generate. <= 0 indicates no maximum, ie. continue until "
  346.       +"error reduction threshold is reached.";
  347.   }
  348.   /**
  349.    * Set the maximum number of models to generate
  350.    * @param maxM the maximum number of models
  351.    */
  352.   public void setMaxModels(int maxM) {
  353.     m_maxModels = maxM;
  354.   }
  355.   /**
  356.    * Get the max number of models to generate
  357.    * @return the max number of models to generate
  358.    */
  359.   public int getMaxModels() {
  360.     return m_maxModels;
  361.   }
  362.   /**
  363.    * Returns the tip text for this property
  364.    * @return tip text for this property suitable for
  365.    * displaying in the explorer/experimenter gui
  366.    */
  367.   public String shrinkageTipText() {
  368.     return "Shrinkage rate. Smaller values help prevent overfitting and "
  369.       + "have a smoothing effect (but increase learning time). "
  370.       +"Default = 1.0, ie. no shrinkage."; 
  371.   }
  372.   /**
  373.    * Set the shrinkage parameter
  374.    *
  375.    * @param l the shrinkage rate.
  376.    */
  377.   public void setShrinkage(double l) {
  378.     m_shrinkage = l;
  379.   }
  380.   /**
  381.    * Get the shrinkage rate.
  382.    *
  383.    * @return the value of the learning rate
  384.    */
  385.   public double getShrinkage() {
  386.     return m_shrinkage;
  387.   }
  388.   /**
  389.    * Build the classifier on the supplied data
  390.    *
  391.    * @param data the training data
  392.    * @exception Exception if the classifier could not be built successfully
  393.    */
  394.   public void buildClassifier(Instances data) throws Exception {
  395.      m_additiveModels = new FastVector();
  396.     if (m_Classifier == null) {
  397.       throw new Exception("No base classifiers have been set!");
  398.     }
  399.     if (data.classAttribute().isNominal()) {
  400.       throw new Exception("Class must be numeric!");
  401.     }
  402.     Instances newData = new Instances(data);
  403.     newData.deleteWithMissingClass();
  404.     m_classIndex = newData.classIndex();
  405.     double sum = 0;
  406.     double temp_sum = 0;
  407.     // Add the model for the mean first
  408.     ZeroR zr = new ZeroR();
  409.     zr.buildClassifier(newData);
  410.     m_additiveModels.addElement(zr);
  411.     newData = residualReplace(newData, zr);
  412.     sum = newData.attributeStats(m_classIndex).numericStats.sumSq;
  413.     if (m_debug) {
  414.       System.err.println("Sum of squared residuals "
  415.  +"(predicting the mean) : "+sum);
  416.     }
  417.     int modelCount = 0;
  418.     do {
  419.       temp_sum = sum;
  420.       Classifier nextC = Classifier.makeCopies(m_Classifier, 1)[0];
  421.       nextC.buildClassifier(newData);
  422.       m_additiveModels.addElement(nextC);
  423.       newData = residualReplace(newData, nextC);
  424.       sum = newData.attributeStats(m_classIndex).numericStats.sumSq;
  425.       if (m_debug) {
  426. System.err.println("Sum of squared residuals : "+sum);
  427.       }
  428.       modelCount++;
  429.     } while (((temp_sum - 
  430.    (sum = newData.attributeStats(m_classIndex).
  431.     numericStats.sumSq)) > Utils.SMALL) && 
  432.      (m_maxModels > 0 ? (modelCount < m_maxModels) : true));
  433.     // remove last classifier
  434.     m_additiveModels.removeElementAt(m_additiveModels.size()-1);
  435.   }
  436.   /**
  437.    * Classify an instance.
  438.    *
  439.    * @param inst the instance to predict
  440.    * @return a prediction for the instance
  441.    * @exception Exception if an error occurs
  442.    */
  443.   public double classifyInstance(Instance inst) throws Exception {
  444.     double prediction = 0;
  445.     for (int i = 0; i < m_additiveModels.size(); i++) {
  446.       Classifier current = (Classifier)m_additiveModels.elementAt(i);
  447.       prediction += (current.classifyInstance(inst) * getShrinkage());
  448.     }
  449.     return prediction;
  450.   }
  451.   /**
  452.    * Replace the class values of the instances from the current iteration
  453.    * with residuals ater predicting with the supplied classifier.
  454.    *
  455.    * @param data the instances to predict
  456.    * @param c the classifier to use
  457.    * @return a new set of instances with class values replaced by residuals
  458.    */
  459.   private Instances residualReplace(Instances data, Classifier c) {
  460.     double pred,residual;
  461.     Instances newInst = new Instances(data);
  462.     for (int i = 0; i < newInst.numInstances(); i++) {
  463.       try {
  464. pred = c.classifyInstance(newInst.instance(i)) * getShrinkage();
  465. residual = newInst.instance(i).classValue() - pred;
  466. // System.err.println("Residual : "+residual);
  467. newInst.instance(i).setClassValue(residual);
  468.       } catch (Exception ex) {
  469. // continue
  470.       }
  471.     }
  472.     //    System.err.print(newInst);
  473.     return newInst;
  474.   }
  475.   /**
  476.    * Returns an enumeration of the additional measure names
  477.    * @return an enumeration of the measure names
  478.    */
  479.   public Enumeration enumerateMeasures() {
  480.     Vector newVector = new Vector(1);
  481.     newVector.addElement("measureNumIterations");
  482.     return newVector.elements();
  483.   }
  484.   /**
  485.    * Returns the value of the named measure
  486.    * @param measureName the name of the measure to query for its value
  487.    * @return the value of the named measure
  488.    * @exception IllegalArgumentException if the named measure is not supported
  489.    */
  490.   public double getMeasure(String additionalMeasureName) {
  491.     if (additionalMeasureName.compareTo("measureNumIterations") == 0) {
  492.       return measureNumIterations();
  493.     } else {
  494.       throw new IllegalArgumentException(additionalMeasureName 
  495.   + " not supported (AdditiveRegression)");
  496.     }
  497.   }
  498.   /**
  499.    * return the number of iterations (base classifiers) completed
  500.    * @return the number of iterations (same as number of base classifier
  501.    * models)
  502.    */
  503.   public double measureNumIterations() {
  504.     return m_additiveModels.size();
  505.   }
  506.   /**
  507.    * Returns textual description of the classifier.
  508.    *
  509.    * @return a description of the classifier as a string
  510.    */
  511.   public String toString() {
  512.     StringBuffer text = new StringBuffer();
  513.     if (m_additiveModels.size() == 0) {
  514.       return "Classifier hasn't been built yet!";
  515.     }
  516.     text.append("Additive Regressionnn");
  517.     text.append("Base classifier " 
  518. + getClassifier().getClass().getName()
  519. + "nn");
  520.     text.append(""+m_additiveModels.size()+" models generated.n");
  521.     return text.toString();
  522.   }
  523.   /**
  524.    * Main method for testing this class.
  525.    *
  526.    * @param argv should contain the following arguments:
  527.    * -t training file [-T test file] [-c class index]
  528.    */
  529.   public static void main(String [] argv) {
  530.     try {
  531.       System.out.println(Evaluation.evaluateModel(new AdditiveRegression(),
  532.   argv));
  533.     } catch (Exception e) {
  534.       System.err.println(e.getMessage());
  535.     }
  536.   }
  537. }