M5Base.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 10k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    M5Base.java
  3.  *    Copyright (C) 2000 Mark Hall
  4.  *
  5.  *    This program is free software; you can redistribute it and/or modify
  6.  *    it under the terms of the GNU General Public License as published by
  7.  *    the Free Software Foundation; either version 2 of the License, or
  8.  *    (at your option) any later version.
  9.  *
  10.  *    This program is distributed in the hope that it will be useful,
  11.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  12.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  13.  *    GNU General Public License for more details.
  14.  *
  15.  *    You should have received a copy of the GNU General Public License
  16.  *    along with this program; if not, write to the Free Software
  17.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  18.  */
  19. package weka.classifiers.trees.m5;
  20. import java.io.*;
  21. import java.util.*;
  22. import weka.core.*;
  23. import weka.classifiers.*;
  24. import weka.filters.unsupervised.attribute.ReplaceMissingValues;
  25. import weka.filters.supervised.attribute.NominalToBinary;
  26. import weka.filters.Filter;
  27. /**
  28.  * M5Base. Implements base routines
  29.  * for generating M5 Model trees and rules. <p>
  30.  * 
  31.  * Valid options are:<p>
  32.  * 
  33.  * -U <br>
  34.  * Use unsmoothed predictions. <p>
  35.  *
  36.  * -R <br>
  37.  * Build regression tree/rule rather than model tree/rule
  38.  *
  39.  * @version $Revision: 1.4 $
  40.  */
  41. public abstract class M5Base extends Classifier 
  42.   implements OptionHandler,
  43.      AdditionalMeasureProducer {
  44.   /**
  45.    * the instances covered by the tree/rules
  46.    */
  47.   private Instances      m_instances;
  48.   /**
  49.    * the class index
  50.    */
  51.   private int      m_classIndex;
  52.   /**
  53.    * the number of attributes
  54.    */
  55.   private int      m_numAttributes;
  56.   /**
  57.    * the number of instances in the dataset
  58.    */
  59.   private int      m_numInstances;
  60.   /**
  61.    * the rule set
  62.    */
  63.   protected FastVector      m_ruleSet;
  64.   /**
  65.    * generate a decision list instead of a single tree.
  66.    */
  67.   private boolean      m_generateRules;
  68.   /**
  69.    * use unsmoothed predictions
  70.    */
  71.   private boolean      m_unsmoothedPredictions;
  72.   /**
  73.    * filter to fill in missing values
  74.    */
  75.   private ReplaceMissingValues m_replaceMissing;
  76.   /**
  77.    * filter to convert nominal attributes to binary
  78.    */
  79.   private NominalToBinary      m_nominalToBinary;
  80.   /**
  81.    * Save instances at each node in an M5 tree for visualization purposes.
  82.    */
  83.   protected boolean m_saveInstances = false;
  84.   /**
  85.    * Make a regression tree/rule instead of a model tree/rule
  86.    */
  87.   protected boolean m_regressionTree;
  88.   /**
  89.    * Constructor
  90.    */
  91.   public M5Base() {
  92.     m_generateRules = false;
  93.     m_unsmoothedPredictions = false;
  94.   }
  95.   /**
  96.    * Returns an enumeration describing the available options
  97.    * 
  98.    * @return an enumeration of all the available options
  99.    */
  100.   public Enumeration listOptions() {
  101.     Vector newVector = new Vector(2);
  102.     newVector.addElement(new Option("tUse unsmoothed predictionsn", 
  103.     "U", 0, "-U"));
  104.     newVector.addElement(new Option("tBuild regression tree/rule rather "
  105.     +"than a model tree/rulen", 
  106.     "R", 0, "-R"));
  107.     return newVector.elements();
  108.   } 
  109.   /**
  110.    * Parses a given list of options. <p>
  111.    * 
  112.    * Valid options are:<p>
  113.    * 
  114.    * -U <br>
  115.    * Use unsmoothed predictions. <p>
  116.    * 
  117.    * @param options the list of options as an array of strings
  118.    * @exception Exception if an option is not supported
  119.    */
  120.   public void setOptions(String[] options) throws Exception {
  121.     setUseUnsmoothed(Utils.getFlag('U', options));
  122.     setBuildRegressionTree(Utils.getFlag('R', options));
  123.     
  124.     Utils.checkForRemainingOptions(options);
  125.   } 
  126.   /**
  127.    * Gets the current settings of the classifier.
  128.    * 
  129.    * @return an array of strings suitable for passing to setOptions
  130.    */
  131.   public String[] getOptions() {
  132.     String[] options = new String[2];
  133.     int      current = 0;
  134.     if (getUseUnsmoothed()) {
  135.       options[current++] = "-U";
  136.     } 
  137.     if (getBuildRegressionTree()) {
  138.       options[current++] = "-R";
  139.     }
  140.     while (current < options.length) {
  141.       options[current++] = "";
  142.     } 
  143.     return options;
  144.   } 
  145.   /**
  146.    * Generate rules (decision list) rather than a tree
  147.    * 
  148.    * @param u true if rules are to be generated
  149.    */
  150.   protected void setGenerateRules(boolean u) {
  151.     m_generateRules = u;
  152.   } 
  153.   /**
  154.    * get whether rules are being generated rather than a tree
  155.    * 
  156.    * @return true if rules are to be generated
  157.    */
  158.   protected boolean getGenerateRules() {
  159.     return m_generateRules;
  160.   } 
  161.   /**
  162.    * Use unsmoothed predictions
  163.    * 
  164.    * @param s true if unsmoothed predictions are to be used
  165.    */
  166.   public void setUseUnsmoothed(boolean s) {
  167.     m_unsmoothedPredictions = s;
  168.   } 
  169.   /**
  170.    * Get whether or not smoothing is being used
  171.    * 
  172.    * @return true if unsmoothed predictions are to be used
  173.    */
  174.   public boolean getUseUnsmoothed() {
  175.     return m_unsmoothedPredictions;
  176.   } 
  177.   /**
  178.    * Get the value of regressionTree.
  179.    *
  180.    * @return Value of regressionTree.
  181.    */
  182.   public boolean getBuildRegressionTree() {
  183.     
  184.     return m_regressionTree;
  185.   }
  186.   
  187.   /**
  188.    * Set the value of regressionTree.
  189.    *
  190.    * @param newregressionTree Value to assign to regressionTree.
  191.    */
  192.   public void setBuildRegressionTree(boolean newregressionTree) {
  193.     
  194.     m_regressionTree = newregressionTree;
  195.   }
  196.   /**
  197.    * Generates the classifier.
  198.    * 
  199.    * @param data set of instances serving as training data
  200.    * @exception Exception if the classifier has not been generated
  201.    * successfully
  202.    */
  203.   public void buildClassifier(Instances data) throws Exception {
  204.     if (data.checkForStringAttributes()) {
  205.       throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
  206.     } 
  207.     m_instances = new Instances(data);
  208.     m_replaceMissing = new ReplaceMissingValues();
  209.     m_instances.deleteWithMissingClass();
  210.     m_replaceMissing.setInputFormat(m_instances);
  211.     m_instances = Filter.useFilter(m_instances, m_replaceMissing);
  212.     m_nominalToBinary = new NominalToBinary();
  213.     m_nominalToBinary.setInputFormat(m_instances);
  214.     m_instances = Filter.useFilter(m_instances, m_nominalToBinary);
  215.     // 
  216.     m_instances.randomize(new Random(1));
  217.     m_classIndex = m_instances.classIndex();
  218.     m_numAttributes = m_instances.numAttributes();
  219.     m_numInstances = m_instances.numInstances();
  220.     m_ruleSet = new FastVector();
  221.     Rule tempRule;
  222.     if (m_generateRules) {
  223.       Instances tempInst = m_instances;
  224.       double sum = 0;
  225.       double temp_sum = 0;
  226.      
  227.       do {
  228. tempRule = new Rule();
  229. tempRule.setSmoothing(!m_unsmoothedPredictions);
  230. tempRule.setRegressionTree(m_regressionTree);
  231. tempRule.buildClassifier(tempInst);
  232. m_ruleSet.addElement(tempRule);
  233. // System.err.println("Built rule : "+tempRule.toString());
  234. tempInst = tempRule.notCoveredInstances();
  235.       } while (tempInst.numInstances() > 0);
  236.     } else {
  237.       // just build a single tree
  238.       tempRule = new Rule();
  239.       tempRule.setUseTree(true);
  240.       tempRule.setGrowFullTree(true);
  241.       tempRule.setSmoothing(!m_unsmoothedPredictions);
  242.       tempRule.setSaveInstances(m_saveInstances);
  243.       tempRule.setRegressionTree(m_regressionTree);
  244.       Instances temp_train;
  245.       temp_train = m_instances;
  246.       tempRule.buildClassifier(temp_train);
  247.       m_ruleSet.addElement(tempRule);      
  248.       // save space
  249.       m_instances = new Instances(m_instances, 0);
  250.       //      System.err.print(tempRule.m_topOfTree.treeToString(0));
  251.     } 
  252.   } 
  253.   /**
  254.    * Calculates a prediction for an instance using a set of rules
  255.    * or an M5 model tree
  256.    * 
  257.    * @param inst the instance whos class value is to be predicted
  258.    * @return the prediction
  259.    * @exception if a prediction can't be made.
  260.    */
  261.   public double classifyInstance(Instance inst) throws Exception {
  262.     Rule   temp;
  263.     double prediction = 0;
  264.     boolean success = false;
  265.     m_replaceMissing.input(inst);
  266.     inst = m_replaceMissing.output();
  267.     m_nominalToBinary.input(inst);
  268.     inst = m_nominalToBinary.output();
  269.     if (m_ruleSet == null) {
  270.       throw new Exception("Classifier has not been built yet!");
  271.     } 
  272.     if (!m_generateRules) {
  273.       temp = (Rule) m_ruleSet.elementAt(0);
  274.       return temp.classifyInstance(inst);
  275.     } 
  276.     boolean cont;
  277.     int     i;
  278.     for (i = 0; i < m_ruleSet.size(); i++) {
  279.       cont = false;
  280.       temp = (Rule) m_ruleSet.elementAt(i);
  281.       try {
  282. prediction = temp.classifyInstance(inst);
  283. success = true;
  284.       } catch (Exception e) {
  285. cont = true;
  286.       } 
  287.       if (!cont) {
  288. break;
  289.       } 
  290.     } 
  291.     if (!success) {
  292.       System.out.println("Error in predicting (DecList)");
  293.     } 
  294.     return prediction;
  295.   } 
  296.   /**
  297.    * Returns a description of the classifier
  298.    * 
  299.    * @return a description of the classifier as a String
  300.    */
  301.   public String toString() {
  302.     StringBuffer text = new StringBuffer();
  303.     Rule  temp;
  304.     if (m_ruleSet == null) {
  305.       return "Classifier hasn't been built yet!";
  306.     } 
  307.     if (m_generateRules) {
  308.       text.append("M5 Rules  ");
  309.       if (!m_unsmoothedPredictions) {
  310. text.append("(smoothed) ");
  311.       }
  312.       text.append(":n");
  313.       text.append("Number of Rules : " + m_ruleSet.size() + "nn");
  314.       for (int j = 0; j < m_ruleSet.size(); j++) {
  315. temp = (Rule) m_ruleSet.elementAt(j);
  316. text.append("Rule: " + (j + 1) + "n");
  317. text.append(temp.toString());
  318.       } 
  319.     } else {
  320.       temp = (Rule) m_ruleSet.elementAt(0);
  321.       text.append(temp.toString());
  322.     } 
  323.     return text.toString();
  324.   } 
  325.   /**
  326.    * Returns an enumeration of the additional measure names
  327.    * @return an enumeration of the measure names
  328.    */
  329.   public Enumeration enumerateMeasures() {
  330.     Vector newVector = new Vector(1);
  331.     newVector.addElement("measureNumRules");
  332.     return newVector.elements();
  333.   }
  334.   /**
  335.    * Returns the value of the named measure
  336.    * @param measureName the name of the measure to query for its value
  337.    * @return the value of the named measure
  338.    * @exception Exception if the named measure is not supported
  339.    */
  340.   public double getMeasure(String additionalMeasureName) 
  341.     throws IllegalArgumentException {
  342.     if (additionalMeasureName.compareTo("measureNumRules") == 0) {
  343.       return measureNumRules();
  344.     } else {
  345.       throw new IllegalArgumentException(additionalMeasureName 
  346.  + " not supported (M5)");
  347.     }
  348.   }
  349.   /**
  350.    * return the number of rules
  351.    * @return the number of rules (same as # linear models &
  352.    * # leaves in the tree)
  353.    */
  354.   public double measureNumRules() {
  355.     if (m_generateRules) {
  356.       return m_ruleSet.size();
  357.     }
  358.     return ((Rule)m_ruleSet.elementAt(0)).m_topOfTree.numberOfLinearModels();
  359.   }
  360. }