Code/Resource
Windows Develop
Linux-Unix program
Internet-Socket-Network
Web Server
Browser Client
Ftp Server
Ftp Client
Browser Plugins
Proxy Server
Email Server
Email Client
WEB Mail
Firewall-Security
Telnet Server
Telnet Client
ICQ-IM-Chat
Search Engine
Sniffer Package capture
Remote Control
xml-soap-webservice
P2P
WEB(ASP,PHP,...)
TCP/IP Stack
SNMP
Grid Computing
SilverLight
DNS
Cluster Service
Network Security
Communication-Mobile
Game Program
Editor
Multimedia program
Graph program
Compiler program
Compress-Decompress algrithms
Crypt_Decrypt algrithms
Mathimatics-Numerical algorithms
MultiLanguage
Disk/Storage
Java Develop
assembly language
Applications
Other systems
Database system
Embeded-SCM Develop
FlashMX/Flex
source in ebook
Delphi VCL
OS Develop
MiddleWare
MPI
MacOS develop
LabView
ELanguage
Software/Tools
E-Books
Artical/Document
ADTree.java
Package: Weka-3-2.rar [view]
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 46k
Category:
Windows Develop
Development Platform:
Java
- /*
- * This program is free software; you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation; either version 2 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program; if not, write to the Free Software
- * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
- */
- /*
- * ADTree.java
- * Copyright (C) 2001 Richard Kirkby, Bernhard Pfahringer
- *
- */
- package weka.classifiers.trees.adtree;
- import weka.classifiers.*;
- import weka.core.*;
- import java.io.*;
- import java.util.*;
- /**
- * Class for generating an alternating decision tree. The basic algorithm is based on:<p>
- *
- * Freund, Y., Mason, L.: The alternating decision tree learning algorithm.
- * Proceeding of the Sixteenth International Conference on Machine Learning,
- * Bled, Slovenia, (1999) 124-133.</p>
- *
- * This version currently only supports two-class problems. The number of boosting
- * iterations needs to be manually tuned to suit the dataset and the desired
- * complexity/accuracy tradeoff. Induction of the trees has been optimized, and heuristic
- * search methods have been introduced to speed learning.<p>
- *
- * Valid options are: <p>
- *
- * -B num <br>
- * Set the number of boosting iterations
- * (default 10) <p>
- *
- * -E num <br>
- * Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
- * (default -3) <p>
- *
- * -D <br>
- * Save the instance data with the model <p>
- *
- * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
- * @author Bernhard Pfahringer (bernhard@cs.waikato.ac.nz)
- * @version $Revision: 1.4 $
- */
- public class ADTree
- extends DistributionClassifier implements OptionHandler, Drawable,
- AdditionalMeasureProducer,
- WeightedInstancesHandler,
- IterativeClassifier
- {
- /** The search modes */
- public final static int SEARCHPATH_ALL = 0;
- public final static int SEARCHPATH_HEAVIEST = 1;
- public final static int SEARCHPATH_ZPURE = 2;
- public final static int SEARCHPATH_RANDOM = 3;
- public static final Tag [] TAGS_SEARCHPATH = {
- new Tag(SEARCHPATH_ALL, "Expand all paths"),
- new Tag(SEARCHPATH_HEAVIEST, "Expand the heaviest path"),
- new Tag(SEARCHPATH_ZPURE, "Expand the best z-pure path"),
- new Tag(SEARCHPATH_RANDOM, "Expand a random path")
- };
- /** The instances used to train the tree */
- protected Instances m_trainInstances;
- /** The root of the tree */
- protected PredictionNode m_root = null;
- /** The random number generator - used for the random search heuristic */
- protected Random m_random = null;
- /** The number of the last splitter added to the tree */
- protected int m_lastAddedSplitNum = 0;
- /** An array containing the inidices to the numeric attributes in the data */
- protected int[] m_numericAttIndices;
- /** An array containing the inidices to the nominal attributes in the data */
- protected int[] m_nominalAttIndices;
- /** The total weight of the instances - used to speed Z calculations */
- protected double m_trainTotalWeight;
- /** The training instances with positive class - referencing the training dataset */
- protected ReferenceInstances m_posTrainInstances;
- /** The training instances with negative class - referencing the training dataset */
- protected ReferenceInstances m_negTrainInstances;
- /** The best node to insert under, as found so far by the latest search */
- protected PredictionNode m_search_bestInsertionNode;
- /** The best splitter to insert, as found so far by the latest search */
- protected Splitter m_search_bestSplitter;
- /** The smallest Z value found so far by the latest search */
- protected double m_search_smallestZ;
- /** The positive instances that apply to the best path found so far */
- protected Instances m_search_bestPathPosInstances;
- /** The negative instances that apply to the best path found so far */
- protected Instances m_search_bestPathNegInstances;
- /** Statistics - the number of prediction nodes investigated during search */
- protected int m_nodesExpanded = 0;
- /** Statistics - the number of instances processed during search */
- protected int m_examplesCounted = 0;
- /** Option - the number of boosting iterations o perform */
- protected int m_boostingIterations = 10;
- /** Option - the search mode */
- protected int m_searchPath = 0;
- /** Option - the seed to use for a random search */
- protected int m_randomSeed = 0;
- /** Option - whether the tree should remember the instance data */
- protected boolean m_saveInstanceData = false;
- /**
- * Sets up the tree ready to be trained, using two-class optimized method.
- *
- * @param instances the instances to train the tree with
- * @exception Exception if training data is unsuitable
- */
- public void initClassifier(Instances instances) throws Exception {
- // clear stats
- m_nodesExpanded = 0;
- m_examplesCounted = 0;
- m_lastAddedSplitNum = 0;
- // make sure training data is suitable
- if (instances.classIndex() < 0) {
- throw new UnassignedClassException("ADTree: Needs a class to be assigned");
- }
- if (instances.checkForStringAttributes()) {
- throw new UnsupportedAttributeTypeException("ADTree: Can't handle string attributes");
- }
- if (!instances.classAttribute().isNominal()) {
- throw new UnsupportedClassTypeException("ADTree: Class must be nominal");
- }
- if (instances.numClasses() != 2) {
- throw new UnsupportedClassTypeException("ADTree: Must be a two-class problem");
- }
- // prepare the random generator
- m_random = new Random(m_randomSeed);
- // create training set
- m_trainInstances = new Instances(instances);
- m_trainInstances.deleteWithMissingClass();
- // create positive/negative subsets
- m_posTrainInstances = new ReferenceInstances(m_trainInstances,
- m_trainInstances.numInstances());
- m_negTrainInstances = new ReferenceInstances(m_trainInstances,
- m_trainInstances.numInstances());
- for (Enumeration e = m_trainInstances.enumerateInstances(); e.hasMoreElements(); ) {
- Instance inst = (Instance) e.nextElement();
- if ((int) inst.classValue() == 0)
- m_negTrainInstances.addReference(inst); // belongs in negative class
- else
- m_posTrainInstances.addReference(inst); // belongs in positive class
- }
- m_posTrainInstances.compactify();
- m_negTrainInstances.compactify();
- // create the root prediction node
- double rootPredictionValue = calcPredictionValue(m_posTrainInstances,
- m_negTrainInstances);
- m_root = new PredictionNode(rootPredictionValue);
- // pre-adjust weights
- updateWeights(m_posTrainInstances, m_negTrainInstances, rootPredictionValue);
- // pre-calculate what we can
- generateAttributeIndicesSingle();
- }
- /**
- * Performs one iteration.
- *
- * @param iteration the index of the current iteration (0-based)
- * @exception Exception if this iteration fails
- */
- public void next(int iteration) throws Exception {
- boost();
- }
- /**
- * Performs a single boosting iteration, using two-class optimized method.
- * Will add a new splitter node and two prediction nodes to the tree
- * (unless merging takes place).
- *
- * @exception Exception if try to boost without setting up tree first or there are no
- * instances to train with
- */
- public void boost() throws Exception {
- if (m_trainInstances == null || m_trainInstances.numInstances() == 0)
- throw new Exception("Trying to boost with no training data");
- // perform the search
- searchForBestTestSingle();
- if (m_search_bestSplitter == null) return; // handle empty instances
- // create the new nodes for the tree, updating the weights
- for (int i=0; i<2; i++) {
- Instances posInstances =
- m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathPosInstances);
- Instances negInstances =
- m_search_bestSplitter.instancesDownBranch(i, m_search_bestPathNegInstances);
- double predictionValue = calcPredictionValue(posInstances, negInstances);
- PredictionNode newPredictor = new PredictionNode(predictionValue);
- updateWeights(posInstances, negInstances, predictionValue);
- m_search_bestSplitter.setChildForBranch(i, newPredictor);
- }
- // insert the new nodes
- m_search_bestInsertionNode.addChild((Splitter) m_search_bestSplitter, this);
- // free memory
- m_search_bestPathPosInstances = null;
- m_search_bestPathNegInstances = null;
- m_search_bestSplitter = null;
- }
- /**
- * Generates the m_nominalAttIndices and m_numericAttIndices arrays to index
- * the respective attribute types in the training data.
- *
- */
- private void generateAttributeIndicesSingle() {
- // insert indices into vectors
- FastVector nominalIndices = new FastVector();
- FastVector numericIndices = new FastVector();
- for (int i=0; i<m_trainInstances.numAttributes(); i++) {
- if (i != m_trainInstances.classIndex()) {
- if (m_trainInstances.attribute(i).isNumeric())
- numericIndices.addElement(new Integer(i));
- else
- nominalIndices.addElement(new Integer(i));
- }
- }
- // create nominal array
- m_nominalAttIndices = new int[nominalIndices.size()];
- for (int i=0; i<nominalIndices.size(); i++)
- m_nominalAttIndices[i] = ((Integer)nominalIndices.elementAt(i)).intValue();
- // create numeric array
- m_numericAttIndices = new int[numericIndices.size()];
- for (int i=0; i<numericIndices.size(); i++)
- m_numericAttIndices[i] = ((Integer)numericIndices.elementAt(i)).intValue();
- }
- /**
- * Performs a search for the best test (splitter) to add to the tree, by aiming to
- * minimize the Z value.
- *
- * @exception Exception if search fails
- */
- private void searchForBestTestSingle() throws Exception {
- // keep track of total weight for efficient wRemainder calculations
- m_trainTotalWeight = m_trainInstances.sumOfWeights();
- m_search_smallestZ = Double.POSITIVE_INFINITY;
- searchForBestTestSingle(m_root, m_posTrainInstances, m_negTrainInstances);
- }
- /**
- * Recursive function that carries out search for the best test (splitter) to add to
- * this part of the tree, by aiming to minimize the Z value. Performs Z-pure cutoff to
- * reduce search space.
- *
- * @param currentNode the root of the subtree to be searched, and the current node
- * being considered as parent of a new split
- * @param posInstances the positive-class instances that apply at this node
- * @param negInstances the negative-class instances that apply at this node
- * @exception Exception if search fails
- */
- private void searchForBestTestSingle(PredictionNode currentNode,
- Instances posInstances, Instances negInstances)
- throws Exception
- {
- // don't investigate pure or empty nodes any further
- if (posInstances.numInstances() == 0 || negInstances.numInstances() == 0) return;
- // do z-pure cutoff
- if (calcZpure(posInstances, negInstances) >= m_search_smallestZ) return;
- // keep stats
- m_nodesExpanded++;
- m_examplesCounted += posInstances.numInstances() + negInstances.numInstances();
- // evaluate static splitters (nominal)
- for (int i=0; i<m_nominalAttIndices.length; i++)
- evaluateNominalSplitSingle(m_nominalAttIndices[i], currentNode,
- posInstances, negInstances);
- // evaluate dynamic splitters (numeric)
- if (m_numericAttIndices.length > 0) {
- // merge the two sets of instances into one
- Instances allInstances = new Instances(posInstances);
- for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); )
- allInstances.add((Instance) e.nextElement());
- // use method of finding the optimal Z split-point
- for (int i=0; i<m_numericAttIndices.length; i++)
- evaluateNumericSplitSingle(m_numericAttIndices[i], currentNode,
- posInstances, negInstances, allInstances);
- }
- if (currentNode.getChildren().size() == 0) return;
- // keep searching
- switch (m_searchPath) {
- case SEARCHPATH_ALL:
- goDownAllPathsSingle(currentNode, posInstances, negInstances);
- break;
- case SEARCHPATH_HEAVIEST:
- goDownHeaviestPathSingle(currentNode, posInstances, negInstances);
- break;
- case SEARCHPATH_ZPURE:
- goDownZpurePathSingle(currentNode, posInstances, negInstances);
- break;
- case SEARCHPATH_RANDOM:
- goDownRandomPathSingle(currentNode, posInstances, negInstances);
- break;
- }
- }
- /**
- * Continues single (two-class optimized) search by investigating every node in the
- * subtree under currentNode.
- *
- * @param currentNode the root of the subtree to be searched
- * @param posInstances the positive-class instances that apply at this node
- * @param negInstances the negative-class instances that apply at this node
- * @exception Exception if search fails
- */
- private void goDownAllPathsSingle(PredictionNode currentNode,
- Instances posInstances, Instances negInstances)
- throws Exception
- {
- for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
- Splitter split = (Splitter) e.nextElement();
- for (int i=0; i<split.getNumOfBranches(); i++)
- searchForBestTestSingle(split.getChildForBranch(i),
- split.instancesDownBranch(i, posInstances),
- split.instancesDownBranch(i, negInstances));
- }
- }
- /**
- * Continues single (two-class optimized) search by investigating only the path
- * with the most heavily weighted instances.
- *
- * @param currentNode the root of the subtree to be searched
- * @param posInstances the positive-class instances that apply at this node
- * @param negInstances the negative-class instances that apply at this node
- * @exception Exception if search fails
- */
- private void goDownHeaviestPathSingle(PredictionNode currentNode,
- Instances posInstances, Instances negInstances)
- throws Exception
- {
- Splitter heaviestSplit = null;
- int heaviestBranch = 0;
- double largestWeight = 0.0;
- for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
- Splitter split = (Splitter) e.nextElement();
- for (int i=0; i<split.getNumOfBranches(); i++) {
- double weight =
- split.instancesDownBranch(i, posInstances).sumOfWeights() +
- split.instancesDownBranch(i, negInstances).sumOfWeights();
- if (weight > largestWeight) {
- heaviestSplit = split;
- heaviestBranch = i;
- largestWeight = weight;
- }
- }
- }
- if (heaviestSplit != null)
- searchForBestTestSingle(heaviestSplit.getChildForBranch(heaviestBranch),
- heaviestSplit.instancesDownBranch(heaviestBranch,
- posInstances),
- heaviestSplit.instancesDownBranch(heaviestBranch,
- negInstances));
- }
- /**
- * Continues single (two-class optimized) search by investigating only the path
- * with the best Z-pure value at each branch.
- *
- * @param currentNode the root of the subtree to be searched
- * @param posInstances the positive-class instances that apply at this node
- * @param negInstances the negative-class instances that apply at this node
- * @exception Exception if search fails
- */
- private void goDownZpurePathSingle(PredictionNode currentNode,
- Instances posInstances, Instances negInstances)
- throws Exception
- {
- double lowestZpure = m_search_smallestZ; // do z-pure cutoff
- PredictionNode bestPath = null;
- Instances bestPosSplit = null, bestNegSplit = null;
- // search for branch with lowest Z-pure
- for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
- Splitter split = (Splitter) e.nextElement();
- for (int i=0; i<split.getNumOfBranches(); i++) {
- Instances posSplit = split.instancesDownBranch(i, posInstances);
- Instances negSplit = split.instancesDownBranch(i, negInstances);
- double newZpure = calcZpure(posSplit, negSplit);
- if (newZpure < lowestZpure) {
- lowestZpure = newZpure;
- bestPath = split.getChildForBranch(i);
- bestPosSplit = posSplit;
- bestNegSplit = negSplit;
- }
- }
- }
- if (bestPath != null)
- searchForBestTestSingle(bestPath, bestPosSplit, bestNegSplit);
- }
- /**
- * Continues single (two-class optimized) search by investigating a random path.
- *
- * @param currentNode the root of the subtree to be searched
- * @param posInstances the positive-class instances that apply at this node
- * @param negInstances the negative-class instances that apply at this node
- * @exception Exception if search fails
- */
- private void goDownRandomPathSingle(PredictionNode currentNode,
- Instances posInstances, Instances negInstances)
- throws Exception {
- FastVector children = currentNode.getChildren();
- Splitter split = (Splitter) children.elementAt(getRandom(children.size()));
- int branch = getRandom(split.getNumOfBranches());
- searchForBestTestSingle(split.getChildForBranch(branch),
- split.instancesDownBranch(branch, posInstances),
- split.instancesDownBranch(branch, negInstances));
- }
- /**
- * Investigates the option of introducing a nominal split under currentNode. If it
- * finds a split that has a Z-value lower than has already been found it will
- * update the search information to record this as the best option so far.
- *
- * @param attIndex index of a nominal attribute to create a split from
- * @param currentNode the parent under which a split is to be considered
- * @param posInstances the positive-class instances that apply at this node
- * @param negInstances the negative-class instances that apply at this node
- */
- private void evaluateNominalSplitSingle(int attIndex, PredictionNode currentNode,
- Instances posInstances, Instances negInstances)
- {
- double[] indexAndZ = findLowestZNominalSplit(posInstances, negInstances, attIndex);
- if (indexAndZ[1] < m_search_smallestZ) {
- m_search_smallestZ = indexAndZ[1];
- m_search_bestInsertionNode = currentNode;
- m_search_bestSplitter = new TwoWayNominalSplit(attIndex, (int) indexAndZ[0]);
- m_search_bestPathPosInstances = posInstances;
- m_search_bestPathNegInstances = negInstances;
- }
- }
- /**
- * Investigates the option of introducing a two-way numeric split under currentNode.
- * If it finds a split that has a Z-value lower than has already been found it will
- * update the search information to record this as the best option so far.
- *
- * @param attIndex index of a numeric attribute to create a split from
- * @param currentNode the parent under which a split is to be considered
- * @param posInstances the positive-class instances that apply at this node
- * @param negInstances the negative-class instances that apply at this node
- * @param allInstances all of the instances the apply at this node (pos+neg combined)
- */
- private void evaluateNumericSplitSingle(int attIndex, PredictionNode currentNode,
- Instances posInstances, Instances negInstances,
- Instances allInstances)
- throws Exception
- {
- double[] splitAndZ = findLowestZNumericSplit(allInstances, attIndex);
- if (splitAndZ[1] < m_search_smallestZ) {
- m_search_smallestZ = splitAndZ[1];
- m_search_bestInsertionNode = currentNode;
- m_search_bestSplitter = new TwoWayNumericSplit(attIndex, splitAndZ[0]);
- m_search_bestPathPosInstances = posInstances;
- m_search_bestPathNegInstances = negInstances;
- }
- }
- /**
- * Calculates the prediction value used for a particular set of instances.
- *
- * @param posInstances the positive-class instances
- * @param negInstances the negative-class instances
- * @return the prediction value
- */
- private double calcPredictionValue(Instances posInstances, Instances negInstances) {
- return 0.5 * Math.log( (posInstances.sumOfWeights() + 1.0)
- / (negInstances.sumOfWeights() + 1.0) );
- }
- /**
- * Calculates the Z-pure value for a particular set of instances.
- *
- * @param posInstances the positive-class instances
- * @param negInstances the negative-class instances
- * @return the Z-pure value
- */
- private double calcZpure(Instances posInstances, Instances negInstances) {
- double posWeight = posInstances.sumOfWeights();
- double negWeight = negInstances.sumOfWeights();
- return (2.0 * (Math.sqrt(posWeight+1.0) + Math.sqrt(negWeight+1.0))) +
- (m_trainTotalWeight - (posWeight + negWeight));
- }
- /**
- * Updates the weights of instances that are influenced by a new prediction value.
- *
- * @param posInstances positive-class instances to which the prediction value applies
- * @param negInstances negative-class instances to which the prediction value applies
- * @param predictionValue the new prediction value
- */
- private void updateWeights(Instances posInstances, Instances negInstances,
- double predictionValue) {
- // do positives
- double weightMultiplier = Math.pow(Math.E, -predictionValue);
- for (Enumeration e = posInstances.enumerateInstances(); e.hasMoreElements(); ) {
- Instance inst = (Instance) e.nextElement();
- inst.setWeight(inst.weight() * weightMultiplier);
- }
- // do negatives
- weightMultiplier = Math.pow(Math.E, predictionValue);
- for (Enumeration e = negInstances.enumerateInstances(); e.hasMoreElements(); ) {
- Instance inst = (Instance) e.nextElement();
- inst.setWeight(inst.weight() * weightMultiplier);
- }
- }
- /**
- * Finds the nominal attribute value to split on that results in the lowest Z-value.
- *
- * @param posInstances the positive-class instances to split
- * @param negInstances the negative-class instances to split
- * @param attIndex the index of the nominal attribute to find a split for
- * @return a double array, index[0] contains the value to split on, index[1] contains
- * the Z-value of the split
- */
- private double[] findLowestZNominalSplit(Instances posInstances, Instances negInstances,
- int attIndex)
- {
- double lowestZ = Double.MAX_VALUE;
- int bestIndex = 0;
- // set up arrays
- double[] posWeights = attributeValueWeights(posInstances, attIndex);
- double[] negWeights = attributeValueWeights(negInstances, attIndex);
- double posWeight = Utils.sum(posWeights);
- double negWeight = Utils.sum(negWeights);
- int maxIndex = posWeights.length;
- if (maxIndex == 2) maxIndex = 1; // avoid repeating due to 2-way symmetry
- for (int i = 0; i < maxIndex; i++) {
- // calculate Z
- double w1 = posWeights[i] + 1.0;
- double w2 = negWeights[i] + 1.0;
- double w3 = posWeight - w1 + 2.0;
- double w4 = negWeight - w2 + 2.0;
- double wRemainder = m_trainTotalWeight + 4.0 - (w1 + w2 + w3 + w4);
- double newZ = (2.0 * (Math.sqrt(w1 * w2) + Math.sqrt(w3 * w4))) + wRemainder;
- // record best option
- if (newZ < lowestZ) {
- lowestZ = newZ;
- bestIndex = i;
- }
- }
- // return result
- double[] indexAndZ = new double[2];
- indexAndZ[0] = (double) bestIndex;
- indexAndZ[1] = lowestZ;
- return indexAndZ;
- }
- /**
- * Simultanously sum the weights of all attribute values for all instances.
- *
- * @param instances the instances to get the weights from
- * @param attIndex index of the attribute to be evaluated
- * @return a double array containing the weight of each attribute value
- */
- private double[] attributeValueWeights(Instances instances, int attIndex)
- {
- double[] weights = new double[instances.attribute(attIndex).numValues()];
- for(int i = 0; i < weights.length; i++) weights[i] = 0.0;
- for (Enumeration e = instances.enumerateInstances(); e.hasMoreElements(); ) {
- Instance inst = (Instance) e.nextElement();
- if (!inst.isMissing(attIndex)) weights[(int)inst.value(attIndex)] += inst.weight();
- }
- return weights;
- }
- /**
- * Finds the numeric split-point that results in the lowest Z-value.
- *
- * @param instances the instances to find a split for
- * @param attIndex the index of the numeric attribute to find a split for
- * @return a double array, index[0] contains the split-point, index[1] contains the
- * Z-value of the split
- */
- private double[] findLowestZNumericSplit(Instances instances, int attIndex)
- throws Exception
- {
- double splitPoint = 0.0;
- double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
- int numMissing = 0;
- double[][] distribution = new double[3][instances.numClasses()];
- // compute counts for all the values
- for (int i = 0; i < instances.numInstances(); i++) {
- Instance inst = instances.instance(i);
- if (!inst.isMissing(attIndex)) {
- distribution[1][(int)inst.classValue()] += inst.weight();
- } else {
- distribution[2][(int)inst.classValue()] += inst.weight();
- numMissing++;
- }
- }
- // sort instances
- instances.sort(attIndex);
- // make split counts for each possible split and evaluate
- for (int i = 0; i < instances.numInstances() - (numMissing + 1); i++) {
- Instance inst = instances.instance(i);
- Instance instPlusOne = instances.instance(i + 1);
- distribution[0][(int)inst.classValue()] += inst.weight();
- distribution[1][(int)inst.classValue()] -= inst.weight();
- if (Utils.sm(inst.value(attIndex), instPlusOne.value(attIndex))) {
- currCutPoint = (inst.value(attIndex) + instPlusOne.value(attIndex)) / 2.0;
- currVal = conditionedZOnRows(distribution);
- if (currVal < bestVal) {
- splitPoint = currCutPoint;
- bestVal = currVal;
- }
- }
- }
- double[] splitAndZ = new double[2];
- splitAndZ[0] = splitPoint;
- splitAndZ[1] = bestVal;
- return splitAndZ;
- }
- /**
- * Calculates the Z-value from the rows of a weight distribution array.
- *
- * @param distribution the weight distribution
- * @return the Z-value
- */
- private double conditionedZOnRows(double [][] distribution) {
- double w1 = distribution[0][0] + 1.0;
- double w2 = distribution[0][1] + 1.0;
- double w3 = distribution[1][0] + 1.0;
- double w4 = distribution[1][1] + 1.0;
- double wRemainder = m_trainTotalWeight + 4.0 - (w1 + w2 + w3 + w4);
- return (2.0 * (Math.sqrt(w1 * w2) + Math.sqrt(w3 * w4))) + wRemainder;
- }
- /**
- * Returns the class probability distribution for an instance.
- *
- * @param instance the instance to be classified
- * @return the distribution the tree generates for the instance
- */
- public double[] distributionForInstance(Instance instance) {
- double predVal = predictionValueForInstance(instance, m_root, 0.0);
- double[] distribution = new double[2];
- distribution[0] = 1.0 / (1.0 + Math.pow(Math.E, predVal));
- distribution[1] = 1.0 / (1.0 + Math.pow(Math.E, -predVal));
- return distribution;
- }
- /**
- * Returns the class prediction value (vote) for an instance.
- *
- * @param inst the instance
- * @param currentNode the root of the tree to get the values from
- * @param currentValue the current value before adding the value contained in the
- * subtree
- * @return the class prediction value (vote)
- */
- protected double predictionValueForInstance(Instance inst, PredictionNode currentNode,
- double currentValue) {
- currentValue += currentNode.getValue();
- for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
- Splitter split = (Splitter) e.nextElement();
- int branch = split.branchInstanceGoesDown(inst);
- if (branch >= 0)
- currentValue = predictionValueForInstance(inst, split.getChildForBranch(branch),
- currentValue);
- }
- return currentValue;
- }
- /**
- * Returns a description of the classifier.
- *
- * @return a string containing a description of the classifier
- */
- public String toString() {
- if (m_root == null)
- return ("ADTree not built yet");
- else {
- return ("Alternating decision tree:nn" + toString(m_root, 1) +
- "nLegend: " + legend() +
- "nTree size (total number of nodes): " + numOfAllNodes(m_root) +
- "nLeaves (number of predictor nodes): " + numOfPredictionNodes(m_root)
- );
- }
- }
- /**
- * Traverses the tree, forming a string that describes it.
- *
- * @param currentNode the current node under investigation
- * @param level the current level in the tree
- * @return the string describing the subtree
- */
- protected String toString(PredictionNode currentNode, int level) {
- StringBuffer text = new StringBuffer();
- text.append(": " + Utils.doubleToString(currentNode.getValue(),3));
- for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
- Splitter split = (Splitter) e.nextElement();
- for (int j=0; j<split.getNumOfBranches(); j++) {
- PredictionNode child = split.getChildForBranch(j);
- if (child != null) {
- text.append("n");
- for (int k = 0; k < level; k++) {
- text.append("| ");
- }
- text.append("(" + split.orderAdded + ")");
- text.append(split.attributeString(m_trainInstances) + " "
- + split.comparisonString(j, m_trainInstances));
- text.append(toString(child, level + 1));
- }
- }
- }
- return text.toString();
- }
- /**
- * Returns graph describing the tree.
- *
- * @return the graph of the tree in dotty format
- * @exception Exception if something goes wrong
- */
- public String graph() throws Exception {
- StringBuffer text = new StringBuffer();
- text.append("digraph ADTree {n");
- graphTraverse(m_root, text, 0, 0, m_trainInstances);
- return text.toString() +"}n";
- }
- /**
- * Traverses the tree, graphing each node.
- *
- * @param currentNode the currentNode under investigation
- * @param text the string built so far
- * @param splitOrder the order the parent splitter was added to the tree
- * @param predOrder the order this predictor was added to the split
- * @exception Exception if something goes wrong
- */
- protected void graphTraverse(PredictionNode currentNode, StringBuffer text,
- int splitOrder, int predOrder, Instances instances)
- throws Exception
- {
- text.append("S" + splitOrder + "P" + predOrder + " [label="");
- text.append(Utils.doubleToString(currentNode.getValue(),3));
- if (splitOrder == 0) // show legend in root
- text.append(" (" + legend() + ")");
- text.append("" shape=box style=filled");
- if (instances.numInstances() > 0) text.append(" data=n" + instances + "n,n");
- text.append("]n");
- for (Enumeration e = currentNode.children(); e.hasMoreElements(); ) {
- Splitter split = (Splitter) e.nextElement();
- text.append("S" + splitOrder + "P" + predOrder + "->" + "S" + split.orderAdded +
- " [style=dotted]n");
- text.append("S" + split.orderAdded + " [label="" + split.orderAdded + ": " +
- split.attributeString(m_trainInstances) + ""]n");
- for (int i=0; i<split.getNumOfBranches(); i++) {
- PredictionNode child = split.getChildForBranch(i);
- if (child != null) {
- text.append("S" + split.orderAdded + "->" + "S" + split.orderAdded + "P" + i +
- " [label="" + split.comparisonString(i, m_trainInstances) + ""]n");
- graphTraverse(child, text, split.orderAdded, i,
- split.instancesDownBranch(i, instances));
- }
- }
- }
- }
- /**
- * Returns the legend of the tree, describing how results are to be interpreted.
- *
- * @return a string containing the legend of the classifier
- */
- public String legend() {
- Attribute classAttribute = null;
- if (m_trainInstances == null) return "";
- try {classAttribute = m_trainInstances.classAttribute();} catch (Exception x){};
- return ("-ve = " + classAttribute.value(0) +
- ", +ve = " + classAttribute.value(1));
- }
- /**
- * @return a description of the classifier suitable for
- * displaying in the explorer/experimenter gui
- */
- public String globalInfo() {
- return "Builds an alternating decision tree, optimized for 2-class problems only.";
- }
- /**
- * @return tip text for this property suitable for
- * displaying in the explorer/experimenter gui
- */
- public String numOfBoostingIterationsTipText() {
- return "Sets the number of boosting iterations to perform. You will need to manually "
- + "tune this parameter to suit the dataset and the desired complexity/accuracy "
- + "tradeoff. More boosting iterations will result in larger (potentially more "
- + " accurate) trees, but will make learning slower. Each iteration will add 3 nodes "
- + "(1 split + 2 prediction) to the tree unless merging occurs.";
- }
- /**
- * Gets the number of boosting iterations.
- *
- * @return the number of boosting iterations
- */
- public int getNumOfBoostingIterations() {
- return m_boostingIterations;
- }
- /**
- * Sets the number of boosting iterations.
- *
- * @param b the number of boosting iterations to use
- */
- public void setNumOfBoostingIterations(int b) {
- m_boostingIterations = b;
- }
- /**
- * @return tip text for this property suitable for
- * displaying in the explorer/experimenter gui
- */
- public String searchPathTipText() {
- return "Sets the type of search to perform when building the tree. The default option"
- + " (Expand all paths) will do an exhaustive search. The other search methods are"
- + " heuristic, so they are not guaranteed to find an optimal solution but they are"
- + " much faster. Expand the heaviest path: searches the path with the most heavily"
- + " weighted instances. Expand the best z-pure path: searches the path determined"
- + " by the best z-pure estimate. Expand a random path: the fastest method, simply"
- + " searches down a single random path on each iteration.";
- }
- /**
- * Gets the method of searching the tree for a new insertion. Will be one of
- * SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
- *
- * @return the tree searching mode
- */
- public SelectedTag getSearchPath() {
- return new SelectedTag(m_searchPath, TAGS_SEARCHPATH);
- }
- /**
- * Sets the method of searching the tree for a new insertion. Will be one of
- * SEARCHPATH_ALL, SEARCHPATH_HEAVIEST, SEARCHPATH_ZPURE, SEARCHPATH_RANDOM.
- *
- * @param newMethod the new tree searching mode
- */
- public void setSearchPath(SelectedTag newMethod) {
- if (newMethod.getTags() == TAGS_SEARCHPATH) {
- m_searchPath = newMethod.getSelectedTag().getID();
- }
- }
- /**
- * @return tip text for this property suitable for
- * displaying in the explorer/experimenter gui
- */
- public String randomSeedTipText() {
- return "Sets the random seed to use for a random search.";
- }
- /**
- * Gets random seed for a random walk.
- *
- * @return the random seed
- */
- public int getRandomSeed() {
- return m_randomSeed;
- }
- /**
- * Sets random seed for a random walk.
- *
- * @param s the random seed
- */
- public void setRandomSeed(int seed) {
- // the actual random object is created when the tree is initialized
- m_randomSeed = seed;
- }
- /**
- * @return tip text for this property suitable for
- * displaying in the explorer/experimenter gui
- */
- public String saveInstanceDataTipText() {
- return "Sets whether the tree is to save instance data - the model will take up more"
- + " memory if it does. If enabled you will be able to visualize the instances at"
- + " the prediction nodes when visualizing the tree.";
- }
- /**
- * Gets whether the tree is to save instance data.
- *
- * @return the random seed
- */
- public boolean getSaveInstanceData() {
- return m_saveInstanceData;
- }
- /**
- * Sets whether the tree is to save instance data.
- *
- * @param s the random seed
- */
- public void setSaveInstanceData(boolean v) {
- m_saveInstanceData = v;
- }
- /**
- * Returns an enumeration describing the available options..
- *
- * @return an enumeration of all the available options.
- */
- public Enumeration listOptions() {
- Vector newVector = new Vector(3);
- newVector.addElement(new Option(
- "tNumber of boosting iterations.n"
- +"t(Default = 10)",
- "B", 1,"-B <number of boosting iterations>"));
- newVector.addElement(new Option(
- "tExpand nodes: -3(all), -2(weight), -1(z_pure), "
- +">=0 seed for random walkn"
- +"t(Default = -3)",
- "E", 1,"-E <-3|-2|-1|>=0>"));
- newVector.addElement(new Option(
- "tSave the instance data with the model",
- "D", 0,"-D"));
- return newVector.elements();
- }
- /**
- * Parses a given list of options. Valid options are:<p>
- *
- * -B num <br>
- * Set the number of boosting iterations
- * (default 10) <p>
- *
- * -E num <br>
- * Set the nodes to expand: -3(all), -2(weight), -1(z_pure), >=0 seed for random walk
- * (default -3) <p>
- *
- * -D <br>
- * Save the instance data with the model <p>
- *
- * @param options the list of options as an array of strings
- * @exception Exception if an option is not supported
- */
- public void setOptions(String[] options) throws Exception {
- String bString = Utils.getOption('B', options);
- if (bString.length() != 0) setNumOfBoostingIterations(Integer.parseInt(bString));
- String eString = Utils.getOption('E', options);
- if (eString.length() != 0) {
- int value = Integer.parseInt(eString);
- if (value >= 0) {
- setSearchPath(new SelectedTag(SEARCHPATH_RANDOM, TAGS_SEARCHPATH));
- setRandomSeed(value);
- } else setSearchPath(new SelectedTag(value + 3, TAGS_SEARCHPATH));
- }
- setSaveInstanceData(Utils.getFlag('D', options));
- Utils.checkForRemainingOptions(options);
- }
- /**
- * Gets the current settings of ADTree.
- *
- * @return an array of strings suitable for passing to setOptions()
- */
- public String[] getOptions() {
- String[] options = new String[6];
- int current = 0;
- options[current++] = "-B"; options[current++] = "" + getNumOfBoostingIterations();
- options[current++] = "-E"; options[current++] = "" +
- (m_searchPath == SEARCHPATH_RANDOM ?
- m_randomSeed : m_searchPath - 3);
- if (getSaveInstanceData()) options[current++] = "-D";
- while (current < options.length) options[current++] = "";
- return options;
- }
- /**
- * Calls measure function for tree size - the total number of nodes.
- *
- * @return the tree size
- */
- public double measureTreeSize() {
- return numOfAllNodes(m_root);
- }
- /**
- * Calls measure function for leaf size - the number of prediction nodes.
- *
- * @return the leaf size
- */
- public double measureNumLeaves() {
- return numOfPredictionNodes(m_root);
- }
- /**
- * Calls measure function for prediction leaf size - the number of
- * prediction nodes without children.
- *
- * @return the leaf size
- */
- public double measureNumPredictionLeaves() {
- return numOfPredictionLeafNodes(m_root);
- }
- /**
- * Returns the number of nodes expanded.
- *
- * @return the number of nodes expanded during search
- */
- public double measureNodesExpanded() {
- return m_nodesExpanded;
- }
- /**
- * Returns the number of examples "counted".
- *
- * @return the number of nodes processed during search
- */
- public double measureExamplesProcessed() {
- return m_examplesCounted;
- }
- /**
- * Returns an enumeration of the additional measure names.
- *
- * @return an enumeration of the measure names
- */
- public Enumeration enumerateMeasures() {
- Vector newVector = new Vector(4);
- newVector.addElement("measureTreeSize");
- newVector.addElement("measureNumLeaves");
- newVector.addElement("measureNumPredictionLeaves");
- newVector.addElement("measureNodesExpanded");
- newVector.addElement("measureExamplesProcessed");
- return newVector.elements();
- }
- /**
- * Returns the value of the named measure.
- *
- * @param measureName the name of the measure to query for its value
- * @return the value of the named measure
- * @exception IllegalArgumentException if the named measure is not supported
- */
- public double getMeasure(String additionalMeasureName) {
- if (additionalMeasureName.equals("measureTreeSize")) {
- return measureTreeSize();
- }
- else if (additionalMeasureName.equals("measureNumLeaves")) {
- return measureNumLeaves();
- }
- else if (additionalMeasureName.equals("measureNumPredictionLeaves")) {
- return measureNumPredictionLeaves();
- }
- else if (additionalMeasureName.equals("measureNodesExpanded")) {
- return measureNodesExpanded();
- }
- else if (additionalMeasureName.equals("measureExamplesProcessed")) {
- return measureExamplesProcessed();
- }
- else {throw new IllegalArgumentException(additionalMeasureName
- + " not supported (ADTree)");
- }
- }
- /**
- * Returns the total number of nodes in a tree.
- *
- * @param root the root of the tree being measured
- * @return tree size in number of splitter + prediction nodes
- */
- protected int numOfAllNodes(PredictionNode root) {
- int numSoFar = 0;
- if (root != null) {
- numSoFar++;
- for (Enumeration e = root.children(); e.hasMoreElements(); ) {
- numSoFar++;
- Splitter split = (Splitter) e.nextElement();
- for (int i=0; i<split.getNumOfBranches(); i++)
- numSoFar += numOfAllNodes(split.getChildForBranch(i));
- }
- }
- return numSoFar;
- }
- /**
- * Returns the number of prediction nodes in a tree.
- *
- * @param root the root of the tree being measured
- * @return tree size in number of prediction nodes
- */
- protected int numOfPredictionNodes(PredictionNode root) {
- int numSoFar = 0;
- if (root != null) {
- numSoFar++;
- for (Enumeration e = root.children(); e.hasMoreElements(); ) {
- Splitter split = (Splitter) e.nextElement();
- for (int i=0; i<split.getNumOfBranches(); i++)
- numSoFar += numOfPredictionNodes(split.getChildForBranch(i));
- }
- }
- return numSoFar;
- }
- /**
- * Returns the number of leaf nodes in a tree - prediction nodes without
- * children.
- *
- * @param root the root of the tree being measured
- * @return tree leaf size in number of prediction nodes
- */
- protected int numOfPredictionLeafNodes(PredictionNode root) {
- int numSoFar = 0;
- if (root.getChildren().size() > 0) {
- for (Enumeration e = root.children(); e.hasMoreElements(); ) {
- Splitter split = (Splitter) e.nextElement();
- for (int i=0; i<split.getNumOfBranches(); i++)
- numSoFar += numOfPredictionLeafNodes(split.getChildForBranch(i));
- }
- } else numSoFar = 1;
- return numSoFar;
- }
- /**
- * Gets the next random value.
- *
- * @param max the maximum value (+1) to be returned
- * @return the next random value (between 0 and max-1)
- */
- protected int getRandom(int max) {
- return m_random.nextInt(max);
- }
- /**
- * Returns the next number in the order that splitter nodes have been added to
- * the tree, and records that a new splitter has been added.
- *
- * @return the next number in the order
- */
- public int nextSplitAddedOrder() {
- return ++m_lastAddedSplitNum;
- }
- /**
- * Builds a classifier for a set of instances.
- *
- * @param instances the instances to train the classifier with
- * @exception Exception if something goes wrong
- */
- public void buildClassifier(Instances instances) throws Exception {
- // set up the tree
- initClassifier(instances);
- // build the tree
- for (int T = 0; T < m_boostingIterations; T++) boost();
- // clean up if desired
- if (!m_saveInstanceData) done();
- }
- /**
- * Frees memory that is no longer needed for a final model - will no longer be able
- * to increment the classifier after calling this.
- *
- */
- public void done() {
- m_trainInstances = new Instances(m_trainInstances, 0);
- m_random = null;
- m_numericAttIndices = null;
- m_nominalAttIndices = null;
- m_posTrainInstances = null;
- m_negTrainInstances = null;
- }
- /**
- * Creates a clone that is identical to the current tree, but is independent.
- * Deep copies the essential elements such as the tree nodes, and the instances
- * (because the weights change.) Reference copies several elements such as the
- * potential splitter sets, assuming that such elements should never differ between
- * clones.
- *
- * @return the clone
- */
- public Object clone() {
- ADTree clone = new ADTree();
- if (m_root != null) { // check for initialization first
- clone.m_root = (PredictionNode) m_root.clone(); // deep copy the tree
- clone.m_trainInstances = new Instances(m_trainInstances); // copy training instances
- // deep copy the random object
- if (m_random != null) {
- SerializedObject randomSerial = null;
- try {
- randomSerial = new SerializedObject(m_random);
- } catch (Exception ignored) {} // we know that Random is serializable
- clone.m_random = (Random) randomSerial.getObject();
- }
- clone.m_lastAddedSplitNum = m_lastAddedSplitNum;
- clone.m_numericAttIndices = m_numericAttIndices;
- clone.m_nominalAttIndices = m_nominalAttIndices;
- clone.m_trainTotalWeight = m_trainTotalWeight;
- // reconstruct pos/negTrainInstances references
- if (m_posTrainInstances != null) {
- clone.m_posTrainInstances =
- new ReferenceInstances(m_trainInstances, m_posTrainInstances.numInstances());
- clone.m_negTrainInstances =
- new ReferenceInstances(m_trainInstances, m_negTrainInstances.numInstances());
- for (Enumeration e = clone.m_trainInstances.enumerateInstances();
- e.hasMoreElements(); ) {
- Instance inst = (Instance) e.nextElement();
- try { // ignore classValue() exception
- if ((int) inst.classValue() == 0)
- clone.m_negTrainInstances.addReference(inst); // belongs in negative class
- else
- clone.m_posTrainInstances.addReference(inst); // belongs in positive class
- } catch (Exception ignored) {}
- }
- }
- }
- clone.m_nodesExpanded = m_nodesExpanded;
- clone.m_examplesCounted = m_examplesCounted;
- clone.m_boostingIterations = m_boostingIterations;
- clone.m_searchPath = m_searchPath;
- clone.m_randomSeed = m_randomSeed;
- return clone;
- }
- /**
- * Merges two trees together. Modifies the tree being acted on, leaving tree passed
- * as a parameter untouched (cloned). Does not check to see whether training instances
- * are compatible - strange things could occur if they are not.
- *
- * @param mergeWith the tree to merge with
- * @exception Exception if merge could not be performed
- */
- public void merge(ADTree mergeWith) throws Exception {
- if (m_root == null || mergeWith.m_root == null)
- throw new Exception("Trying to merge an uninitialized tree");
- m_root.merge(mergeWith.m_root, this);
- }
- /**
- * Main method for testing this class.
- *
- * @param argv the options
- */
- public static void main(String [] argv) {
- try {
- System.out.println(Evaluation.evaluateModel(new ADTree(),
- argv));
- } catch (Exception e) {
- System.err.println(e.getMessage());
- }
- }
- }