AdditiveRegression.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 18k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    AdditiveRegression.java
  18.  *    Copyright (C) 2000 Mark Hall
  19.  *
  20.  */
  21. package weka.classifiers.meta;
  22. import weka.classifiers.Classifier;
  23. import weka.classifiers.Evaluation;
  24. import weka.classifiers.trees.DecisionStump;
  25. import weka.classifiers.rules.ZeroR;
  26. import java.io.*;
  27. import java.util.*;
  28. import weka.core.*;
  29. import weka.classifiers.meta.*;
  30. /**
  31.  * Meta classifier that enhances the performance of a regression base
  32.  * classifier. Each iteration fits a model to the residuals left by the
  33.  * classifier on the previous iteration. Prediction is accomplished by
  34.  * adding the predictions of each classifier. Smoothing is accomplished
  35.  * through varying the shrinkage (learning rate) parameter. <p>
  36.  *
  37.  * <pre>
  38.  * Analysing:  Root_relative_squared_error
  39.  * Datasets:   36
  40.  * Resultsets: 2
  41.  * Confidence: 0.05 (two tailed)
  42.  * Date:       10/13/00 10:00 AM
  43.  *
  44.  *
  45.  * Dataset                   (1) m5.M5Prim | (2) AdditiveRegression -S 0.7 
  46.  *                                         |    -B weka.classifiers.meta.m5.M5Prime 
  47.  *                          ----------------------------
  48.  * auto93.names              (10)    54.4  |    49.41 * 
  49.  * autoHorse.names           (10)    32.76 |    26.34 * 
  50.  * autoMpg.names             (10)    35.32 |    34.84 * 
  51.  * autoPrice.names           (10)    40.01 |    36.57 * 
  52.  * baskball                  (10)    79.46 |    79.85   
  53.  * bodyfat.names             (10)    10.38 |    11.41 v 
  54.  * bolts                     (10)    19.29 |    12.61 * 
  55.  * breastTumor               (10)    96.95 |    96.23 * 
  56.  * cholesterol               (10)   101.03 |    98.88 * 
  57.  * cleveland                 (10)    71.29 |    70.87 * 
  58.  * cloud                     (10)    38.82 |    39.18   
  59.  * cpu                       (10)    22.26 |    14.74 * 
  60.  * detroit                   (10)   228.16 |    83.7  * 
  61.  * echoMonths                (10)    71.52 |    69.15 * 
  62.  * elusage                   (10)    48.94 |    49.03   
  63.  * fishcatch                 (10)    16.61 |    15.36 * 
  64.  * fruitfly                  (10)   100    |   100    * 
  65.  * gascons                   (10)    18.72 |    14.26 * 
  66.  * housing                   (10)    38.62 |    36.53 * 
  67.  * hungarian                 (10)    74.67 |    72.19 * 
  68.  * longley                   (10)    31.23 |    28.26 * 
  69.  * lowbwt                    (10)    62.26 |    61.48 * 
  70.  * mbagrade                  (10)    89.2  |    89.2    
  71.  * meta                      (10)   163.15 |   188.28 v 
  72.  * pbc                       (10)    81.35 |    79.4  * 
  73.  * pharynx                   (10)   105.41 |   105.03   
  74.  * pollution                 (10)    72.24 |    68.16 * 
  75.  * pwLinear                  (10)    32.42 |    33.33 v 
  76.  * quake                     (10)   100.21 |    99.93   
  77.  * schlvote                  (10)    92.41 |    98.23 v 
  78.  * sensory                   (10)    88.03 |    87.94   
  79.  * servo                     (10)    37.07 |    35.5  * 
  80.  * sleep                     (10)    70.17 |    71.65   
  81.  * strike                    (10)    84.98 |    83.96 * 
  82.  * veteran                   (10)    90.61 |    88.77 * 
  83.  * vineyard                  (10)    79.41 |    73.95 * 
  84.  *                        ----------------------------
  85.  *                              (v| |*) |   (4|8|24) 
  86.  *
  87.  * </pre> <p>
  88.  *
  89.  * For more information see: <p>
  90.  *
  91.  * Friedman, J.H. (1999). Stochastic Gradient Boosting. Technical Report
  92.  * Stanford University. http://www-stat.stanford.edu/~jhf/ftp/stobst.ps. <p>
  93.  *
  94.  * Valid options from the command line are: <p>
  95.  * 
  96.  * -B classifierstring <br>
  97.  * Classifierstring should contain the full class name of a classifier
  98.  * followed by options to the classifier.
  99.  * (required).<p>
  100.  *
  101.  * -S shrinkage rate <br>
  102.  * Smaller values help prevent overfitting and have a smoothing effect 
  103.  * (but increase learning time).
  104.  * (default = 1.0, ie no shrinkage). <p>
  105.  *
  106.  * -M max models <br>
  107.  * Set the maximum number of models to generate. Values <= 0 indicate 
  108.  * no maximum, ie keep going until the reduction in error threshold is 
  109.  * reached.
  110.  * (default = -1). <p>
  111.  *
  112.  * -D <br>
  113.  * Debugging output. <p>
  114.  *
  115.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  116.  * @version $Revision: 1.10 $
  117.  */
  118. public class AdditiveRegression extends Classifier 
  119.   implements OptionHandler,
  120.      AdditionalMeasureProducer,
  121.      WeightedInstancesHandler {
  122.   
  123.   /**
  124.    * Base classifier.
  125.    */
  126.   protected Classifier m_Classifier = new weka.classifiers.trees.DecisionStump();
  127.   /**
  128.    * Class index.
  129.    */
  130.   private int m_classIndex;
  131.   /**
  132.    * Shrinkage (Learning rate). Default = no shrinkage.
  133.    */
  134.   protected double m_shrinkage = 1.0;
  135.   
  136.   /**
  137.    * The list of iteratively generated models.
  138.    */
  139.   private FastVector m_additiveModels = new FastVector();
  140.   /**
  141.    * Produce debugging output.
  142.    */
  143.   private boolean m_debug = false;
  144.   /**
  145.    * Maximum number of models to produce. -1 indicates keep going until the error
  146.    * threshold is met.
  147.    */
  148.   protected int m_maxModels = -1;
  149.   /**
  150.    * Returns a string describing this attribute evaluator
  151.    * @return a description of the evaluator suitable for
  152.    * displaying in the explorer/experimenter gui
  153.    */
  154.   public String globalInfo() {
  155.     return " Meta classifier that enhances the performance of a regression "
  156.       +"base classifier. Each iteration fits a model to the residuals left "
  157.       +"by the classifier on the previous iteration. Prediction is "
  158.       +"accomplished by adding the predictions of each classifier. "
  159.       +"Reducing the shrinkage (learning rate) parameter helps prevent "
  160.       +"overfitting and has a smoothing effect but increases the learning "
  161.       +"time.  For more information see: Friedman, J.H. (1999). Stochastic "
  162.       +"Gradient Boosting. Technical Report Stanford University. "
  163.       +"http://www-stat.stanford.edu/~jhf/ftp/stobst.ps.";
  164.   }
  165.   /**
  166.    * Default constructor specifying DecisionStump as the classifier
  167.    */
  168.   public AdditiveRegression() {
  169.     this(new weka.classifiers.trees.DecisionStump());
  170.   }
  171.   /**
  172.    * Constructor which takes base classifier as argument.
  173.    *
  174.    * @param classifier the base classifier to use
  175.    */
  176.   public AdditiveRegression(Classifier classifier) {
  177.     m_Classifier = classifier;
  178.   }
  179.   /**
  180.    * Returns an enumeration describing the available options.
  181.    *
  182.    * @return an enumeration of all the available options.
  183.    */
  184.   public Enumeration listOptions() {
  185.     Vector newVector = new Vector(4);
  186.     newVector.addElement(new Option(
  187.       "tFull class name of classifier to use, followedn"
  188.       + "tby scheme options. (required)n"
  189.       + "teg: "weka.classifiers.bayes.NaiveBayes -D"",
  190.       "B", 1, "-B <classifier specification>"));
  191.     newVector.addElement(new Option(
  192.       "tSpecify shrinkage rate. "
  193.       +"(default=1.0, ie. no shrinkage)n", 
  194.       "S", 1, "-S"));
  195.     newVector.addElement(new Option(
  196.       "tTurn on debugging output.",
  197.       "D", 0, "-D"));
  198.     newVector.addElement(new Option(
  199.       "tSpecify max models to generate. "
  200.       +"(default = -1, ie. no max; keep going until error reduction threshold "
  201.       +"is reached)n", 
  202.       "M", 1, "-M"));
  203.      
  204.     return newVector.elements();
  205.   }
  206.   /**
  207.    * Parses a given list of options. Valid options are:<p>
  208.    *
  209.    * -B classifierstring <br>
  210.    * Classifierstring should contain the full class name of a classifier
  211.    * followed by options to the classifier.
  212.    * (required).<p>
  213.    *
  214.    * -S shrinkage rate <br>
  215.    * Smaller values help prevent overfitting and have a smoothing effect 
  216.    * (but increase learning time).
  217.    * (default = 1.0, ie. no shrinkage). <p>
  218.    *
  219.    * -D <br>
  220.    * Debugging output. <p>
  221.    *
  222.    * -M max models <br>
  223.    * Set the maximum number of models to generate. Values <= 0 indicate 
  224.    * no maximum, ie keep going until the reduction in error threshold is 
  225.    * reached.
  226.    * (default = -1). <p>
  227.    *
  228.    * @param options the list of options as an array of strings
  229.    * @exception Exception if an option is not supported
  230.    */
  231.   public void setOptions(String[] options) throws Exception {
  232.     setDebug(Utils.getFlag('D', options));
  233.     String classifierString = Utils.getOption('B', options);
  234.     if (classifierString.length() == 0) {
  235.       throw new Exception("A classifier must be specified"
  236.   + " with the -B option.");
  237.     }
  238.     String [] classifierSpec = Utils.splitOptions(classifierString);
  239.     if (classifierSpec.length == 0) {
  240.       throw new Exception("Invalid classifier specification string");
  241.     }
  242.     String classifierName = classifierSpec[0];
  243.     classifierSpec[0] = "";
  244.     setClassifier(Classifier.forName(classifierName, classifierSpec));
  245.     String optionString = Utils.getOption('S', options);
  246.     if (optionString.length() != 0) {
  247.       Double temp;
  248.       temp = Double.valueOf(optionString);
  249.       setShrinkage(temp.doubleValue());
  250.     }
  251.     optionString = Utils.getOption('M', options);
  252.     if (optionString.length() != 0) {
  253.       setMaxModels(Integer.parseInt(optionString));
  254.     }
  255.     Utils.checkForRemainingOptions(options);
  256.   }
  257.   /**
  258.    * Gets the current settings of the Classifier.
  259.    *
  260.    * @return an array of strings suitable for passing to setOptions
  261.    */
  262.   public String [] getOptions() {
  263.     
  264.     String [] options = new String [7];
  265.     int current = 0;
  266.     if (getDebug()) {
  267.       options[current++] = "-D";
  268.     }
  269.     options[current++] = "-B";
  270.     options[current++] = "" + getClassifierSpec();
  271.     options[current++] = "-S"; options[current++] = ""+getShrinkage();
  272.     options[current++] = "-M"; options[current++] = ""+getMaxModels();
  273.     while (current < options.length) {
  274.       options[current++] = "";
  275.     }
  276.     return options;
  277.   }
  278.   
  279.   /**
  280.    * Returns the tip text for this property
  281.    * @return tip text for this property suitable for
  282.    * displaying in the explorer/experimenter gui
  283.    */
  284.   public String debugTipText() {
  285.     return "Turn on debugging output";
  286.   }
  287.   /**
  288.    * Set whether debugging output is produced.
  289.    *
  290.    * @param d true if debugging output is to be produced
  291.    */
  292.   public void setDebug(boolean d) {
  293.     m_debug = d;
  294.   }
  295.   /**
  296.    * Gets whether debugging has been turned on
  297.    *
  298.    * @return true if debugging has been turned on
  299.    */
  300.   public boolean getDebug() {
  301.     return m_debug;
  302.   }
  303.   /**
  304.    * Returns the tip text for this property
  305.    * @return tip text for this property suitable for
  306.    * displaying in the explorer/experimenter gui
  307.    */
  308.   public String classifierTipText() {
  309.     return "Classifier to use";
  310.   }
  311.   /**
  312.    * Sets the classifier
  313.    *
  314.    * @param classifier the classifier with all options set.
  315.    */
  316.   public void setClassifier(Classifier classifier) {
  317.     m_Classifier = classifier;
  318.   }
  319.   /**
  320.    * Gets the classifier used.
  321.    *
  322.    * @return the classifier
  323.    */
  324.   public Classifier getClassifier() {
  325.     return m_Classifier;
  326.   }
  327.   
  328.   /**
  329.    * Gets the classifier specification string, which contains the class name of
  330.    * the classifier and any options to the classifier
  331.    *
  332.    * @return the classifier string.
  333.    */
  334.   protected String getClassifierSpec() {
  335.     
  336.     Classifier c = getClassifier();
  337.     if (c instanceof OptionHandler) {
  338.       return c.getClass().getName() + " "
  339. + Utils.joinOptions(((OptionHandler)c).getOptions());
  340.     }
  341.     return c.getClass().getName();
  342.   }
  343.   
  344.   /**
  345.    * Returns the tip text for this property
  346.    * @return tip text for this property suitable for
  347.    * displaying in the explorer/experimenter gui
  348.    */
  349.   public String maxModelsTipText() {
  350.     return "Max models to generate. <= 0 indicates no maximum, ie. continue until "
  351.       +"error reduction threshold is reached.";
  352.   }
  353.   /**
  354.    * Set the maximum number of models to generate
  355.    * @param maxM the maximum number of models
  356.    */
  357.   public void setMaxModels(int maxM) {
  358.     m_maxModels = maxM;
  359.   }
  360.   /**
  361.    * Get the max number of models to generate
  362.    * @return the max number of models to generate
  363.    */
  364.   public int getMaxModels() {
  365.     return m_maxModels;
  366.   }
  367.   /**
  368.    * Returns the tip text for this property
  369.    * @return tip text for this property suitable for
  370.    * displaying in the explorer/experimenter gui
  371.    */
  372.   public String shrinkageTipText() {
  373.     return "Shrinkage rate. Smaller values help prevent overfitting and "
  374.       + "have a smoothing effect (but increase learning time). "
  375.       +"Default = 1.0, ie. no shrinkage."; 
  376.   }
  377.   /**
  378.    * Set the shrinkage parameter
  379.    *
  380.    * @param l the shrinkage rate.
  381.    */
  382.   public void setShrinkage(double l) {
  383.     m_shrinkage = l;
  384.   }
  385.   /**
  386.    * Get the shrinkage rate.
  387.    *
  388.    * @return the value of the learning rate
  389.    */
  390.   public double getShrinkage() {
  391.     return m_shrinkage;
  392.   }
  393.   /**
  394.    * Build the classifier on the supplied data
  395.    *
  396.    * @param data the training data
  397.    * @exception Exception if the classifier could not be built successfully
  398.    */
  399.   public void buildClassifier(Instances data) throws Exception {
  400.      m_additiveModels = new FastVector();
  401.     if (m_Classifier == null) {
  402.       throw new Exception("No base classifiers have been set!");
  403.     }
  404.     if (data.classAttribute().isNominal()) {
  405.       throw new UnsupportedClassTypeException("Class must be numeric!");
  406.     }
  407.     Instances newData = new Instances(data);
  408.     newData.deleteWithMissingClass();
  409.     m_classIndex = newData.classIndex();
  410.     double sum = 0;
  411.     double temp_sum = 0;
  412.     // Add the model for the mean first
  413.     ZeroR zr = new ZeroR();
  414.     zr.buildClassifier(newData);
  415.     m_additiveModels.addElement(zr);
  416.     newData = residualReplace(newData, zr);
  417.     for (int i = 0; i < newData.numInstances(); i++) {
  418.       sum += newData.instance(i).weight() *
  419. newData.instance(i).classValue() *
  420. newData.instance(i).classValue();
  421.     }
  422.     if (m_debug) {
  423.       System.err.println("Sum of squared residuals "
  424.  +"(predicting the mean) : "+sum);
  425.     }
  426.     int modelCount = 0;
  427.     do {
  428.       temp_sum = sum;
  429.       Classifier nextC = Classifier.makeCopies(m_Classifier, 1)[0];
  430.       nextC.buildClassifier(newData);
  431.       m_additiveModels.addElement(nextC);
  432.       newData = residualReplace(newData, nextC);
  433.       sum = 0;
  434.       for (int i = 0; i < newData.numInstances(); i++) {
  435. sum += newData.instance(i).weight() *
  436.   newData.instance(i).classValue() *
  437.   newData.instance(i).classValue();
  438.       }
  439.       if (m_debug) {
  440. System.err.println("Sum of squared residuals : "+sum);
  441.       }
  442.       modelCount++;
  443.     } while (((temp_sum - sum) > Utils.SMALL) && 
  444.      (m_maxModels > 0 ? (modelCount < m_maxModels) : true));
  445.     // remove last classifier
  446.     m_additiveModels.removeElementAt(m_additiveModels.size()-1);
  447.   }
  448.   /**
  449.    * Classify an instance.
  450.    *
  451.    * @param inst the instance to predict
  452.    * @return a prediction for the instance
  453.    * @exception Exception if an error occurs
  454.    */
  455.   public double classifyInstance(Instance inst) throws Exception {
  456.     double prediction = 0;
  457.     for (int i = 0; i < m_additiveModels.size(); i++) {
  458.       Classifier current = (Classifier)m_additiveModels.elementAt(i);
  459.       prediction += (current.classifyInstance(inst) * getShrinkage());
  460.     }
  461.     return prediction;
  462.   }
  463.   /**
  464.    * Replace the class values of the instances from the current iteration
  465.    * with residuals ater predicting with the supplied classifier.
  466.    *
  467.    * @param data the instances to predict
  468.    * @param c the classifier to use
  469.    * @return a new set of instances with class values replaced by residuals
  470.    */
  471.   private Instances residualReplace(Instances data, Classifier c) {
  472.     double pred,residual;
  473.     Instances newInst = new Instances(data);
  474.     for (int i = 0; i < newInst.numInstances(); i++) {
  475.       try {
  476. pred = c.classifyInstance(newInst.instance(i)) * getShrinkage();
  477. residual = newInst.instance(i).classValue() - pred;
  478. // System.err.println("Residual : "+residual);
  479. newInst.instance(i).setClassValue(residual);
  480.       } catch (Exception ex) {
  481. // continue
  482.       }
  483.     }
  484.     //    System.err.print(newInst);
  485.     return newInst;
  486.   }
  487.   /**
  488.    * Returns an enumeration of the additional measure names
  489.    * @return an enumeration of the measure names
  490.    */
  491.   public Enumeration enumerateMeasures() {
  492.     Vector newVector = new Vector(1);
  493.     newVector.addElement("measureNumIterations");
  494.     return newVector.elements();
  495.   }
  496.   /**
  497.    * Returns the value of the named measure
  498.    * @param measureName the name of the measure to query for its value
  499.    * @return the value of the named measure
  500.    * @exception IllegalArgumentException if the named measure is not supported
  501.    */
  502.   public double getMeasure(String additionalMeasureName) {
  503.     if (additionalMeasureName.compareTo("measureNumIterations") == 0) {
  504.       return measureNumIterations();
  505.     } else {
  506.       throw new IllegalArgumentException(additionalMeasureName 
  507.   + " not supported (AdditiveRegression)");
  508.     }
  509.   }
  510.   /**
  511.    * return the number of iterations (base classifiers) completed
  512.    * @return the number of iterations (same as number of base classifier
  513.    * models)
  514.    */
  515.   public double measureNumIterations() {
  516.     return m_additiveModels.size();
  517.   }
  518.   /**
  519.    * Returns textual description of the classifier.
  520.    *
  521.    * @return a description of the classifier as a string
  522.    */
  523.   public String toString() {
  524.     StringBuffer text = new StringBuffer();
  525.     if (m_additiveModels.size() == 0) {
  526.       return "Classifier hasn't been built yet!";
  527.     }
  528.     text.append("Additive Regressionnn");
  529.     text.append("Base classifier " 
  530. + getClassifier().getClass().getName()
  531. + "nn");
  532.     text.append(""+m_additiveModels.size()+" models generated.n");
  533.     return text.toString();
  534.   }
  535.   /**
  536.    * Main method for testing this class.
  537.    *
  538.    * @param argv should contain the following arguments:
  539.    * -t training file [-T test file] [-c class index]
  540.    */
  541.   public static void main(String [] argv) {
  542.     try {
  543.       System.out.println(Evaluation.evaluateModel(new AdditiveRegression(),
  544.   argv));
  545.     } catch (Exception e) {
  546.       System.err.println(e.getMessage());
  547.     }
  548.   }
  549. }