Rule.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 14k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    Rule.java
  3.  *    Copyright (C) 2000 Mark Hall
  4.  *
  5.  *    This program is free software; you can redistribute it and/or modify
  6.  *    it under the terms of the GNU General Public License as published by
  7.  *    the Free Software Foundation; either version 2 of the License, or
  8.  *    (at your option) any later version.
  9.  *
  10.  *    This program is distributed in the hope that it will be useful,
  11.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  12.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  13.  *    GNU General Public License for more details.
  14.  *
  15.  *    You should have received a copy of the GNU General Public License
  16.  *    along with this program; if not, write to the Free Software
  17.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  18.  */
  19. package weka.classifiers.trees.m5;
  20. import java.io.*;
  21. import java.util.*;
  22. import weka.core.*;
  23. import weka.classifiers.*;
  24. import weka.filters.*;
  25. /**
  26.  * Generates a single m5 tree or rule
  27.  *
  28.  * @author Mark Hall
  29.  * @version $Revision: 1.2 $
  30.  */
  31. public class Rule {
  32.   protected static int LEFT = 0;
  33.   protected static int RIGHT = 1;
  34.   /**
  35.    * the instances covered by this rule
  36.    */
  37.   private Instances  m_instances;
  38.   /**
  39.    * the class index
  40.    */
  41.   private int        m_classIndex;
  42.   /**
  43.    * the number of attributes
  44.    */
  45.   private int        m_numAttributes;
  46.   /**
  47.    * the number of instances in the dataset
  48.    */
  49.   private int        m_numInstances;
  50.   /**
  51.    * the indexes of the attributes used to split on for this rule
  52.    */
  53.   private int[]      m_splitAtts;
  54.   /**
  55.    * the corresponding values of the split points
  56.    */
  57.   private double[]   m_splitVals;
  58.   /**
  59.    * the corresponding internal nodes. Used for smoothing rules.
  60.    */
  61.   private RuleNode[] m_internalNodes;
  62.   /**
  63.    * the corresponding relational operators (0 = "<=", 1 = ">")
  64.    */
  65.   private int[]      m_relOps;
  66.   /**
  67.    * the leaf encapsulating the linear model for this rule
  68.    */
  69.   private RuleNode   m_ruleModel;
  70.   /**
  71.    * the top of the m5 tree for this rule
  72.    */
  73.   protected RuleNode   m_topOfTree;
  74.   /**
  75.    * the standard deviation of the class for all the instances
  76.    */
  77.   private double     m_globalStdDev;
  78.   /**
  79.    * the absolute deviation of the class for all the instances
  80.    */
  81.   private double     m_globalAbsDev;
  82.   /**
  83.    * the instances covered by this rule
  84.    */
  85.   private Instances  m_covered;
  86.   /**
  87.    * the number of instances covered by this rule
  88.    */
  89.   private int        m_numCovered;
  90.   /**
  91.    * the instances not covered by this rule
  92.    */
  93.   private Instances  m_notCovered;
  94.   /**
  95.    * use a pruned m5 tree rather than make a rule
  96.    */
  97.   private boolean    m_useTree;
  98.   /**
  99.    * grow and prune a full m5 tree rather than use the PART heuristic
  100.    */
  101.   private boolean    m_growFullTree;
  102.   /**
  103.    * use the original m5 smoothing procedure
  104.    */
  105.   private boolean    m_smoothPredictions;
  106.   /**
  107.    * Save instances at each node in an M5 tree for visualization purposes.
  108.    */
  109.   private boolean m_saveInstances;
  110.   /**
  111.    * Make a regression tree instead of a model tree
  112.    */
  113.   private boolean m_regressionTree;
  114.   /**
  115.    * Constructor declaration
  116.    *
  117.    */
  118.   public Rule() {
  119.     m_useTree = false;
  120.     m_growFullTree = true;
  121.     m_smoothPredictions = false;
  122.   }
  123.   /**
  124.    * Generates a single rule or m5 model tree.
  125.    * 
  126.    * @param data set of instances serving as training data
  127.    * @exception Exception if the rule has not been generated
  128.    * successfully
  129.    */
  130.   public void buildClassifier(Instances data) throws Exception {
  131.     m_instances = null;
  132.     m_topOfTree = null;
  133.     m_covered = null;
  134.     m_notCovered = null;
  135.     m_ruleModel = null;
  136.     m_splitAtts = null;
  137.     m_splitVals = null;
  138.     m_relOps = null;
  139.     m_internalNodes = null;
  140.     m_instances = data;
  141.     m_classIndex = m_instances.classIndex();
  142.     m_numAttributes = m_instances.numAttributes();
  143.     m_numInstances = m_instances.numInstances();
  144.     // first calculate global deviation of class attribute
  145.     m_globalStdDev = Rule.stdDev(m_classIndex, m_instances);
  146.     m_globalAbsDev = Rule.absDev(m_classIndex, m_instances);
  147.     m_topOfTree = new RuleNode(m_globalStdDev, m_globalAbsDev, null);
  148.     m_topOfTree.setSmoothing(m_smoothPredictions);
  149.     m_topOfTree.setSaveInstances(m_saveInstances);
  150.     m_topOfTree.setRegressionTree(m_regressionTree);
  151.     m_topOfTree.buildClassifier(m_instances);
  152.     // m_topOfTree.numLeaves(0);
  153.     if (m_growFullTree) {
  154. m_topOfTree.prune();
  155. // m_topOfTree.printAllModels();
  156. m_topOfTree.numLeaves(0);
  157.     } 
  158.     if (!m_useTree) {      
  159.       makeRule();
  160.       // save space
  161.       //      m_topOfTree = null;
  162.     }
  163.     // save space
  164.     m_instances = new Instances(m_instances, 0);
  165.   } 
  166.   /**
  167.    * Calculates a prediction for an instance using this rule
  168.    * or M5 model tree
  169.    * 
  170.    * @param inst the instance whos class value is to be predicted
  171.    * @return the prediction
  172.    * @exception if a prediction can't be made.
  173.    */
  174.   public double classifyInstance(Instance instance) throws Exception {
  175.     if (m_useTree) {
  176.       return m_topOfTree.classifyInstance(instance);
  177.     } 
  178.     // does the instance pass the rule's conditions?
  179.     if (m_splitAtts.length > 0) {
  180.       for (int i = 0; i < m_relOps.length; i++) {
  181. if (m_relOps[i] == LEFT)    // left
  182.  {
  183.   if (instance.value(m_splitAtts[i]) > m_splitVals[i]) {
  184.     throw new Exception("Rule does not classify instance");
  185.   } 
  186. } else {
  187.   if (instance.value(m_splitAtts[i]) <= m_splitVals[i]) {
  188.     throw new Exception("Rule does not classify instance");
  189.   } 
  190.       } 
  191.     } 
  192.     // the linear model's prediction for this rule
  193.     // add smoothing code here
  194.     if (m_smoothPredictions) {
  195.       double pred = m_ruleModel.classifyInstance(instance);
  196.       int n = m_ruleModel.m_numInstances;
  197.       double supportPred;
  198.       Instance tempInst;
  199.       for (int i = 0; i < m_internalNodes.length; i++) {
  200. tempInst = m_internalNodes[i].applyNodeFilter(instance);
  201. supportPred = m_internalNodes[i].getModel().classifyInstance(tempInst);
  202. pred = RuleNode.smoothingOriginal(n, pred, supportPred);
  203. n = m_internalNodes[i].m_numInstances;
  204.       }
  205.       return pred;
  206.     }
  207.     return m_ruleModel.classifyInstance(instance);
  208.   } 
  209.   /**
  210.    * Make the single best rule from a pruned m5 model tree
  211.    * 
  212.    * @exception if something goes wrong.
  213.    */
  214.   private void makeRule() throws Exception {
  215.     RuleNode[] best_leaf = new RuleNode[1];
  216.     double[]   best_cov = new double[1];
  217.     RuleNode   temp;
  218.     m_notCovered = new Instances(m_instances, 0);
  219.     m_covered = new Instances(m_instances, 0);
  220.     best_cov[0] = -1;
  221.     best_leaf[0] = null;
  222.     m_topOfTree.findBestLeaf(best_cov, best_leaf);
  223.     temp = best_leaf[0];
  224.     if (temp == null) {
  225.       throw new Exception("Unable to generate rule!");
  226.     } 
  227.     // save the linear model for this rule
  228.     m_ruleModel = temp;
  229.     int count = 0;
  230.     while (temp.parentNode() != null) {
  231.       count++;
  232.       temp = temp.parentNode();
  233.     } 
  234.     temp = best_leaf[0];
  235.     m_relOps = new int[count];
  236.     m_splitAtts = new int[count];
  237.     m_splitVals = new double[count];
  238.     if (m_smoothPredictions) {
  239.       m_internalNodes = new RuleNode[count];
  240.     }
  241.     // trace back to the root
  242.     int i = 0;
  243.     while (temp.parentNode() != null) {
  244.       m_splitAtts[i] = temp.parentNode().splitAtt();
  245.       m_splitVals[i] = temp.parentNode().splitVal();
  246.       if (temp.parentNode().leftNode() == temp) {
  247. m_relOps[i] = LEFT;
  248. // temp.parentNode().m_right = null;
  249.       } else {
  250. m_relOps[i] = RIGHT;
  251. // temp.parentNode().m_left = null;
  252.       }
  253.       if (m_smoothPredictions) {
  254. m_internalNodes[i] = temp.parentNode();
  255.       }
  256.       temp = temp.parentNode();
  257.       i++;
  258.     } 
  259.     // now assemble the covered and uncovered instances
  260.     boolean ok;
  261.     for (i = 0; i < m_numInstances; i++) {
  262.       ok = true;
  263.       for (int j = 0; j < m_relOps.length; j++) {
  264. if (m_relOps[j] == LEFT)
  265.  {
  266.   if (m_instances.instance(i).value(m_splitAtts[j]) 
  267.   > m_splitVals[j]) {
  268.     m_notCovered.add(m_instances.instance(i));
  269.     ok = false;
  270.     break;
  271.   } 
  272. } else {
  273.   if (m_instances.instance(i).value(m_splitAtts[j]) 
  274.   <= m_splitVals[j]) {
  275.     m_notCovered.add(m_instances.instance(i));
  276.     ok = false;
  277.     break;
  278.   } 
  279.       } 
  280.       if (ok) {
  281. m_numCovered++;
  282. // m_covered.add(m_instances.instance(i));
  283.       } 
  284.     } 
  285.   } 
  286.   /**
  287.    * Return a description of the m5 tree or rule
  288.    * 
  289.    * @return a description of the m5 tree or rule as a String
  290.    */
  291.   public String toString() {
  292.     if (m_useTree) {
  293.       return treeToString();
  294.     } else {
  295.       return ruleToString();
  296.     } 
  297.   } 
  298.   /**
  299.    * Return a description of the m5 tree
  300.    * 
  301.    * @return a description of the m5 tree as a String
  302.    */
  303.   private String treeToString() {
  304.     StringBuffer text = new StringBuffer();
  305.     if (m_topOfTree == null) {
  306.       return "Tree/Rule has not been built yet!";
  307.     } 
  308.     text.append("Pruned training "
  309. + ((m_regressionTree) 
  310.    ? "regression "
  311.    : "model ")
  312. +"tree:n");
  313.     if (m_smoothPredictions == true) {
  314.       text.append("(using smoothed predictions)n");
  315.     } 
  316.     text.append(m_topOfTree.treeToString(0));
  317.     text.append(m_topOfTree.printLeafModels());
  318.     text.append("nNumber of Rules : " + m_topOfTree.numberOfLinearModels());
  319.     return text.toString();
  320.   } 
  321.   /**
  322.    * Return a description of the rule
  323.    * 
  324.    * @return a description of the rule as a String
  325.    */
  326.   private String ruleToString() {
  327.     StringBuffer text = new StringBuffer();
  328.     if (m_splitAtts.length > 0) {
  329.       text.append("IFn");
  330.       for (int i = m_splitAtts.length - 1; i >= 0; i--) {
  331. text.append("t" + m_covered.attribute(m_splitAtts[i]).name() + " ");
  332. if (m_relOps[i] == 0) {
  333.   text.append("<= ");
  334. } else {
  335.   text.append("> ");
  336. text.append(Utils.doubleToString(m_splitVals[i], 1, 3) + "n");
  337.       } 
  338.       text.append("THENn");
  339.     } 
  340.     if (m_ruleModel != null) {
  341.       try {
  342. text.append(m_ruleModel.printNodeLinearModel());
  343. text.append(" [" + m_numCovered/*m_covered.numInstances()*/);
  344. if (m_globalAbsDev > 0.0) {
  345.   text.append("/"+Utils.doubleToString((100 * 
  346.    m_ruleModel.
  347.    rootMeanSquaredError() / 
  348.    m_globalAbsDev), 1, 3) 
  349.       + "%]nn");
  350. } else {
  351.   text.append("]nn");
  352.       } catch (Exception e) {
  353. return "Can't print rule";
  354.       } 
  355.     } 
  356.     
  357.     //    System.out.println(m_instances);
  358.     return text.toString();
  359.   } 
  360.   /**
  361.    * Use an m5 tree rather than generate rules
  362.    * 
  363.    * @param u true if m5 tree is to be used
  364.    */
  365.   public void setUseTree(boolean u) {
  366.     m_useTree = u;
  367.   } 
  368.   /**
  369.    * get whether an m5 tree is being used rather than rules
  370.    * 
  371.    * @return true if an m5 tree is being used.
  372.    */
  373.   public boolean getUseTree() {
  374.     return m_useTree;
  375.   } 
  376.   /**
  377.    * Grow a full tree instead of using the PART heuristic
  378.    * 
  379.    * @param g true if a full tree is to be grown rather than using
  380.    * the part heuristic
  381.    */
  382.   public void setGrowFullTree(boolean g) {
  383.     m_growFullTree = g;
  384.   } 
  385.   /**
  386.    * Get whether or not a full tree has been grown
  387.    * 
  388.    * @return true if a full tree has been grown
  389.    */
  390.   public boolean getGrowFullTree() {
  391.     return m_growFullTree;
  392.   } 
  393.   /**
  394.    * Smooth predictions
  395.    * 
  396.    * @param s true if smoothing is to be used
  397.    */
  398.   public void setSmoothing(boolean s) {
  399.     m_smoothPredictions = s;
  400.   } 
  401.   /**
  402.    * Get whether or not smoothing has been turned on
  403.    * 
  404.    * @return true if smoothing is being used
  405.    */
  406.   public boolean getSmoothing() {
  407.     return m_smoothPredictions;
  408.   } 
  409.   /**
  410.    * Get the instances not covered by this rule
  411.    * 
  412.    * @return the instances not covered
  413.    */
  414.   public Instances notCoveredInstances() {
  415.     return m_notCovered;
  416.   } 
  417. //    /**
  418. //     * Get the instances covered by this rule
  419. //     * 
  420. //     * @return the instances covered by this rule
  421. //     */
  422. //    public Instances coveredInstances() {
  423. //      return m_covered;
  424. //    } 
  425.   /**
  426.    * Returns the standard deviation value of the supplied attribute index.
  427.    *
  428.    * @param attr an attribute index
  429.    * @param inst the instances
  430.    * @return the standard deviation value
  431.    */
  432.   protected final static double stdDev(int attr, Instances inst) {
  433.     int i,count=0;
  434.     double sd,va,sum=0.0,sqrSum=0.0,value;
  435.     
  436.     for(i = 0; i <= inst.numInstances() - 1; i++) {
  437.       count++;
  438.       value = inst.instance(i).value(attr);
  439.       sum +=  value;
  440.       sqrSum += value * value;
  441.     }
  442.     
  443.     if(count > 1) {
  444.       va = (sqrSum - sum * sum / count) / count;
  445.       va = Math.abs(va);
  446.       sd = Math.sqrt(va);
  447.     } else {
  448.       sd = 0.0;
  449.     }
  450.     return sd;
  451.   }
  452.   /**
  453.    * Returns the absolute deviation value of the supplied attribute index.
  454.    *
  455.    * @param attr an attribute index
  456.    * @param inst the instances
  457.    * @return the absolute deviation value
  458.    */
  459.   protected final static double absDev(int attr, Instances inst) {
  460.     int i;
  461.     double average=0.0,absdiff=0.0,absDev;
  462.     
  463.     for(i = 0; i <= inst.numInstances()-1; i++) {
  464.       average  += inst.instance(i).value(attr);
  465.     }
  466.     if(inst.numInstances() > 1) {
  467.       average /= (double)inst.numInstances();
  468.       for(i=0; i <= inst.numInstances()-1; i++) {
  469. absdiff += Math.abs(inst.instance(i).value(attr) - average);
  470.       }
  471.       absDev = absdiff / (double)inst.numInstances();
  472.     } else {
  473.       absDev = 0.0;
  474.     }
  475.    
  476.     return absDev;
  477.   }
  478.   /**
  479.    * Sets whether instances at each node in an M5 tree should be saved
  480.    * for visualization purposes. Default is to save memory.
  481.    *
  482.    * @param save a <code>boolean</code> value
  483.    */
  484.   protected void setSaveInstances(boolean save) {
  485.     m_saveInstances = save;
  486.   }
  487.   /**
  488.    * Get the value of regressionTree.
  489.    *
  490.    * @return Value of regressionTree.
  491.    */
  492.   public boolean getRegressionTree() {
  493.     
  494.     return m_regressionTree;
  495.   }
  496.   
  497.   /**
  498.    * Set the value of regressionTree.
  499.    *
  500.    * @param newregressionTree Value to assign to regressionTree.
  501.    */
  502.   public void setRegressionTree(boolean newregressionTree) {
  503.     
  504.     m_regressionTree = newregressionTree;
  505.   }
  506. }