ADTree.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 46k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    ADTree.java
  18.  *    Copyright (C) 2001 Richard Kirkby, Bernhard Pfahringer
  19.  *
  20.  */
  21. package weka.classifiers.trees.adtree;
  22. import weka.classifiers.*;
  23. import weka.core.*;
  24. import java.io.*;
  25. import java.util.*;
  26. /**
  27.  * Class for generating an alternating decision tree. The basic algorithm is based on:<p>
  28.  *
  29.  * Freund, Y., Mason, L.: The alternating decision tree learning algorithm.
  30.  * Proceeding of the Sixteenth International Conference on Machine Learning,
  31.  * Bled, Slovenia, (1999) 124-133.</p>
  32.  *
  33.  * This version currently only supports two-class problems. The number of boosting
  34.  * iterations needs to be manually tuned to suit the dataset and the desired 
  35.  * complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic
  36.  * search methods have been introduced to speed learning.<p>
  37.  *
  38.  * Valid options are: <p>
  39.  *
  40.  * -B num <br>
  41.  * Set the number of boosting iterations
  42.  * (default 10) <p>
  43.  *
  44.  * -E num <br>
  45.  * Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
  46.  * (default -3) <p>
  47.  *
  48.  * -D <br>
  49.  * Save the instance data with the model <p>
  50.  *
  51.  * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
  52.  * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz)
  53.  * @version $Revision: 1.4 $
  54.  */
  55. public class ADTree
  56.   extends DistributionClassifier implements OptionHandler, Drawable,
  57.     AdditionalMeasureProducer,
  58.     WeightedInstancesHandler,
  59.     IterativeClassifier
  60. {
  61.   /** The search modes */
  62.   public final static int SEARCHPATH_ALL = 0;
  63.   public final static int SEARCHPATH_HEAVIEST = 1;
  64.   public final static int SEARCHPATH_ZPURE = 2;
  65.   public final static int SEARCHPATH_RANDOM = 3;
  66.   public static final Tag [] TAGS_SEARCHPATH = {
  67.     new Tag(SEARCHPATH_ALL, "Expand all paths"),
  68.     new Tag(SEARCHPATH_HEAVIEST, "Expand the heaviest path"),
  69.     new Tag(SEARCHPATH_ZPURE, "Expand the best z-pure path"),
  70.     new Tag(SEARCHPATH_RANDOM, "Expand a random path")
  71.   };
  72.   /** The instances used to train the tree */
  73.   protected Instances m_trainInstances;
  74.   /** The root of the tree */
  75.   protected PredictionNode m_root = null;
  76.   /** The random number generator - used for the random search heuristic */
  77.   protected Random m_random = null; 
  78.   /** The number of the last splitter added to the tree */
  79.   protected int m_lastAddedSplitNum = 0;
  80.   /** An array containing the inidices to the numeric attributes in the data */
  81.   protected int[] m_numericAttIndices;
  82.   /** An array containing the inidices to the nominal attributes in the data */
  83.   protected int[] m_nominalAttIndices;
  84.   /** The total weight of the instances - used to speed Z calculations */
  85.   protected double m_trainTotalWeight;
  86.   /** The training instances with positive class - referencing the training dataset */
  87.   protected ReferenceInstances m_posTrainInstances;
  88.   /** The training instances with negative class - referencing the training dataset */
  89.   protected ReferenceInstances m_negTrainInstances;
  90.   /** The best node to insert under, as found so far by the latest search */
  91.   protected PredictionNode m_search_bestInsertionNode;
  92.   /** The best splitter to insert, as found so far by the latest search */
  93.   protected Splitter m_search_bestSplitter;
  94.   /** The smallest Z value found so far by the latest search */
  95.   protected double m_search_smallestZ;
  96.   /** The positive instances that apply to the best path found so far */
  97.   protected Instances m_search_bestPathPosInstances;
  98.   /** The negative instances that apply to the best path found so far */
  99.   protected Instances m_search_bestPathNegInstances;
  100.   /** Statistics - the number of prediction nodes investigated during search */
  101.   protected int m_nodesExpanded = 0;
  102.   /** Statistics - the number of instances processed during search */
  103.   protected int m_examplesCounted = 0;
  104.   /** Option - the number of boosting iterations o perform */
  105.   protected int m_boostingIterations = 10;
  106.   /** Option - the search mode */
  107.   protected int m_searchPath = 0;
  108.   /** Option - the seed to use for a random search */
  109.   protected int m_randomSeed = 0; 
  110.   /** Option - whether the tree should remember the instance data */
  111.   protected boolean m_saveInstanceData = false; 
  112.   /**
  113.    * Sets up the tree ready to be trained, using two-class optimized method.
  114.    *
  115.    * @param instances the instances to train the tree with
  116.    * @exception Exception if training data is unsuitable
  117.    */
  118.   public void initClassifier(Instances instances) throws Exception {
  119.     // clear stats
  120.     m_nodesExpanded = 0;
  121.     m_examplesCounted = 0;
  122.     m_lastAddedSplitNum = 0;
  123.     // make sure training data is suitable
  124.     if (instances.classIndex() < 0) {
  125.       throw new UnassignedClassException("ADTree: Needs a class to be assigned");
  126.     }
  127.     if (instances.checkForStringAttributes()) {
  128.       throw new UnsupportedAttributeTypeException("ADTree: Can't handle string attributes");
  129.     }
  130.     if (!instances.classAttribute().isNominal()) {
  131.       throw new UnsupportedClassTypeException("ADTree: Class must be nominal");
  132.     }
  133.     if (instances.numClasses() != 2) {
  134.       throw new UnsupportedClassTypeException("ADTree: Must be a two-class problem");
  135.     }
  136.     // prepare the random generator
  137.     m_random = new Random(m_randomSeed);
  138.     // create training set
  139.     m_trainInstances = new Instances(instances);
  140.     m_trainInstances.deleteWithMissingClass();
  141.     // create positive/negative subsets
  142.     m_posTrainInstances = new ReferenceInstances(m_trainInstances,
  143.  m_trainInstances.numInstances());
  144.     m_negTrainInstances = new ReferenceInstances(m_trainInstances,
  145.  m_trainInstances.numInstances());
  146.     for (Enumeration e = m_trainInstances.enumerateInstances(); e.hasMoreElements(); ) {
  147.       Instance inst = (Instance) e.nextElement();
  148.       if ((int) inst.classValue() == 0)
  149. m_negTrainInstances.addReference(inst); // belongs in negative class
  150.       else
  151. m_posTrainInstances.addReference(inst); // belongs in positive class
  152.     }
  153.     m_posTrainInstances.compactify();
  154.     m_negTrainInstances.compactify();
  155.     // create the root prediction node
  156.     double rootPredictionValue = calcPredictionValue(m_posTrainInstances,
  157.      m_negTrainInstances);
  158.     m_root = new PredictionNode(rootPredictionValue);
  159.     // pre-adjust weights
  160.     updateWeights(m_posTrainInstances, m_negTrainInstances, rootPredictionValue);
  161.     
  162.     // pre-calculate what we can
  163.     generateAttributeIndicesSingle();
  164.   }
  165.   /**
  166.    * Performs one iteration.
  167.    * 
  168.    * @param iteration the index of the current iteration (0-based)
  169.    * @exception Exception if this iteration fails 
  170.    */  
  171.   public void next(int iteration) throws Exception {
  172.     boost();
  173.   }
  174.   /**
  175.    * Performs a single boosting iteration, using two-class optimized method.
  176.    * Will add a new splitter node and two prediction nodes to the tree
  177.    * (unless merging takes place).
  178.    *
  179.    * @exception Exception if try to boost without setting up tree first or there are no 
  180.    * instances to train with
  181.    */
  182.   public void boost() throws Exception {
  183.     if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
  184.       throw new Exception("Trying to boost with no training data");
  185.     // perform the search
  186.     searchForBestTestSingle();
  187.     if (m_search_bestSplitter == null) return; // handle empty instances
  188.     // create the new nodes for the tree, updating the weights
  189.     for (int i=0; i<2; i++) {
  190.       Instances posInstances =
  191. m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
  192.       Instances negInstances =
  193. m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
  194.       double predictionValue = calcPredictionValue(posInstances, negInstances);
  195.       PredictionNode newPredictor = new PredictionNode(predictionValue);
  196.       updateWeights(posInstances, negInstances, predictionValue);
  197.       m_search_bestSplitter.setChildForBranch(i, newPredictor);
  198.     }
  199.     // insert the new nodes
  200.     m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);
  201.     // free memory
  202.     m_search_bestPathPosInstances = null;
  203.     m_search_bestPathNegInstances = null;
  204.     m_search_bestSplitter = null;
  205.   }
  206.   /**
  207.    * Generates the m_nominalAttIndices and m_numericAttIndices arrays to index
  208.    * the respective attribute types in the training data.
  209.    *
  210.    */
  211.   private void generateAttributeIndicesSingle() {
  212.     // insert indices into vectors
  213.     FastVector nominalIndices = new FastVector();
  214.     FastVector numericIndices = new FastVector();
  215.     for (int i=0; i<m_trainInstances.numAttributes(); i++) {
  216.       if (i != m_trainInstances.classIndex()) {
  217. if (m_trainInstances.attribute(i).isNumeric())
  218.   numericIndices.addElement(new Integer(i));
  219. else
  220.   nominalIndices.addElement(new Integer(i));
  221.       }
  222.     }
  223.     // create nominal array
  224.     m_nominalAttIndices = new int[nominalIndices.size()];
  225.     for (int i=0; i<nominalIndices.size(); i++)
  226.       m_nominalAttIndices[i] = ((Integer)nominalIndices.elementAt(i)).intValue();
  227.     
  228.     // create numeric array
  229.     m_numericAttIndices = new int[numericIndices.size()];
  230.     for (int i=0; i<numericIndices.size(); i++)
  231.       m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue();
  232.   }
  233.   /**
  234.    * Performs a search for the best test (splitter) to add to the tree, by aiming to
  235.    * minimize the Z value.
  236.    *
  237.    * @exception Exception if search fails
  238.    */
  239.   private void searchForBestTestSingle() throws Exception {
  240.     // keep track of total weight for efficient wRemainder calculations
  241.     m_trainTotalWeight = m_trainInstances.sumOfWeights();
  242.     
  243.     m_search_smallestZ = Double.POSITIVE_INFINITY;
  244.     searchForBestTestSingle(m_root, m_posTrainInstances, m_negTrainInstances);
  245.   }
  246.   /**
  247.    * Recursive function that carries out search for the best test (splitter) to add to
  248.    * this part of the tree, by aiming to minimize the Z value. Performs Z-pure cutoff to
  249.    * reduce search space.
  250.    *
  251.    * @param currentNode the root of the subtree to be searched, and the current node 
  252.    * being considered as parent of a new split
  253.    * @param posInstances the positive-class instances that apply at this node
  254.    * @param negInstances the negative-class instances that apply at this node
  255.    * @exception Exception if search fails
  256.    */
  257.   private void searchForBestTestSingle(PredictionNode currentNode,
  258.        Instances posInstances, Instances negInstances)
  259.     throws Exception
  260.   {
  261.     // don't investigate pure or empty nodes any further
  262.     if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return;
  263.     // do z-pure cutoff
  264.     if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return;
  265.     // keep stats
  266.     m_nodesExpanded++;
  267.     m_examplesCounted += posInstances.numInstances() + negInstances.numInstances();
  268.     // evaluate static splitters (nominal)
  269.     for (int i=0; i<m_nominalAttIndices.length; i++)
  270.       evaluateNominalSplitSingle(m_nominalAttIndices[i], currentNode,
  271.  posInstances, negInstances);
  272.     // evaluate dynamic splitters (numeric)
  273.     if (m_numericAttIndices.length > 0) {
  274.       // merge the two sets of instances into one
  275.       Instances allInstances = new Instances(posInstances);
  276.       for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); )
  277. allInstances.add((Instance) e.nextElement());
  278.     
  279.       // use method of finding the optimal Z split-point
  280.       for (int i=0; i<m_numericAttIndices.length; i++)
  281. evaluateNumericSplitSingle(m_numericAttIndices[i], currentNode,
  282.    posInstances, negInstances, allInstances);
  283.     }
  284.     if (currentNode.getChildren().size() == 0) return;
  285.     // keep searching
  286.     switch (m_searchPath) {
  287.     case SEARCHPATH_ALL:
  288.       goDownAllPathsSingle(currentNode, posInstances, negInstances);
  289.       break;
  290.     case SEARCHPATH_HEAVIEST: 
  291.       goDownHeaviestPathSingle(currentNode, posInstances, negInstances);
  292.       break;
  293.     case SEARCHPATH_ZPURE: 
  294.       goDownZpurePathSingle(currentNode, posInstances, negInstances);
  295.       break;
  296.     case SEARCHPATH_RANDOM: 
  297.       goDownRandomPathSingle(currentNode, posInstances, negInstances);
  298.       break;
  299.     }
  300.   }
  301.   /**
  302.    * Continues single (two-class optimized) search by investigating every node in the
  303.    * subtree under currentNode.
  304.    *
  305.    * @param currentNode the root of the subtree to be searched
  306.    * @param posInstances the positive-class instances that apply at this node
  307.    * @param negInstances the negative-class instances that apply at this node
  308.    * @exception Exception if search fails
  309.    */
  310.   private void goDownAllPathsSingle(PredictionNode currentNode,
  311.     Instances posInstances, Instances negInstances)
  312.     throws Exception
  313.   {
  314.     for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
  315.       Splitter split = (Splitter) e.nextElement();
  316.       for (int i=0; i<split.getNumOfBranches(); i++)
  317. searchForBestTestSingle(split.getChildForBranch(i),
  318. split.instancesDownBranch(i, posInstances),
  319. split.instancesDownBranch(i, negInstances));
  320.     }
  321.   }
  322.   /**
  323.    * Continues single (two-class optimized) search by investigating only the path
  324.    * with the most heavily weighted instances.
  325.    *
  326.    * @param currentNode the root of the subtree to be searched
  327.    * @param posInstances the positive-class instances that apply at this node
  328.    * @param negInstances the negative-class instances that apply at this node
  329.    * @exception Exception if search fails
  330.    */
  331.   private void goDownHeaviestPathSingle(PredictionNode currentNode,
  332. Instances posInstances, Instances negInstances)
  333.     throws Exception
  334.   {
  335.     Splitter heaviestSplit = null;
  336.     int heaviestBranch = 0;
  337.     double largestWeight = 0.0;
  338.     for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
  339.       Splitter split = (Splitter) e.nextElement();
  340.       for (int i=0; i<split.getNumOfBranches(); i++) {
  341. double weight =
  342.   split.instancesDownBranch(i, posInstances).sumOfWeights() +
  343.   split.instancesDownBranch(i, negInstances).sumOfWeights();
  344. if (weight > largestWeight) {
  345.   heaviestSplit = split;
  346.   heaviestBranch = i;
  347.   largestWeight = weight;
  348. }
  349.       }
  350.     }
  351.     if (heaviestSplit != null)
  352.       searchForBestTestSingle(heaviestSplit.getChildForBranch(heaviestBranch),
  353.       heaviestSplit.instancesDownBranch(heaviestBranch,
  354. posInstances),
  355.       heaviestSplit.instancesDownBranch(heaviestBranch,
  356. negInstances));
  357.   }
  358.   /**
  359.    * Continues single (two-class optimized) search by investigating only the path
  360.    * with the best Z-pure value at each branch.
  361.    *
  362.    * @param currentNode the root of the subtree to be searched
  363.    * @param posInstances the positive-class instances that apply at this node
  364.    * @param negInstances the negative-class instances that apply at this node
  365.    * @exception Exception if search fails
  366.    */
  367.   private void goDownZpurePathSingle(PredictionNode currentNode,
  368.      Instances posInstances, Instances negInstances)
  369.     throws Exception
  370.   {
  371.     double lowestZpure = m_search_smallestZ; // do z-pure cutoff
  372.     PredictionNode bestPath = null;
  373.     Instances bestPosSplit = null, bestNegSplit = null;
  374.     // search for branch with lowest Z-pure
  375.     for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
  376.       Splitter split = (Splitter) e.nextElement();
  377.       for (int i=0; i<split.getNumOfBranches(); i++) {
  378. Instances posSplit = split.instancesDownBranch(i, posInstances);
  379. Instances negSplit = split.instancesDownBranch(i, negInstances);
  380. double newZpure = calcZpure(posSplit, negSplit);
  381. if (newZpure < lowestZpure) {
  382.   lowestZpure = newZpure;
  383.   bestPath = split.getChildForBranch(i);
  384.   bestPosSplit = posSplit;
  385.   bestNegSplit = negSplit;
  386. }
  387.       }
  388.     }
  389.     if (bestPath != null)
  390.       searchForBestTestSingle(bestPath, bestPosSplit, bestNegSplit);
  391.   }
  392.   /**
  393.    * Continues single (two-class optimized) search by investigating a random path.
  394.    *
  395.    * @param currentNode the root of the subtree to be searched
  396.    * @param posInstances the positive-class instances that apply at this node
  397.    * @param negInstances the negative-class instances that apply at this node
  398.    * @exception Exception if search fails
  399.    */
  400.   private void goDownRandomPathSingle(PredictionNode currentNode,
  401.       Instances posInstances, Instances negInstances)
  402.     throws Exception {
  403.     FastVector children = currentNode.getChildren();
  404.     Splitter split = (Splitter) children.elementAt(getRandom(children.size()));
  405.     int branch = getRandom(split.getNumOfBranches());
  406.     searchForBestTestSingle(split.getChildForBranch(branch),
  407.     split.instancesDownBranch(branch, posInstances),
  408.     split.instancesDownBranch(branch, negInstances));
  409.   }
  410.   /**
  411.    * Investigates the option of introducing a nominal split under currentNode. If it
  412.    * finds a split that has a Z-value lower than has already been found it will
  413.    * update the search information to record this as the best option so far. 
  414.    *
  415.    * @param attIndex index of a nominal attribute to create a split from
  416.    * @param currentNode the parent under which a split is to be considered
  417.    * @param posInstances the positive-class instances that apply at this node
  418.    * @param negInstances the negative-class instances that apply at this node
  419.    */
  420.   private void evaluateNominalSplitSingle(int attIndex, PredictionNode currentNode,
  421.   Instances posInstances, Instances negInstances)
  422.   {
  423.     
  424.     double[] indexAndZ = findLowestZNominalSplit(posInstances, negInstances, attIndex);
  425.     if (indexAndZ[1] < m_search_smallestZ) {
  426.       m_search_smallestZ = indexAndZ[1];
  427.       m_search_bestInsertionNode = currentNode;
  428.       m_search_bestSplitter = new TwoWayNominalSplit(attIndex, (int) indexAndZ[0]);
  429.       m_search_bestPathPosInstances = posInstances;
  430.       m_search_bestPathNegInstances = negInstances;
  431.     }
  432.   }
  433.   /**
  434.    * Investigates the option of introducing a two-way numeric split under currentNode.
  435.    * If it finds a split that has a Z-value lower than has already been found it will
  436.    * update the search information to record this as the best option so far. 
  437.    *
  438.    * @param attIndex index of a numeric attribute to create a split from
  439.    * @param currentNode the parent under which a split is to be considered
  440.    * @param posInstances the positive-class instances that apply at this node
  441.    * @param negInstances the negative-class instances that apply at this node
  442.    * @param allInstances all of the instances the apply at this node (pos+neg combined)
  443.    */
  444.   private void evaluateNumericSplitSingle(int attIndex, PredictionNode currentNode,
  445.   Instances posInstances, Instances negInstances,
  446.   Instances allInstances)
  447.     throws Exception
  448.   {
  449.     
  450.     double[] splitAndZ = findLowestZNumericSplit(allInstances, attIndex);
  451.     if (splitAndZ[1] < m_search_smallestZ) {
  452.       m_search_smallestZ = splitAndZ[1];
  453.       m_search_bestInsertionNode = currentNode;
  454.       m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndZ[0]);
  455.       m_search_bestPathPosInstances = posInstances;
  456.       m_search_bestPathNegInstances = negInstances;
  457.     }
  458.   }
  459.   /**
  460.    * Calculates the prediction value used for a particular set of instances.
  461.    *
  462.    * @param posInstances the positive-class instances
  463.    * @param negInstances the negative-class instances
  464.    * @return the prediction value
  465.    */
  466.   private double calcPredictionValue(Instances posInstances, Instances negInstances) {
  467.     
  468.     return 0.5 * Math.log( (posInstances.sumOfWeights() + 1.0)
  469.   / (negInstances.sumOfWeights() + 1.0) );
  470.   }
  471.   /**
  472.    * Calculates the Z-pure value for a particular set of instances.
  473.    *
  474.    * @param posInstances the positive-class instances
  475.    * @param negInstances the negative-class instances
  476.    * @return the Z-pure value
  477.    */
  478.   private double calcZpure(Instances posInstances, Instances negInstances) {
  479.     
  480.     double posWeight = posInstances.sumOfWeights();
  481.     double negWeight = negInstances.sumOfWeights();
  482.     return (2.0 * (Math.sqrt(posWeight+1.0) + Math.sqrt(negWeight+1.0))) + 
  483.       (m_trainTotalWeight - (posWeight + negWeight));
  484.   }
  485.   /**
  486.    * Updates the weights of instances that are influenced by a new prediction value.
  487.    *
  488.    * @param posInstances positive-class instances to which the prediction value applies
  489.    * @param negInstances negative-class instances to which the prediction value applies
  490.    * @param predictionValue the new prediction value
  491.    */
  492.   private void updateWeights(Instances posInstances, Instances negInstances,
  493.      double predictionValue) {
  494.     
  495.     // do positives
  496.     double weightMultiplier = Math.pow(Math.E, -predictionValue);
  497.     for (Enumeration e = posInstances.enumerateInstances(); e.hasMoreElements(); ) {
  498.       Instance inst = (Instance) e.nextElement();
  499.       inst.setWeight(inst.weight() * weightMultiplier);
  500.     }
  501.     // do negatives
  502.     weightMultiplier = Math.pow(Math.E, predictionValue);
  503.     for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); ) {
  504.       Instance inst = (Instance) e.nextElement();
  505.       inst.setWeight(inst.weight() * weightMultiplier);
  506.     }
  507.   }
  508.   /**
  509.    * Finds the nominal attribute value to split on that results in the lowest Z-value.
  510.    *
  511.    * @param posInstances the positive-class instances to split
  512.    * @param negInstances the negative-class instances to split
  513.    * @param attIndex the index of the nominal attribute to find a split for
  514.    * @return a double array, index[0] contains the value to split on, index[1] contains
  515.    * the Z-value of the split
  516.    */
  517.   private double[] findLowestZNominalSplit(Instances posInstances, Instances negInstances,
  518.    int attIndex)
  519.   {
  520.     
  521.     double lowestZ = Double.MAX_VALUE;
  522.     int bestIndex = 0;
  523.     // set up arrays
  524.     double[] posWeights = attributeValueWeights(posInstances, attIndex);
  525.     double[] negWeights = attributeValueWeights(negInstances, attIndex);
  526.     double posWeight = Utils.sum(posWeights);
  527.     double negWeight = Utils.sum(negWeights);
  528.     int maxIndex = posWeights.length;
  529.     if (maxIndex == 2) maxIndex = 1; // avoid repeating due to 2-way symmetry
  530.     for (int i = 0; i < maxIndex; i++) {
  531.       // calculate Z
  532.       double w1 = posWeights[i] + 1.0;
  533.       double w2 = negWeights[i] + 1.0;
  534.       double w3 = posWeight - w1 + 2.0;
  535.       double w4 = negWeight - w2 + 2.0;
  536.       double wRemainder = m_trainTotalWeight + 4.0 - (w1 + w2 + w3 + w4);
  537.       double newZ = (2.0 * (Math.sqrt(w1 * w2) + Math.sqrt(w3 * w4))) + wRemainder;
  538.       // record best option
  539.       if (newZ < lowestZ) { 
  540. lowestZ = newZ;
  541. bestIndex = i;
  542.       }
  543.     }
  544.     // return result
  545.     double[] indexAndZ = new double[2];
  546.     indexAndZ[0] = (double) bestIndex;
  547.     indexAndZ[1] = lowestZ;
  548.     return indexAndZ; 
  549.   }
  550.   /**
  551.    * Simultanously sum the weights of all attribute values for all instances.
  552.    *
  553.    * @param instances the instances to get the weights from 
  554.    * @param attIndex index of the attribute to be evaluated
  555.    * @return a double array containing the weight of each attribute value
  556.    */    
  557.   private double[] attributeValueWeights(Instances instances, int attIndex)
  558.   {
  559.     
  560.     double[] weights = new double[instances.attribute(attIndex).numValues()];
  561.     for(int i = 0; i < weights.length; i++) weights[i] = 0.0;
  562.     for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
  563.       Instance inst = (Instance) e.nextElement();
  564.       if (!inst.isMissing(attIndex)) weights[(int)inst.value(attIndex)] += inst.weight();
  565.     }
  566.     return weights;
  567.   }
  568.   /**
  569.    * Finds the numeric split-point that results in the lowest Z-value.
  570.    *
  571.    * @param instances the instances to find a split for
  572.    * @param attIndex the index of the numeric attribute to find a split for
  573.    * @return a double array, index[0] contains the split-point, index[1] contains the
  574.    * Z-value of the split
  575.    */
  576.   private double[] findLowestZNumericSplit(Instances instances, int attIndex)
  577.     throws Exception
  578.   {
  579.     
  580.     double splitPoint = 0.0;
  581.     double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
  582.     int numMissing = 0;
  583.     double[][] distribution = new double[3][instances.numClasses()];   
  584.     // compute counts for all the values
  585.     for (int i = 0; i < instances.numInstances(); i++) {
  586.       Instance inst = instances.instance(i);
  587.       if (!inst.isMissing(attIndex)) {
  588. distribution[1][(int)inst.classValue()] += inst.weight();
  589.       } else {
  590. distribution[2][(int)inst.classValue()] += inst.weight();
  591. numMissing++;
  592.       }
  593.     }
  594.     // sort instances
  595.     instances.sort(attIndex);
  596.     
  597.     // make split counts for each possible split and evaluate
  598.     for (int i = 0; i < instances.numInstances() - (numMissing + 1); i++) {
  599.       Instance inst = instances.instance(i);
  600.       Instance instPlusOne = instances.instance(i + 1);
  601.       distribution[0][(int)inst.classValue()] += inst.weight();
  602.       distribution[1][(int)inst.classValue()] -= inst.weight();
  603.       if (Utils.sm(inst.value(attIndex), instPlusOne.value(attIndex))) {
  604. currCutPoint = (inst.value(attIndex) + instPlusOne.value(attIndex)) / 2.0;
  605. currVal = conditionedZOnRows(distribution);
  606. if (currVal < bestVal) {
  607.   splitPoint = currCutPoint;
  608.   bestVal = currVal;
  609. }
  610.       }
  611.     }
  612.     double[] splitAndZ = new double[2];
  613.     splitAndZ[0] = splitPoint;
  614.     splitAndZ[1] = bestVal;
  615.     return splitAndZ;
  616.   }
  617.   /**
  618.    * Calculates the Z-value from the rows of a weight distribution array.
  619.    *
  620.    * @param distribution the weight distribution
  621.    * @return the Z-value
  622.    */
  623.   private double conditionedZOnRows(double [][] distribution) {
  624.     
  625.     double w1 = distribution[0][0] + 1.0;
  626.     double w2 = distribution[0][1] + 1.0;
  627.     double w3 = distribution[1][0] + 1.0; 
  628.     double w4 = distribution[1][1] + 1.0;
  629.     double wRemainder = m_trainTotalWeight + 4.0 - (w1 + w2 + w3 + w4);
  630.     return (2.0 * (Math.sqrt(w1 * w2) + Math.sqrt(w3 * w4))) + wRemainder;
  631.   }
  632.   /**
  633.    * Returns the class probability distribution for an instance.
  634.    *
  635.    * @param instance the instance to be classified
  636.    * @return the distribution the tree generates for the instance
  637.    */
  638.   public double[] distributionForInstance(Instance instance) {
  639.     
  640.     double predVal = predictionValueForInstance(instance, m_root, 0.0);
  641.     
  642.     double[] distribution = new double[2];
  643.     distribution[0] = 1.0 / (1.0 + Math.pow(Math.E, predVal));
  644.     distribution[1] = 1.0 / (1.0 + Math.pow(Math.E, -predVal));
  645.     return distribution;
  646.   }
  647.   /**
  648.    * Returns the class prediction value (vote) for an instance.
  649.    *
  650.    * @param inst the instance
  651.    * @param currentNode the root of the tree to get the values from
  652.    * @param currentValue the current value before adding the value contained in the
  653.    * subtree
  654.    * @return the class prediction value (vote)
  655.    */
  656.   protected double predictionValueForInstance(Instance inst, PredictionNode currentNode,
  657.     double currentValue) {
  658.     
  659.     currentValue += currentNode.getValue();
  660.     for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
  661.       Splitter split = (Splitter) e.nextElement();
  662.       int branch = split.branchInstanceGoesDown(inst);
  663.       if (branch >= 0)
  664. currentValue = predictionValueForInstance(inst, split.getChildForBranch(branch),
  665.   currentValue);
  666.     }
  667.     return currentValue;
  668.   }
  669.   /**
  670.    * Returns a description of the classifier.
  671.    *
  672.    * @return a string containing a description of the classifier
  673.    */
  674.   public String toString() {
  675.     
  676.     if (m_root == null)
  677.       return ("ADTree not built yet");
  678.     else {
  679.       return ("Alternating decision tree:nn" + toString(m_root, 1) +
  680.       "nLegend: " + legend() +
  681.       "nTree size (total number of nodes): " + numOfAllNodes(m_root) + 
  682.       "nLeaves (number of predictor nodes): " + numOfPredictionNodes(m_root)
  683.       );
  684.     }
  685.   }
  686.   /**
  687.    * Traverses the tree, forming a string that describes it.
  688.    *
  689.    * @param currentNode the current node under investigation
  690.    * @param level the current level in the tree
  691.    * @return the string describing the subtree
  692.    */      
  693.   protected String toString(PredictionNode currentNode, int level) {
  694.     
  695.     StringBuffer text = new StringBuffer();
  696.     
  697.     text.append(": " + Utils.doubleToString(currentNode.getValue(),3));
  698.     
  699.     for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
  700.       Splitter split = (Splitter) e.nextElement();
  701.     
  702.       for (int j=0; j<split.getNumOfBranches(); j++) {
  703. PredictionNode child = split.getChildForBranch(j);
  704. if (child != null) {
  705.   text.append("n");
  706.   for (int k = 0; k < level; k++) {
  707.     text.append("|  ");
  708.   }
  709.   text.append("(" + split.orderAdded + ")");
  710.   text.append(split.attributeString(m_trainInstances) + " "
  711.       + split.comparisonString(j, m_trainInstances));
  712.   text.append(toString(child, level + 1));
  713. }
  714.       }
  715.     }
  716.     return text.toString();
  717.   }
  718.   /**
  719.    * Returns graph describing the tree.
  720.    *
  721.    * @return the graph of the tree in dotty format
  722.    * @exception Exception if something goes wrong
  723.    */
  724.   public String graph() throws Exception {
  725.     
  726.     StringBuffer text = new StringBuffer();
  727.     text.append("digraph ADTree {n");
  728.     graphTraverse(m_root, text, 0, 0, m_trainInstances);
  729.     return text.toString() +"}n";
  730.   }
  731.   /**
  732.    * Traverses the tree, graphing each node.
  733.    *
  734.    * @param currentNode the currentNode under investigation
  735.    * @param text the string built so far
  736.    * @param splitOrder the order the parent splitter was added to the tree
  737.    * @param predOrder the order this predictor was added to the split
  738.    * @exception Exception if something goes wrong
  739.    */       
  740.   protected void graphTraverse(PredictionNode currentNode, StringBuffer text,
  741.        int splitOrder, int predOrder, Instances instances)
  742.     throws Exception
  743.   {
  744.     
  745.     text.append("S" + splitOrder + "P" + predOrder + " [label="");
  746.     text.append(Utils.doubleToString(currentNode.getValue(),3));
  747.     if (splitOrder == 0) // show legend in root
  748.       text.append(" (" + legend() + ")");
  749.     text.append("" shape=box style=filled");
  750.     if (instances.numInstances() > 0) text.append(" data=n" + instances + "n,n");
  751.     text.append("]n");
  752.     for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
  753.       Splitter split = (Splitter) e.nextElement();
  754.       text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded +
  755.   " [style=dotted]n");
  756.       text.append("S" + split.orderAdded + " [label="" + split.orderAdded + ": " +
  757.   split.attributeString(m_trainInstances) + ""]n");
  758.       for (int i=0; i<split.getNumOfBranches(); i++) {
  759. PredictionNode child = split.getChildForBranch(i);
  760. if (child != null) {
  761.   text.append("S" + split.orderAdded + "->" + "S" + split.orderAdded + "P" + i +
  762.       " [label="" + split.comparisonString(i, m_trainInstances) + ""]n");
  763.   graphTraverse(child, text, split.orderAdded, i,
  764. split.instancesDownBranch(i, instances));
  765. }
  766.       }
  767.     }  
  768.   }
  769.   /**
  770.    * Returns the legend of the tree, describing how results are to be interpreted.
  771.    *
  772.    * @return a string containing the legend of the classifier
  773.    */
  774.   public String legend() {
  775.     
  776.     Attribute classAttribute = null;
  777.     if (m_trainInstances == null) return "";
  778.     try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){};
  779.     return ("-ve = " + classAttribute.value(0) +
  780.     ", +ve = " + classAttribute.value(1));
  781.   }
  782.   /**
  783.    * @return a description of the classifier suitable for
  784.    * displaying in the explorer/experimenter gui
  785.    */
  786.   public String globalInfo() {
  787.     return "Builds an alternating decision tree, optimized for 2-class problems only.";
  788.   }
  789.   /**
  790.    * @return tip text for this property suitable for
  791.    * displaying in the explorer/experimenter gui
  792.    */
  793.   public String numOfBoostingIterationsTipText() {
  794.     return "Sets the number of boosting iterations to perform. You will need to manually "
  795.       + "tune this parameter to suit the dataset and the desired complexity/accuracy "
  796.       + "tradeoff. More boosting iterations will result in larger (potentially more "
  797.       + " accurate) trees, but will make learning slower. Each iteration will add 3 nodes "
  798.       + "(1 split + 2 prediction) to the tree unless merging occurs.";
  799.   }
  800.   /**
  801.    * Gets the number of boosting iterations.
  802.    *
  803.    * @return the number of boosting iterations
  804.    */
  805.   public int getNumOfBoostingIterations() {
  806.     
  807.     return m_boostingIterations;
  808.   }
  809.   /**
  810.    * Sets the number of boosting iterations.
  811.    *
  812.    * @param b the number of boosting iterations to use
  813.    */
  814.   public void setNumOfBoostingIterations(int b) {
  815.     
  816.     m_boostingIterations = b; 
  817.   }
  818.   /**
  819.    * @return tip text for this property suitable for
  820.    * displaying in the explorer/experimenter gui
  821.    */
  822.   public String searchPathTipText() {
  823.     return "Sets the type of search to perform when building the tree. The default option"
  824.       + " (Expand all paths) will do an exhaustive search. The other search methods are"
  825.       + " heuristic, so they are not guaranteed to find an optimal solution but they are"
  826.       + " much faster. Expand the heaviest path: searches the path with the most heavily"
  827.       + " weighted instances. Expand the best z-pure path: searches the path determined"
  828.       + " by the best z-pure estimate. Expand a random path: the fastest method, simply"
  829.       + " searches down a single random path on each iteration.";
  830.   }
  831.   /**
  832.    * Gets the method of searching the tree for a new insertion. Will be one of
  833.    * SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
  834.    *
  835.    * @return the tree searching mode
  836.    */
  837.   public SelectedTag getSearchPath() {
  838.     return new SelectedTag(m_searchPath, TAGS_SEARCHPATH);
  839.   }
  840.   
  841.   /**
  842.    * Sets the method of searching the tree for a new insertion. Will be one of
  843.    * SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
  844.    *
  845.    * @param newMethod the new tree searching mode
  846.    */
  847.   public void setSearchPath(SelectedTag newMethod) {
  848.     
  849.     if (newMethod.getTags() == TAGS_SEARCHPATH) {
  850.       m_searchPath = newMethod.getSelectedTag().getID();
  851.     }
  852.   }
  853.   /**
  854.    * @return tip text for this property suitable for
  855.    * displaying in the explorer/experimenter gui
  856.    */
  857.   public String randomSeedTipText() {
  858.     return "Sets the random seed to use for a random search.";
  859.   }
  860.   /**
  861.    * Gets random seed for a random walk.
  862.    *
  863.    * @return the random seed
  864.    */
  865.   public int getRandomSeed() {
  866.     
  867.     return m_randomSeed;
  868.   }
  869.   /**
  870.    * Sets random seed for a random walk.
  871.    *
  872.    * @param s the random seed
  873.    */
  874.   public void setRandomSeed(int seed) {
  875.     
  876.     // the actual random object is created when the tree is initialized
  877.     m_randomSeed = seed; 
  878.   }  
  879.   /**
  880.    * @return tip text for this property suitable for
  881.    * displaying in the explorer/experimenter gui
  882.    */
  883.   public String saveInstanceDataTipText() {
  884.     return "Sets whether the tree is to save instance data - the model will take up more"
  885.       + " memory if it does. If enabled you will be able to visualize the instances at"
  886.       + " the prediction nodes when visualizing the tree.";
  887.   }
  888.   /**
  889.    * Gets whether the tree is to save instance data.
  890.    *
  891.    * @return the random seed
  892.    */
  893.   public boolean getSaveInstanceData() {
  894.     
  895.     return m_saveInstanceData;
  896.   }
  897.   /**
  898.    * Sets whether the tree is to save instance data.
  899.    *
  900.    * @param s the random seed
  901.    */
  902.   public void setSaveInstanceData(boolean v) {
  903.     
  904.     m_saveInstanceData = v;
  905.   }
  906.   /**
  907.    * Returns an enumeration describing the available options..
  908.    *
  909.    * @return an enumeration of all the available options.
  910.    */
  911.   public Enumeration listOptions() {
  912.     
  913.     Vector newVector = new Vector(3);
  914.     newVector.addElement(new Option(
  915.     "tNumber of boosting iterations.n"
  916.     +"t(Default = 10)",
  917.     "B", 1,"-B <number of boosting iterations>"));
  918.     newVector.addElement(new Option(
  919.     "tExpand nodes: -3(all), -2(weight), -1(z_pure), "
  920.     +">=0 seed for random walkn"
  921.     +"t(Default = -3)",
  922.     "E", 1,"-E <-3|-2|-1|>=0>"));
  923.     newVector.addElement(new Option(
  924.     "tSave the instance data with the model",
  925.     "D", 0,"-D"));
  926.     return newVector.elements();
  927.   }
  928.   /**
  929.    * Parses a given list of options. Valid options are:<p>
  930.    *
  931.    * -B num <br>
  932.    * Set the number of boosting iterations
  933.    * (default 10) <p>
  934.    *
  935.    * -E num <br>
  936.    * Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
  937.    * (default -3) <p>
  938.    *
  939.    * -D <br>
  940.    * Save the instance data with the model <p>
  941.    *
  942.    * @param options the list of options as an array of strings
  943.    * @exception Exception if an option is not supported
  944.    */
  945.   public void setOptions(String[] options) throws Exception {
  946.     
  947.     String bString = Utils.getOption('B', options);
  948.     if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString));
  949.     String eString = Utils.getOption('E', options);
  950.     if (eString.length() != 0) {
  951.       int value = Integer.parseInt(eString);
  952.       if (value >= 0) {
  953. setSearchPath(new SelectedTag(SEARCHPATH_RANDOM, TAGS_SEARCHPATH));
  954. setRandomSeed(value);
  955.       } else setSearchPath(new SelectedTag(value + 3, TAGS_SEARCHPATH));
  956.     }
  957.     setSaveInstanceData(Utils.getFlag('D', options));
  958.     Utils.checkForRemainingOptions(options);
  959.   }
  960.   /**
  961.    * Gets the current settings of ADTree.
  962.    *
  963.    * @return an array of strings suitable for passing to setOptions()
  964.    */
  965.   public String[] getOptions() {
  966.     
  967.     String[] options = new String[6];
  968.     int current = 0;
  969.     options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations();
  970.     options[current++] = "-E"; options[current++] = "" +
  971.  (m_searchPath == SEARCHPATH_RANDOM ?
  972.   m_randomSeed : m_searchPath - 3);
  973.     if (getSaveInstanceData()) options[current++] = "-D";
  974.     while (current < options.length) options[current++] = "";
  975.     return options;
  976.   }
  977.   /**
  978.    * Calls measure function for tree size - the total number of nodes.
  979.    *
  980.    * @return the tree size
  981.    */
  982.   public double measureTreeSize() {
  983.     
  984.     return numOfAllNodes(m_root);
  985.   }
  986.   /**
  987.    * Calls measure function for leaf size - the number of prediction nodes.
  988.    *
  989.    * @return the leaf size
  990.    */
  991.   public double measureNumLeaves() {
  992.     
  993.     return numOfPredictionNodes(m_root);
  994.   }
  995.   /**
  996.    * Calls measure function for prediction leaf size - the number of 
  997.    * prediction nodes without children.
  998.    *
  999.    * @return the leaf size
  1000.    */
  1001.   public double measureNumPredictionLeaves() {
  1002.     
  1003.     return numOfPredictionLeafNodes(m_root);
  1004.   }
  1005.   /**
  1006.    * Returns the number of nodes expanded.
  1007.    *
  1008.    * @return the number of nodes expanded during search
  1009.    */
  1010.   public double measureNodesExpanded() {
  1011.     
  1012.     return m_nodesExpanded;
  1013.   }
  1014.   /**
  1015.    * Returns the number of examples "counted".
  1016.    *
  1017.    * @return the number of nodes processed during search
  1018.    */
  1019.   public double measureExamplesProcessed() {
  1020.     
  1021.     return m_examplesCounted;
  1022.   }
  1023.   /**
  1024.    * Returns an enumeration of the additional measure names.
  1025.    *
  1026.    * @return an enumeration of the measure names
  1027.    */
  1028.   public Enumeration enumerateMeasures() {
  1029.     
  1030.     Vector newVector = new Vector(4);
  1031.     newVector.addElement("measureTreeSize");
  1032.     newVector.addElement("measureNumLeaves");
  1033.     newVector.addElement("measureNumPredictionLeaves");
  1034.     newVector.addElement("measureNodesExpanded");
  1035.     newVector.addElement("measureExamplesProcessed");
  1036.     return newVector.elements();
  1037.   }
  1038.  
  1039.   /**
  1040.    * Returns the value of the named measure.
  1041.    *
  1042.    * @param measureName the name of the measure to query for its value
  1043.    * @return the value of the named measure
  1044.    * @exception IllegalArgumentException if the named measure is not supported
  1045.    */
  1046.   public double getMeasure(String additionalMeasureName) {
  1047.     
  1048.     if (additionalMeasureName.equals("measureTreeSize")) {
  1049.       return measureTreeSize();
  1050.     }
  1051.     else if (additionalMeasureName.equals("measureNumLeaves")) {
  1052.       return measureNumLeaves();
  1053.     }
  1054.     else if (additionalMeasureName.equals("measureNumPredictionLeaves")) {
  1055.       return measureNumPredictionLeaves();
  1056.     }
  1057.     else if (additionalMeasureName.equals("measureNodesExpanded")) {
  1058.       return measureNodesExpanded();
  1059.     }
  1060.     else if (additionalMeasureName.equals("measureExamplesProcessed")) {
  1061.       return measureExamplesProcessed();
  1062.     }
  1063.     else {throw new IllegalArgumentException(additionalMeasureName 
  1064.       + " not supported (ADTree)");
  1065.     }
  1066.   }
  1067.   /**
  1068.    * Returns the total number of nodes in a tree.
  1069.    *
  1070.    * @param root the root of the tree being measured
  1071.    * @return tree size in number of splitter + prediction nodes
  1072.    */       
  1073.   protected int numOfAllNodes(PredictionNode root) {
  1074.     
  1075.     int numSoFar = 0;
  1076.     if (root != null) {
  1077.       numSoFar++;
  1078.       for (Enumeration e = root.children(); e.hasMoreElements(); ) {
  1079. numSoFar++;
  1080. Splitter split = (Splitter) e.nextElement();
  1081. for (int i=0; i<split.getNumOfBranches(); i++)
  1082.     numSoFar += numOfAllNodes(split.getChildForBranch(i));
  1083.       }
  1084.     }
  1085.     return numSoFar;
  1086.   }
  1087.   /**
  1088.    * Returns the number of prediction nodes in a tree.
  1089.    *
  1090.    * @param root the root of the tree being measured
  1091.    * @return tree size in number of prediction nodes
  1092.    */       
  1093.   protected int numOfPredictionNodes(PredictionNode root) {
  1094.     
  1095.     int numSoFar = 0;
  1096.     if (root != null) {
  1097.       numSoFar++;
  1098.       for (Enumeration e = root.children(); e.hasMoreElements(); ) {
  1099. Splitter split = (Splitter) e.nextElement();
  1100. for (int i=0; i<split.getNumOfBranches(); i++)
  1101.     numSoFar += numOfPredictionNodes(split.getChildForBranch(i));
  1102.       }
  1103.     }
  1104.     return numSoFar;
  1105.   }
  1106.   /**
  1107.    * Returns the number of leaf nodes in a tree - prediction nodes without
  1108.    * children.
  1109.    *
  1110.    * @param root the root of the tree being measured
  1111.    * @return tree leaf size in number of prediction nodes
  1112.    */       
  1113.   protected int numOfPredictionLeafNodes(PredictionNode root) {
  1114.     
  1115.     int numSoFar = 0;
  1116.     if (root.getChildren().size() > 0) {
  1117.       for (Enumeration e = root.children(); e.hasMoreElements(); ) {
  1118. Splitter split = (Splitter) e.nextElement();
  1119. for (int i=0; i<split.getNumOfBranches(); i++)
  1120.     numSoFar += numOfPredictionLeafNodes(split.getChildForBranch(i));
  1121.       }
  1122.     } else numSoFar = 1;
  1123.     return numSoFar;
  1124.   }
  1125.   /**
  1126.    * Gets the next random value.
  1127.    *
  1128.    * @param max the maximum value (+1) to be returned
  1129.    * @return the next random value (between 0 and max-1)
  1130.    */
  1131.   protected int getRandom(int max) {
  1132.     
  1133.     return m_random.nextInt(max);
  1134.   }
  1135.   /**
  1136.    * Returns the next number in the order that splitter nodes have been added to
  1137.    * the tree, and records that a new splitter has been added.
  1138.    *
  1139.    * @return the next number in the order
  1140.    */
  1141.   public int nextSplitAddedOrder() {
  1142.     return ++m_lastAddedSplitNum;
  1143.   }
  1144.   /**
  1145.    * Builds a classifier for a set of instances.
  1146.    *
  1147.    * @param instances the instances to train the classifier with
  1148.    * @exception Exception if something goes wrong
  1149.    */
  1150.   public void buildClassifier(Instances instances) throws Exception {
  1151.     // set up the tree
  1152.     initClassifier(instances);
  1153.     // build the tree
  1154.     for (int T = 0; T < m_boostingIterations; T++) boost();
  1155.     // clean up if desired
  1156.     if (!m_saveInstanceData) done();
  1157.   }
  1158.   /**
  1159.    * Frees memory that is no longer needed for a final model - will no longer be able
  1160.    * to increment the classifier after calling this.
  1161.    *
  1162.    */
  1163.   public void done() {
  1164.     m_trainInstances = new Instances(m_trainInstances, 0);
  1165.     m_random = null; 
  1166.     m_numericAttIndices = null;
  1167.     m_nominalAttIndices = null;
  1168.     m_posTrainInstances = null;
  1169.     m_negTrainInstances = null;
  1170.   }
  1171.   /**
  1172.    * Creates a clone that is identical to the current tree, but is independent.
  1173.    * Deep copies the essential elements such as the tree nodes, and the instances
  1174.    * (because the weights change.) Reference copies several elements such as the
  1175.    * potential splitter sets, assuming that such elements should never differ between
  1176.    * clones.
  1177.    *
  1178.    * @return the clone
  1179.    */
  1180.   public Object clone() {
  1181.     
  1182.     ADTree clone = new ADTree();
  1183.     if (m_root != null) { // check for initialization first
  1184.       clone.m_root = (PredictionNode) m_root.clone(); // deep copy the tree
  1185.       clone.m_trainInstances = new Instances(m_trainInstances); // copy training instances
  1186.       
  1187.       // deep copy the random object
  1188.       if (m_random != null) { 
  1189. SerializedObject randomSerial = null;
  1190. try {
  1191.   randomSerial = new SerializedObject(m_random);
  1192. } catch (Exception ignored) {} // we know that Random is serializable
  1193. clone.m_random = (Random) randomSerial.getObject();
  1194.       }
  1195.       clone.m_lastAddedSplitNum = m_lastAddedSplitNum;
  1196.       clone.m_numericAttIndices = m_numericAttIndices;
  1197.       clone.m_nominalAttIndices = m_nominalAttIndices;
  1198.       clone.m_trainTotalWeight = m_trainTotalWeight;
  1199.       // reconstruct pos/negTrainInstances references
  1200.       if (m_posTrainInstances != null) { 
  1201. clone.m_posTrainInstances =
  1202.   new ReferenceInstances(m_trainInstances, m_posTrainInstances.numInstances());
  1203. clone.m_negTrainInstances =
  1204.   new ReferenceInstances(m_trainInstances, m_negTrainInstances.numInstances());
  1205. for (Enumeration e = clone.m_trainInstances.enumerateInstances();
  1206.      e.hasMoreElements(); ) {
  1207.   Instance inst = (Instance) e.nextElement();
  1208.   try { // ignore classValue() exception
  1209.     if ((int) inst.classValue() == 0)
  1210.       clone.m_negTrainInstances.addReference(inst); // belongs in negative class
  1211.     else
  1212.       clone.m_posTrainInstances.addReference(inst); // belongs in positive class
  1213.   } catch (Exception ignored) {} 
  1214. }
  1215.       }
  1216.     }
  1217.     clone.m_nodesExpanded = m_nodesExpanded;
  1218.     clone.m_examplesCounted = m_examplesCounted;
  1219.     clone.m_boostingIterations = m_boostingIterations;
  1220.     clone.m_searchPath = m_searchPath;
  1221.     clone.m_randomSeed = m_randomSeed;
  1222.     return clone;
  1223.   }
  1224.   /**
  1225.    * Merges two trees together. Modifies the tree being acted on, leaving tree passed
  1226.    * as a parameter untouched (cloned). Does not check to see whether training instances
  1227.    * are compatible - strange things could occur if they are not.
  1228.    *
  1229.    * @param mergeWith the tree to merge with
  1230.    * @exception Exception if merge could not be performed
  1231.    */
  1232.   public void merge(ADTree mergeWith) throws Exception {
  1233.     
  1234.     if (m_root == null || mergeWith.m_root == null)
  1235.       throw new Exception("Trying to merge an uninitialized tree");
  1236.     m_root.merge(mergeWith.m_root, this);
  1237.   }
  1238.   /**
  1239.    * Main method for testing this class.
  1240.    *
  1241.    * @param argv the options
  1242.    */
  1243.   public static void main(String [] argv) {
  1244.     
  1245.     try {
  1246.       System.out.println(Evaluation.evaluateModel(new ADTree(), 
  1247.   argv));
  1248.     } catch (Exception e) {
  1249.       System.err.println(e.getMessage());
  1250.     }
  1251.   }
  1252. }