ClassifierTree.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 16k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    ClassifierTree.java
  18.  *    Copyright (C) 1999 Eibe Frank
  19.  *
  20.  */
  21. package weka.classifiers.j48;
  22. import weka.core.*;
  23. import weka.classifiers.*;
  24. import java.io.*;
  25. /**
  26.  * Class for handling a tree structure used for
  27.  * classification.
  28.  *
  29.  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
  30.  * @version $Revision: 1.12 $
  31.  */
  32. public class ClassifierTree implements Drawable, Serializable {
  33.   /** The model selection method. */  
  34.   protected ModelSelection m_toSelectModel;     
  35.   /** Local model at node. */
  36.   protected ClassifierSplitModel m_localModel;  
  37.   /** References to sons. */
  38.   protected ClassifierTree [] m_sons;           
  39.   /** True if node is leaf. */
  40.   protected boolean m_isLeaf;                   
  41.   /** True if node is empty. */
  42.   protected boolean m_isEmpty;                  
  43.   /** The training instances. */
  44.   protected Instances m_train;                  
  45.   /** The pruning instances. */
  46.   protected Distribution m_test;     
  47.   /** The id for the node. */
  48.   protected int m_id;
  49.   /** 
  50.    * For getting a unique ID when outputting the tree (hashcode isn't
  51.    * guaranteed unique) 
  52.    */
  53.   private static long PRINTED_NODES = 0;
  54.   /**
  55.    * Gets the next unique node ID.
  56.    *
  57.    * @return the next unique node ID.
  58.    */
  59.   protected static long nextID() {
  60.     return PRINTED_NODES ++;
  61.   }
  62.   /**
  63.    * Resets the unique node ID counter (e.g.
  64.    * between repeated separate print types)
  65.    */
  66.   protected static void resetID() {
  67.     PRINTED_NODES = 0;
  68.   }
  69.   /**
  70.    * Constructor. 
  71.    */
  72.   public ClassifierTree(ModelSelection toSelectLocModel) {
  73.     
  74.     m_toSelectModel = toSelectLocModel;
  75.   }
  76.   /**
  77.    * Method for building a classifier tree.
  78.    *
  79.    * @exception Exception if something goes wrong
  80.    */
  81.   public void buildClassifier(Instances data) throws Exception{
  82.     if (data.checkForStringAttributes()) {
  83.       throw new Exception("Can't handle string attributes!");
  84.     }
  85.     data = new Instances(data);
  86.     data.deleteWithMissingClass();
  87.     buildTree(data, false);
  88.   }
  89.   /**
  90.    * Builds the tree structure.
  91.    *
  92.    * @param data the data for which the tree structure is to be
  93.    * generated.
  94.    * @param keepData is training data to be kept?
  95.    * @exception Exception if something goes wrong
  96.    */
  97.   public void buildTree(Instances data, boolean keepData) throws Exception{
  98.     
  99.     Instances [] localInstances;
  100.     if (keepData) {
  101.       m_train = data;
  102.     }
  103.     m_test = null;
  104.     m_isLeaf = false;
  105.     m_isEmpty = false;
  106.     m_sons = null;
  107.     m_localModel = m_toSelectModel.selectModel(data);
  108.     if (m_localModel.numSubsets() > 1) {
  109.       localInstances = m_localModel.split(data);
  110.       data = null;
  111.       m_sons = new ClassifierTree [m_localModel.numSubsets()];
  112.       for (int i = 0; i < m_sons.length; i++) {
  113. m_sons[i] = getNewTree(localInstances[i]);
  114. localInstances[i] = null;
  115.       }
  116.     }else{
  117.       m_isLeaf = true;
  118.       if (Utils.eq(data.sumOfWeights(), 0))
  119. m_isEmpty = true;
  120.       data = null;
  121.     }
  122.   }
  123.   /**
  124.    * Builds the tree structure with hold out set
  125.    *
  126.    * @param train the data for which the tree structure is to be
  127.    * generated.
  128.    * @param test the test data for potential pruning
  129.    * @param keepData is training Data to be kept?
  130.    * @exception Exception if something goes wrong
  131.    */
  132.   public void buildTree(Instances train, Instances test, boolean keepData)
  133.        throws Exception{
  134.     
  135.     Instances [] localTrain, localTest;
  136.     int i;
  137.     
  138.     if (keepData) {
  139.       m_train = train;
  140.     }
  141.     m_isLeaf = false;
  142.     m_isEmpty = false;
  143.     m_sons = null;
  144.     m_localModel = m_toSelectModel.selectModel(train, test);
  145.     m_test = new Distribution(test, m_localModel);
  146.     if (m_localModel.numSubsets() > 1) {
  147.       localTrain = m_localModel.split(train);
  148.       localTest = m_localModel.split(test);
  149.       train = test = null;
  150.       m_sons = new ClassifierTree [m_localModel.numSubsets()];
  151.       for (i=0;i<m_sons.length;i++) {
  152. m_sons[i] = getNewTree(localTrain[i], localTest[i]);
  153. localTrain[i] = null;
  154. localTest[i] = null;
  155.       }
  156.     }else{
  157.       m_isLeaf = true;
  158.       if (Utils.eq(train.sumOfWeights(), 0))
  159. m_isEmpty = true;
  160.       train = test = null;
  161.     }
  162.   }
  163.   /** 
  164.    * Classifies an instance.
  165.    *
  166.    * @exception Exception if something goes wrong
  167.    */
  168.   public double classifyInstance(Instance instance) 
  169.     throws Exception {
  170.     double maxProb = -1;
  171.     double currentProb;
  172.     int maxIndex = 0;
  173.     int j;
  174.     for (j = 0; j < instance.numClasses(); j++) {
  175.       currentProb = getProbs(j, instance, 1);
  176.       if (Utils.gr(currentProb,maxProb)) {
  177. maxIndex = j;
  178. maxProb = currentProb;
  179.       }
  180.     }
  181.     return (double)maxIndex;
  182.   }
  183.   /**
  184.    * Cleanup in order to save memory.
  185.    */
  186.   public final void cleanup(Instances justHeaderInfo) {
  187.     m_train = justHeaderInfo;
  188.     m_test = null;
  189.     if (!m_isLeaf)
  190.       for (int i = 0; i < m_sons.length; i++)
  191. m_sons[i].cleanup(justHeaderInfo);
  192.   }
  193.   /** 
  194.    * Returns class probabilities for a weighted instance.
  195.    *
  196.    * @exception Exception if something goes wrong
  197.    */
  198.   public final double [] distributionForInstance(Instance instance,
  199.  boolean useLaplace) 
  200.        throws Exception {
  201.     double [] doubles = new double[instance.numClasses()];
  202.     for (int i = 0; i < doubles.length; i++) {
  203.       if (!useLaplace) {
  204. doubles[i] = getProbs(i, instance, 1);
  205.       } else {
  206. doubles[i] = getProbsLaplace(i, instance, 1);
  207.       }
  208.     }
  209.     return doubles;
  210.   }
  211.   /**
  212.    * Assigns a uniqe id to every node in the tree.
  213.    */
  214.   public int assignIDs(int lastID) {
  215.     int currLastID = lastID + 1;
  216.     m_id = currLastID;
  217.     if (m_sons != null) {
  218.       for (int i = 0; i < m_sons.length; i++) {
  219. currLastID = m_sons[i].assignIDs(currLastID);
  220.       }
  221.     }
  222.     return currLastID;
  223.   }
  224.   /**
  225.    * Returns graph describing the tree.
  226.    *
  227.    * @exception Exception if something goes wrong
  228.    */
  229.   public String graph() throws Exception {
  230.     StringBuffer text = new StringBuffer();
  231.     assignIDs(-1);
  232.     text.append("digraph J48Tree {n");
  233.     if (m_isLeaf) {
  234.       text.append("N" + m_id 
  235.   + " [label="" + 
  236.   m_localModel.dumpLabel(0,m_train) + "" " + 
  237.   "shape=box style=filled ");
  238.       if (m_train != null) {
  239. text.append("data =n" + m_train + "n");
  240. text.append(",n");
  241.       }
  242.       text.append("]n");
  243.     }else {
  244.       text.append("N" + m_id 
  245.   + " [label="" + 
  246.   m_localModel.leftSide(m_train) + "" ");
  247.       if (m_train != null) {
  248. text.append("data =n" + m_train + "n");
  249. text.append(",n");
  250.      }
  251.       text.append("]n");
  252.       graphTree(text);
  253.     }
  254.     
  255.     return text.toString() +"}n";
  256.   }
  257.   /**
  258.    * Returns tree in prefix order.
  259.    *
  260.    * @exception Exception if something goes wrong
  261.    */
  262.   public String prefix() throws Exception {
  263.     
  264.     StringBuffer text;
  265.     text = new StringBuffer();
  266.     if (m_isLeaf) {
  267.       text.append("["+m_localModel.dumpLabel(0,m_train)+"]");
  268.     }else {
  269.       prefixTree(text);
  270.     }
  271.     
  272.     return text.toString();
  273.   }
  274.   /**
  275.    * Returns source code for the tree as an if-then statement. The 
  276.    * class is assigned to variable "p", and assumes the tested 
  277.    * instance is named "i". The results are returned as two stringbuffers: 
  278.    * a section of code for assignment of the class, and a section of
  279.    * code containing support code (eg: other support methods).
  280.    *
  281.    * @param className the classname that this static classifier has
  282.    * @return an array containing two stringbuffers, the first string containing
  283.    * assignment code, and the second containing source for support code.
  284.    * @exception Exception if something goes wrong
  285.    */
  286.   public StringBuffer [] toSource(String className) throws Exception {
  287.     
  288.     StringBuffer [] result = new StringBuffer [2];
  289.     if (m_isLeaf) {
  290.       result[0] = new StringBuffer("    p = " 
  291. + m_localModel.distribution().maxClass(0) + ";n");
  292.       result[1] = new StringBuffer("");
  293.     } else {
  294.       StringBuffer text = new StringBuffer();
  295.       String nextIndent = "      ";
  296.       StringBuffer atEnd = new StringBuffer();
  297.       long printID = ClassifierTree.nextID();
  298.       text.append("  static double N") 
  299. .append(Integer.toHexString(m_localModel.hashCode()) + printID)
  300. .append("(Object []i) {n")
  301. .append("    double p = Double.NaN;n");
  302.       text.append("    if (")
  303. .append(m_localModel.sourceExpression(-1, m_train))
  304. .append(") {n");
  305.       text.append("      p = ")
  306. .append(m_localModel.distribution().maxClass(0))
  307. .append(";n");
  308.       text.append("    } ");
  309.       for (int i = 0; i < m_sons.length; i++) {
  310. text.append("else if (" + m_localModel.sourceExpression(i, m_train) 
  311.     + ") {n");
  312. if (m_sons[i].m_isLeaf) {
  313.   text.append("      p = " 
  314.       + m_localModel.distribution().maxClass(i) + ";n");
  315. } else {
  316.   StringBuffer [] sub = m_sons[i].toSource(className);
  317.   text.append(sub[0]);
  318.   atEnd.append(sub[1]);
  319. }
  320. text.append("    } ");
  321. if (i == m_sons.length - 1) {
  322.   text.append('n');
  323. }
  324.       }
  325.       text.append("    return p;n  }n");
  326.       result[0] = new StringBuffer("    p = " + className + ".N");
  327.       result[0].append(Integer.toHexString(m_localModel.hashCode()) +  printID)
  328. .append("(i);n");
  329.       result[1] = text.append(atEnd);
  330.     }
  331.     return result;
  332.   }
  333.   /**
  334.    * Returns number of leaves in tree structure.
  335.    */
  336.   public int numLeaves() {
  337.     
  338.     int num = 0;
  339.     int i;
  340.     
  341.     if (m_isLeaf)
  342.       return 1;
  343.     else
  344.       for (i=0;i<m_sons.length;i++)
  345. num = num+m_sons[i].numLeaves();
  346.         
  347.     return num;
  348.   }
  349.   /**
  350.    * Returns number of nodes in tree structure.
  351.    */
  352.   public int numNodes() {
  353.     
  354.     int no = 1;
  355.     int i;
  356.     
  357.     if (!m_isLeaf)
  358.       for (i=0;i<m_sons.length;i++)
  359. no = no+m_sons[i].numNodes();
  360.     
  361.     return no;
  362.   }
  363.   /**
  364.    * Prints tree structure.
  365.    */
  366.   public String toString() {
  367.     try {
  368.       StringBuffer text = new StringBuffer();
  369.       
  370.       if (m_isLeaf) {
  371. text.append(": ");
  372. text.append(m_localModel.dumpLabel(0,m_train));
  373.       }else
  374. dumpTree(0,text);
  375.       text.append("nnNumber of Leaves  : t"+numLeaves()+"n");
  376.       text.append("nSize of the tree : t"+numNodes()+"n");
  377.  
  378.       return text.toString();
  379.     } catch (Exception e) {
  380.       return "Can't print classification tree.";
  381.     }
  382.   }
  383.   /**
  384.    * Returns a newly created tree.
  385.    *
  386.    * @param data the training data
  387.    * @exception Exception if something goes wrong
  388.    */
  389.   protected ClassifierTree getNewTree(Instances data) throws Exception{
  390.  
  391.     ClassifierTree newTree = new ClassifierTree(m_toSelectModel);
  392.     newTree.buildTree(data, false);
  393.     
  394.     return newTree;
  395.   }
  396.   /**
  397.    * Returns a newly created tree.
  398.    *
  399.    * @param data the training data
  400.    * @param test the pruning data.
  401.    * @exception Exception if something goes wrong
  402.    */
  403.   protected ClassifierTree getNewTree(Instances train, Instances test) 
  404.        throws Exception{
  405.  
  406.     ClassifierTree newTree = new ClassifierTree(m_toSelectModel);
  407.     newTree.buildTree(train, test, false);
  408.     
  409.     return newTree;
  410.   }
  411.   /**
  412.    * Help method for printing tree structure.
  413.    *
  414.    * @exception Exception if something goes wrong
  415.    */
  416.   private void dumpTree(int depth,StringBuffer text) 
  417.        throws Exception {
  418.     
  419.     int i,j;
  420.     
  421.     for (i=0;i<m_sons.length;i++) {
  422.       text.append("n");;
  423.       for (j=0;j<depth;j++)
  424. text.append("|   ");
  425.       text.append(m_localModel.leftSide(m_train));
  426.       text.append(m_localModel.rightSide(i, m_train));
  427.       if (m_sons[i].m_isLeaf) {
  428. text.append(": ");
  429. text.append(m_localModel.dumpLabel(i,m_train));
  430.       }else
  431. m_sons[i].dumpTree(depth+1,text);
  432.     }
  433.   }
  434.   /**
  435.    * Help method for printing tree structure as a graph.
  436.    *
  437.    * @exception Exception if something goes wrong
  438.    */
  439.   private void graphTree(StringBuffer text) throws Exception {
  440.     
  441.     for (int i = 0; i < m_sons.length; i++) {
  442.       text.append("N" + m_id  
  443.   + "->" + 
  444.   "N" + m_sons[i].m_id +
  445.   " [label="" + m_localModel.rightSide(i,m_train).trim() + 
  446.   ""]n");
  447.       if (m_sons[i].m_isLeaf) {
  448. text.append("N" + m_sons[i].m_id +
  449.     " [label=""+m_localModel.dumpLabel(i,m_train)+"" "+ 
  450.     "shape=box style=filled ");
  451. if (m_train != null) {
  452.   text.append("data =n" + m_sons[i].m_train + "n");
  453.   text.append(",n");
  454. }
  455. text.append("]n");
  456.       } else {
  457. text.append("N" + m_sons[i].m_id +
  458.     " [label=""+m_sons[i].m_localModel.leftSide(m_train) + 
  459.     "" ");
  460. if (m_train != null) {
  461.   text.append("data =n" + m_sons[i].m_train + "n");
  462.   text.append(",n");
  463. }
  464. text.append("]n");
  465. m_sons[i].graphTree(text);
  466.       }
  467.     }
  468.   }
  469.   /**
  470.    * Prints the tree in prefix form
  471.    */
  472.   private void prefixTree(StringBuffer text) throws Exception {
  473.     text.append("[");
  474.     text.append(m_localModel.leftSide(m_train)+":");
  475.     for (int i = 0; i < m_sons.length; i++) {
  476.       if (i > 0) {
  477. text.append(",n");
  478.       }
  479.       text.append(m_localModel.rightSide(i, m_train));
  480.     }
  481.     for (int i = 0; i < m_sons.length; i++) {
  482.       if (m_sons[i].m_isLeaf) {
  483. text.append("[");
  484. text.append(m_localModel.dumpLabel(i,m_train));
  485. text.append("]");
  486.       } else {
  487. m_sons[i].prefixTree(text);
  488.       }
  489.     }
  490.     text.append("]");
  491.   }
  492.   /**
  493.    * Help method for computing class probabilities of 
  494.    * a given instance.
  495.    *
  496.    * @exception Exception if something goes wrong
  497.    */
  498.   private double getProbsLaplace(int classIndex, Instance instance, double weight) 
  499.     throws Exception {
  500.     
  501.     double [] weights;
  502.     double prob = 0;
  503.     int treeIndex;
  504.     int i,j;
  505.     
  506.     if (m_isLeaf) {
  507.       return weight * localModel().classProbLaplace(classIndex, instance, -1);
  508.     } else {
  509.       treeIndex = localModel().whichSubset(instance);
  510.       if (treeIndex == -1) {
  511. weights = localModel().weights(instance);
  512. for (i = 0; i < m_sons.length; i++) {
  513.   if (!son(i).m_isEmpty) {
  514.     if (!son(i).m_isLeaf) {
  515.       prob += son(i).getProbsLaplace(classIndex, instance, 
  516.      weights[i] * weight);
  517.     } else {
  518.       prob += weight * weights[i] * 
  519. localModel().classProbLaplace(classIndex, instance, i);
  520.     }
  521.   }
  522. }
  523. return prob;
  524.       } else {
  525. if (son(treeIndex).m_isLeaf) {
  526.   return weight * localModel().classProbLaplace(classIndex, instance, 
  527. treeIndex);
  528. } else {
  529.   return son(treeIndex).getProbsLaplace(classIndex, instance, weight);
  530. }
  531.       }
  532.     }
  533.   }
  534.   /**
  535.    * Help method for computing class probabilities of 
  536.    * a given instance.
  537.    *
  538.    * @exception Exception if something goes wrong
  539.    */
  540.   private double getProbs(int classIndex, Instance instance, double weight) 
  541.     throws Exception {
  542.     
  543.     double [] weights;
  544.     double prob = 0;
  545.     int treeIndex;
  546.     int i,j;
  547.     
  548.     if (m_isLeaf) {
  549.       return weight * localModel().classProb(classIndex, instance, -1);
  550.     } else {
  551.       treeIndex = localModel().whichSubset(instance);
  552.       if (treeIndex == -1) {
  553. weights = localModel().weights(instance);
  554. for (i = 0; i < m_sons.length; i++) {
  555.   if (!son(i).m_isEmpty) {
  556.     prob += son(i).getProbs(classIndex, instance, 
  557.     weights[i] * weight);
  558.   }
  559. }
  560. return prob;
  561.       } else {
  562. if (son(treeIndex).m_isEmpty) {
  563.   return weight * localModel().classProb(classIndex, instance, 
  564.  treeIndex);
  565. } else {
  566.   return son(treeIndex).getProbs(classIndex, instance, weight);
  567. }
  568.       }
  569.     }
  570.   }
  571.   /**
  572.    * Method just exists to make program easier to read.
  573.    */
  574.   private ClassifierSplitModel localModel() {
  575.     
  576.     return (ClassifierSplitModel)m_localModel;
  577.   }
  578.   
  579.   /**
  580.    * Method just exists to make program easier to read.
  581.    */
  582.   private ClassifierTree son(int index) {
  583.     
  584.     return (ClassifierTree)m_sons[index];
  585.   }
  586. }