M5Prime.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 14k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    M5Prime.java
  18.  *    Copyright (C) 1999 Yong Wang
  19.  *
  20.  */
  21. package weka.classifiers.m5;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. import weka.classifiers.*;
  26. import weka.filters.*;
  27. /**
  28.  * Class for contructing and evaluating model trees; M5' algorithm. <p>
  29.  *
  30.  * Reference: Wang, Y. and Witten, I.H. (1997). <i> Induction of model
  31.  * trees for predicting continuous classes.</i> Proceedings of the poster
  32.  * papers of the European Conference on Machine Learning. University of
  33.  * Economics, Faculty of Informatics and Statistics, Prague. <p>
  34.  *
  35.  * Valid options are: <p>
  36.  *
  37.  * -O <l|r|m> <br>
  38.  * Type of model to be used. (l: linear regression, 
  39.  * r: regression tree, m: model tree) (default: m) <p>
  40.  *
  41.  * -U <br>
  42.  * Use unsmoothed tree. <p>
  43.  *
  44.  * -F factor <br>
  45.  * Set pruning factor (default: 2). <p>
  46.  *
  47.  * -V <0|1|2> <br>
  48.  * Verbosity (default: 0). <p>
  49.  *
  50.  * @author Yong Wang (yongwang@cs.waikato.ac.nz)
  51.  * @version $Revision: 1.15.2.1 $
  52.  */
  53. public final class  M5Prime extends Classifier implements OptionHandler,
  54.  AdditionalMeasureProducer {
  55.   
  56.   /** The root node */
  57.   private Node m_root[];
  58.   /** The options */
  59.   private Options options;
  60.   /** No smoothing? */
  61.   private boolean m_UseUnsmoothed = false;
  62.   /** Pruning factor */
  63.   private double m_PruningFactor = 2;
  64.   /** Type of model */
  65.   private int m_Model = Node.MODEL_TREE;
  66.   /** Verbosity */
  67.   private int m_Verbosity = 0;
  68.   /** Filter for replacing missing values. */
  69.   private ReplaceMissingValuesFilter m_ReplaceMissingValuesFilter;
  70.   /** Filter for replacing nominal attributes with numeric binary ones. */
  71.   private NominalToBinaryFilter m_NominalToBinaryFilter;
  72.   public static final int MODEL_LINEAR_REGRESSION = Node.LINEAR_REGRESSION;
  73.   public static final int MODEL_REGRESSION_TREE = Node.REGRESSION_TREE;
  74.   public static final int MODEL_MODEL_TREE = Node.MODEL_TREE;
  75.   public static final Tag [] TAGS_MODEL_TYPES = {
  76.     new Tag(MODEL_LINEAR_REGRESSION, "Simple linear regression"),
  77.     new Tag(MODEL_REGRESSION_TREE, "Regression tree"),
  78.     new Tag(MODEL_MODEL_TREE, "Model tree")
  79.   };
  80.   
  81.   /**
  82.    * Construct a model tree by training instances
  83.    *
  84.    * @param inst training instances
  85.    * @param options information for constructing the model tree, 
  86.    * mostly from command line options
  87.    * @return the root of the model tree
  88.    * @exception Exception if the classifier can't be built
  89.    */
  90.   public final void buildClassifier(Instances inst) throws Exception{
  91.     if (inst.checkForStringAttributes()) {
  92.       throw new Exception("Can't handle string attributes!");
  93.     }
  94.     options = new Options(inst);
  95.     options.model = m_Model;
  96.     options.smooth = !m_UseUnsmoothed;
  97.     options.pruningFactor = m_PruningFactor;
  98.     options.verbosity = m_Verbosity;
  99.     if(!inst.classAttribute().isNumeric()) 
  100.       throw new Exception("Class has to be numeric."); 
  101.     inst = new Instances(inst);
  102.     inst.deleteWithMissingClass();
  103.     m_ReplaceMissingValuesFilter = new ReplaceMissingValuesFilter();
  104.     m_ReplaceMissingValuesFilter.setInputFormat(inst);
  105.     inst = Filter.useFilter(inst, m_ReplaceMissingValuesFilter);
  106.     m_NominalToBinaryFilter = new NominalToBinaryFilter();
  107.     m_NominalToBinaryFilter.setInputFormat(inst);
  108.     inst = Filter.useFilter(inst, m_NominalToBinaryFilter);
  109.     
  110.     m_root = new Node[2];
  111.     options.deviation = M5Utils.stdDev(inst.classIndex(),inst);
  112.     m_root[0] = new Node(inst,null,options);       // build an empty tree
  113.     m_root[0].split(inst);         // build the unpruned initial tree
  114.     m_root[0].numLeaves(0);       // set tree leaves' number of the unpruned treee
  115.     m_root[1] = m_root[0].copy(null);  // make a copy of the unpruned tree
  116.     m_root[1].prune();            // prune the tree
  117.     if(options.model != Node.LINEAR_REGRESSION){
  118.       m_root[1].smoothen();    // compute the smoothed linear models at the leaves
  119.       m_root[1].numLeaves(0);  // set tree leaves' number of the pruned tree
  120.     }
  121.   }
  122.   /**
  123.    * Classifies the given test instance.
  124.    *
  125.    * @param instance the instance to be classified
  126.    * @return the predicted class for the instance 
  127.    * @exception Exception if the instance can't be classified
  128.    */
  129.   public double classifyInstance(Instance ins) throws Exception {
  130.     m_ReplaceMissingValuesFilter.input(ins);
  131.     m_ReplaceMissingValuesFilter.batchFinished();
  132.     ins = m_ReplaceMissingValuesFilter.output();
  133.     m_NominalToBinaryFilter.input(ins);
  134.     m_NominalToBinaryFilter.batchFinished();
  135.     ins = m_NominalToBinaryFilter.output();
  136.     return m_root[1].predict(ins,!m_UseUnsmoothed);
  137.   }
  138.   /**
  139.    * Returns an enumeration describing the available options.
  140.    *
  141.    * Valid options are: <p>
  142.    *
  143.    * -O <l|r|m> <br>
  144.    * Type of model to be used. (l: linear regression, 
  145.    * r: regression tree, m: model tree) (default: m) <p>
  146.    *
  147.    * -U <br>
  148.    * Use unsmoothed tree. <p>
  149.    *
  150.    * -F factor <br>
  151.    * Set pruning factor (default: 2). <p>
  152.    *
  153.    * -V <0|1|2> <br>
  154.    * Verbosity (default: 0). <p>
  155.    *
  156.    * @return an enumeration of all the available options
  157.    */
  158.   public Enumeration listOptions() {
  159.     Vector newVector = new Vector(4);
  160.     newVector.addElement(new Option("tType of model to be used.n"+
  161.     "tl: linear regressionn"+
  162.     "tr: regression treen"+
  163.     "tm: model treen"+
  164.     "t(default: m)",
  165.     "-O", 1, "-O <l|r|m>"));
  166.     newVector.addElement(new Option("tUse unsmoothed tree.", "C", 0, 
  167.     "-U"));
  168.     newVector.addElement(new Option("tPruning factor (default: 2).",
  169.     "-F", 1, "-F <double>"));
  170.     newVector.addElement(new Option("tVerbosity (default: 0).",
  171.     "-V", 1, "-V <0|1|2>"));
  172.     return newVector.elements();
  173.   }
  174.   /**
  175.    * Parses a given list of options.
  176.    *
  177.    * @param options the list of options as an array of strings
  178.    * @exception Exception if an option is not supported
  179.    */
  180.   public void setOptions(String[] options) throws Exception{
  181.     String modelString = Utils.getOption('O', options);
  182.     if (modelString.length() != 0) {
  183.       if (modelString.equals("l"))
  184. setModelType(new SelectedTag(MODEL_LINEAR_REGRESSION,
  185.      TAGS_MODEL_TYPES));
  186.       else if (modelString.equals("r"))
  187. setModelType(new SelectedTag(MODEL_REGRESSION_TREE,
  188.      TAGS_MODEL_TYPES));
  189.       else if (modelString.equals("m"))
  190. setModelType(new SelectedTag(MODEL_MODEL_TREE,
  191.      TAGS_MODEL_TYPES));
  192.       else
  193. throw new Exception("Don't know model type " + modelString);
  194.     } else {
  195.       setModelType(new SelectedTag(MODEL_MODEL_TREE,
  196.    TAGS_MODEL_TYPES));
  197.     }
  198.     
  199.     setUseUnsmoothed(Utils.getFlag('U', options));
  200.     if (m_Model != Node.MODEL_TREE) {
  201.       setUseUnsmoothed(true);
  202.     }
  203.     String pruningString = Utils.getOption('F', options);
  204.     if (pruningString.length() != 0) {
  205.       setPruningFactor((new Double(pruningString)).doubleValue());
  206.     } else {
  207.       setPruningFactor(2);
  208.     }
  209.     
  210.     String verbosityString = Utils.getOption('V', options);
  211.     if (verbosityString.length() != 0) {
  212.       setVerbosity(Integer.parseInt(verbosityString));
  213.     } else {
  214.       setVerbosity(0);
  215.     }
  216.   }
  217.   /**
  218.    * Gets the current settings of the Classifier.
  219.    *
  220.    * @return an array of strings suitable for passing to setOptions
  221.    */
  222.   public String [] getOptions() {
  223.     String [] options = new String [7];
  224.     int current = 0;
  225.     switch (m_Model) {
  226.     case MODEL_MODEL_TREE:
  227.       options[current++] = "-O"; options[current++] = "m";
  228.       if (m_UseUnsmoothed) {
  229. options[current++] = "-U";
  230.       }
  231.       break;
  232.     case MODEL_REGRESSION_TREE:
  233.       options[current++] = "-O"; options[current++] = "r";
  234.       break;
  235.     case MODEL_LINEAR_REGRESSION:
  236.       options[current++] = "-O"; options[current++] = "l";
  237.       break;
  238.     }
  239.     options[current++] = "-F"; options[current++] = "" + m_PruningFactor;
  240.     options[current++] = "-V"; options[current++] = "" + m_Verbosity;
  241.     while (current < options.length) {
  242.       options[current++] = "";
  243.     }
  244.     return options;
  245.   }
  246.   /**
  247.    * Converts the output of the training process into a string
  248.    *
  249.    * @return the converted string
  250.    */
  251.   public final String toString() {
  252.     try{
  253.       StringBuffer text = new StringBuffer();
  254.       double absDev = M5Utils.absDev(m_root[0].instances.classIndex(),
  255.      m_root[0].instances);
  256.       
  257.       if(options.verbosity >= 1 && options.model != Node.LINEAR_REGRESSION){
  258. switch(m_root[0].model){
  259. case Node.LINEAR_REGRESSION: 
  260.   break;
  261. case Node.REGRESSION_TREE  : 
  262.   text.append("Unpruned training regression tree:n"); break;
  263. case Node.MODEL_TREE       : 
  264.   text.append("Unpruned training model tree:n"); break;
  265. }     
  266. if(m_root[0].type == false)text.append("n");
  267. text.append(m_root[0].treeToString(0,absDev)+ "n");
  268. text.append("Models at the leaves:nn");
  269. //    the linear models at the leaves of the unpruned tree
  270. text.append(m_root[0].formulaeToString(false) + "n");;  
  271.       }
  272.       
  273.       if(m_root[0].model != Node.LINEAR_REGRESSION){
  274. switch(m_root[0].model){
  275. case Node.LINEAR_REGRESSION: 
  276.   break;
  277. case Node.REGRESSION_TREE  : 
  278.   text.append("Pruned training regression tree:n"); break;
  279. case Node.MODEL_TREE       : 
  280.   text.append("Pruned training model tree:n"); break;
  281. }
  282. if(m_root[1].type == false)text.append("n");
  283. text.append(m_root[1].treeToString(0,absDev) + "n"); //the pruned tree
  284. text.append("Models at the leaves:nn");
  285. if ((m_root[0].model != Node.LINEAR_REGRESSION) &&
  286.     (m_UseUnsmoothed)) {
  287.   text.append("  Unsmoothed (simple):nn");
  288.   //     the unsmoothed linear models at the leaves of the pruned tree
  289.   text.append(m_root[1].formulaeToString(false) + "n");
  290. }
  291. if ((m_root[0].model == Node.MODEL_TREE) &&
  292.     (!m_UseUnsmoothed)) {
  293.   text.append("  Smoothed (complex):nn");
  294.   text.append(m_root[1].formulaeToString(true) + "n");
  295.   //   the smoothed linear models at the leaves of the pruned tree
  296. }
  297.       }
  298.       else {
  299. text.append("Training linear regression model:nn");
  300. text.append(m_root[1].unsmoothed.toString(m_root[1].instances,0) + "nn");
  301. //       print the linear regression model
  302.       }
  303.       
  304.       text.append("Number of Rules : "+m_root[1].numberOfLinearModels());
  305.       return text.toString();
  306.     } catch (Exception e) {
  307.       return "can't print m5' tree";
  308.     }
  309.   }
  310.   /**
  311.    * return the number of linear models
  312.    * @return the number of linear models
  313.    */
  314.   public double measureNumLinearModels() {
  315.     return m_root[1].numberOfLinearModels();
  316.   }
  317.   /**
  318.    * return the number of leaves in the tree
  319.    * @return the number leaves in the tree (same as # linear models &
  320.    * # rules)
  321.    */
  322.   public double measureNumLeaves() {
  323.     return measureNumLinearModels();
  324.   }
  325.   /**
  326.    * return the number of rules
  327.    * @return the number of rules (same as # linear models &
  328.    * # leaves in the tree)
  329.    */
  330.   public double measureNumRules() {
  331.     return measureNumLinearModels();
  332.   }
  333.   /**
  334.    * Returns an enumeration of the additional measure names
  335.    * @return an enumeration of the measure names
  336.    */
  337.   public Enumeration enumerateMeasures() {
  338.     Vector newVector = new Vector(3);
  339.     newVector.addElement("measureNumLinearModels");
  340.     newVector.addElement("measureNumLeaves");
  341.     newVector.addElement("measureNumRules");
  342.     return newVector.elements();
  343.   }
  344.   /**
  345.    * Returns the value of the named measure
  346.    * @param measureName the name of the measure to query for its value
  347.    * @return the value of the named measure
  348.    * @exception IllegalArgumentException if the named measure is not supported
  349.    */
  350.   public double getMeasure(String additionalMeasureName) {
  351.     if (additionalMeasureName.compareTo("measureNumRules") == 0) {
  352.       return measureNumRules();
  353.     } else if (additionalMeasureName.compareTo("measureNumLinearModels") == 0){
  354.       return measureNumLinearModels();
  355.     } else if (additionalMeasureName.compareTo("measureNumLeaves") == 0) {
  356.       return measureNumLeaves();
  357.     } else {
  358.       throw new IllegalArgumentException(additionalMeasureName 
  359.   + " not supported (M5)");
  360.     }
  361.   }
  362.   
  363.   /**
  364.    * Get the value of UseUnsmoothed.
  365.    *
  366.    * @return Value of UseUnsmoothed.
  367.    */
  368.   public boolean getUseUnsmoothed() {
  369.     
  370.     return m_UseUnsmoothed;
  371.   }
  372.   
  373.   /**
  374.    * Set the value of UseUnsmoothed.
  375.    *
  376.    * @param v  Value to assign to UseUnsmoothed.
  377.    */
  378.   public void setUseUnsmoothed(boolean v) {
  379.     
  380.     if (m_Model != Node.MODEL_TREE) m_UseUnsmoothed = true;
  381.     else m_UseUnsmoothed = v;
  382.   }
  383.   
  384.   /**
  385.    * Get the value of PruningFactor.
  386.    *
  387.    * @return Value of PruningFactor.
  388.    */
  389.   public double getPruningFactor() {
  390.     
  391.     return m_PruningFactor;
  392.   }
  393.   
  394.   /**
  395.    * Set the value of PruningFactor.
  396.    *
  397.    * @param v  Value to assign to PruningFactor.
  398.    */
  399.   public void setPruningFactor(double v) {
  400.     
  401.     m_PruningFactor = v;
  402.   }
  403.   
  404.   /**
  405.    * Get the value of Model.
  406.    *
  407.    * @return Value of Model.
  408.    */
  409.   public SelectedTag getModelType() {
  410.     
  411.     return new SelectedTag(m_Model, TAGS_MODEL_TYPES);
  412.   }
  413.   
  414.   /**
  415.    * Set the value of Model.
  416.    *
  417.    * @param v  Value to assign to Model.
  418.    */
  419.   public void setModelType(SelectedTag newMethod) {
  420.     
  421.     if (newMethod.getTags() == TAGS_MODEL_TYPES) {
  422.       m_Model = newMethod.getSelectedTag().getID();
  423.       if (m_Model != Node.MODEL_TREE) setUseUnsmoothed(true);
  424.     }
  425.   }
  426.   
  427.   /**
  428.    * Get the value of Verbosity.
  429.    *
  430.    * @return Value of Verbosity.
  431.    */
  432.   public int getVerbosity() {
  433.     
  434.     return m_Verbosity;
  435.   }
  436.   
  437.   /**
  438.    * Set the value of Verbosity.
  439.    *
  440.    * @param v  Value to assign to Verbosity.
  441.    */
  442.   public void setVerbosity(int v) {
  443.     
  444.     m_Verbosity = v;
  445.   }
  446.   /**
  447.    * Main method for M5' algorithm
  448.    *
  449.    * @param argv command line arguments
  450.    */
  451.   public static void  main(String [] argv){
  452.     try {
  453.       System.out.println(Evaluation.evaluateModel(new M5Prime(), argv));
  454.     } catch (Exception e) {
  455.       System.err.println(e.getMessage());
  456.     }
  457.   }
  458. }
  459.