Code/Resource
Windows Develop
Linux-Unix program
Internet-Socket-Network
Web Server
Browser Client
Ftp Server
Ftp Client
Browser Plugins
Proxy Server
Email Server
Email Client
WEB Mail
Firewall-Security
Telnet Server
Telnet Client
ICQ-IM-Chat
Search Engine
Sniffer Package capture
Remote Control
xml-soap-webservice
P2P
WEB(ASP,PHP,...)
TCP/IP Stack
SNMP
Grid Computing
SilverLight
DNS
Cluster Service
Network Security
Communication-Mobile
Game Program
Editor
Multimedia program
Graph program
Compiler program
Compress-Decompress algrithms
Crypt_Decrypt algrithms
Mathimatics-Numerical algorithms
MultiLanguage
Disk/Storage
Java Develop
assembly language
Applications
Other systems
Database system
Embeded-SCM Develop
FlashMX/Flex
source in ebook
Delphi VCL
OS Develop
MiddleWare
MPI
MacOS develop
LabView
ELanguage
Software/Tools
E-Books
Artical/Document
DecisionStump.java
Package: Weka-3-2.rar [view]
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 20k
Category:
Windows Develop
Development Platform:
Java
- /*
- * This program is free software; you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation; either version 2 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program; if not, write to the Free Software
- * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
- */
- /*
- * DecisionStump.java
- * Copyright (C) 1999 Eibe Frank
- *
- */
- package weka.classifiers.trees;
- import weka.classifiers.meta.LogitBoost;
- import weka.classifiers.Classifier;
- import weka.classifiers.DistributionClassifier;
- import weka.classifiers.Evaluation;
- import weka.classifiers.meta.LogitBoost;
- import weka.classifiers.Sourcable;
- import java.io.*;
- import java.util.*;
- import weka.core.*;
- /**
- * Class for building and using a decision stump. Usually used in conjunction
- * with a boosting algorithm.
- *
- * Typical usage: <p>
- * <code>java weka.classifiers.trees.LogitBoost -I 100 -W weka.classifiers.trees.DecisionStump
- * -t training_data </code><p>
- *
- * @author Eibe Frank (eibe@cs.waikato.ac.nz)
- * @version $Revision: 1.14 $
- */
- public class DecisionStump extends DistributionClassifier
- implements WeightedInstancesHandler, Sourcable {
- /** The attribute used for classification. */
- private int m_AttIndex;
- /** The split point (index respectively). */
- private double m_SplitPoint;
- /** The distribution of class values or the means in each subset. */
- private double[][] m_Distribution;
- /** The instances used for training. */
- private Instances m_Instances;
- /**
- * Generates the classifier.
- *
- * @param instances set of instances serving as training data
- * @exception Exception if the classifier has not been generated successfully
- */
- public void buildClassifier(Instances instances) throws Exception {
- double bestVal = Double.MAX_VALUE, currVal;
- double bestPoint = -Double.MAX_VALUE, sum;
- int bestAtt = -1, numClasses;
- if (instances.checkForStringAttributes()) {
- throw new UnsupportedAttributeTypeException("Can't handle string attributes!");
- }
- double[][] bestDist = new double[3][instances.numClasses()];
- m_Instances = new Instances(instances);
- m_Instances.deleteWithMissingClass();
- if (m_Instances.classAttribute().isNominal()) {
- numClasses = m_Instances.numClasses();
- } else {
- numClasses = 1;
- }
- // For each attribute
- boolean first = true;
- for (int i = 0; i < m_Instances.numAttributes(); i++) {
- if (i != m_Instances.classIndex()) {
- // Reserve space for distribution.
- m_Distribution = new double[3][numClasses];
- // Compute value of criterion for best split on attribute
- if (m_Instances.attribute(i).isNominal()) {
- currVal = findSplitNominal(i);
- } else {
- currVal = findSplitNumeric(i);
- }
- if ((first) || (Utils.sm(currVal, bestVal))) {
- bestVal = currVal;
- bestAtt = i;
- bestPoint = m_SplitPoint;
- for (int j = 0; j < 3; j++) {
- System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
- numClasses);
- }
- }
- // First attribute has been investigated
- first = false;
- }
- }
- // Set attribute, split point and distribution.
- m_AttIndex = bestAtt;
- m_SplitPoint = bestPoint;
- m_Distribution = bestDist;
- if (m_Instances.classAttribute().isNominal()) {
- for (int i = 0; i < m_Distribution.length; i++) {
- Utils.normalize(m_Distribution[i]);
- }
- }
- // Save memory
- m_Instances = new Instances(m_Instances, 0);
- }
- /**
- * Calculates the class membership probabilities for the given test instance.
- *
- * @param instance the instance to be classified
- * @return predicted class probability distribution
- * @exception Exception if distribution can't be computed
- */
- public double[] distributionForInstance(Instance instance) throws Exception {
- return m_Distribution[whichSubset(instance)];
- }
- /**
- * Returns the decision tree as Java source code.
- *
- * @return the tree as Java source code
- * @exception Exception if something goes wrong
- */
- public String toSource(String className) throws Exception {
- StringBuffer text = new StringBuffer("class ");
- Attribute c = m_Instances.classAttribute();
- text.append(className)
- .append(" {n"
- +" public static double classify(Object [] i) {n");
- text.append(" /* " + m_Instances.attribute(m_AttIndex).name() + " */n");
- text.append(" if (i[").append(m_AttIndex);
- text.append("] == null) { return ");
- text.append(sourceClass(c, m_Distribution[2])).append(";");
- if (m_Instances.attribute(m_AttIndex).isNominal()) {
- text.append(" } else if (((String)i[").append(m_AttIndex);
- text.append("]).equals("");
- text.append(m_Instances.attribute(m_AttIndex).value((int)m_SplitPoint));
- text.append("")");
- } else {
- text.append(" } else if (((Double)i[").append(m_AttIndex);
- text.append("]).doubleValue() <= ").append(m_SplitPoint);
- }
- text.append(") { return ");
- text.append(sourceClass(c, m_Distribution[0])).append(";");
- text.append(" } else { return ");
- text.append(sourceClass(c, m_Distribution[1])).append(";");
- text.append(" }n }n}n");
- return text.toString();
- }
- private String sourceClass(Attribute c, double []dist) {
- if (c.isNominal()) {
- return Integer.toString(Utils.maxIndex(dist));
- } else {
- return Double.toString(dist[0]);
- }
- }
- /**
- * Returns a description of the classifier.
- *
- * @return a description of the classifier as a string.
- */
- public String toString(){
- if (m_Instances == null) {
- return "Decision Stump: No model built yet.";
- }
- try {
- StringBuffer text = new StringBuffer();
- text.append("Decision Stumpnn");
- text.append("Classificationsnn");
- Attribute att = m_Instances.attribute(m_AttIndex);
- if (att.isNominal()) {
- text.append(att.name() + " = " + att.value((int)m_SplitPoint) +
- " : ");
- text.append(printClass(m_Distribution[0]));
- text.append(att.name() + " != " + att.value((int)m_SplitPoint) +
- " : ");
- text.append(printClass(m_Distribution[1]));
- } else {
- text.append(att.name() + " <= " + m_SplitPoint + " : ");
- text.append(printClass(m_Distribution[0]));
- text.append(att.name() + " > " + m_SplitPoint + " : ");
- text.append(printClass(m_Distribution[1]));
- }
- text.append(att.name() + " is missing : ");
- text.append(printClass(m_Distribution[2]));
- if (m_Instances.classAttribute().isNominal()) {
- text.append("nClass distributionsnn");
- if (att.isNominal()) {
- text.append(att.name() + " = " + att.value((int)m_SplitPoint) +
- "n");
- text.append(printDist(m_Distribution[0]));
- text.append(att.name() + " != " + att.value((int)m_SplitPoint) +
- "n");
- text.append(printDist(m_Distribution[1]));
- } else {
- text.append(att.name() + " <= " + m_SplitPoint + "n");
- text.append(printDist(m_Distribution[0]));
- text.append(att.name() + " > " + m_SplitPoint + "n");
- text.append(printDist(m_Distribution[1]));
- }
- text.append(att.name() + " is missingn");
- text.append(printDist(m_Distribution[2]));
- }
- return text.toString();
- } catch (Exception e) {
- return "Can't print decision stump classifier!";
- }
- }
- /**
- * Prints a class distribution.
- *
- * @param dist the class distribution to print
- * @return the distribution as a string
- * @exception Exception if distribution can't be printed
- */
- private String printDist(double[] dist) throws Exception {
- StringBuffer text = new StringBuffer();
- if (m_Instances.classAttribute().isNominal()) {
- for (int i = 0; i < m_Instances.numClasses(); i++) {
- text.append(m_Instances.classAttribute().value(i) + "t");
- }
- text.append("n");
- for (int i = 0; i < m_Instances.numClasses(); i++) {
- text.append(dist[i] + "t");
- }
- text.append("n");
- }
- return text.toString();
- }
- /**
- * Prints a classification.
- *
- * @param dist the class distribution
- * @return the classificationn as a string
- * @exception Exception if the classification can't be printed
- */
- private String printClass(double[] dist) throws Exception {
- StringBuffer text = new StringBuffer();
- if (m_Instances.classAttribute().isNominal()) {
- text.append(m_Instances.classAttribute().value(Utils.maxIndex(dist)));
- } else {
- text.append(dist[0]);
- }
- return text.toString() + "n";
- }
- /**
- * Finds best split for nominal attribute and returns value.
- *
- * @param index attribute index
- * @return value of criterion for the best split
- * @exception Exception if something goes wrong
- */
- private double findSplitNominal(int index) throws Exception {
- if (m_Instances.classAttribute().isNominal()) {
- return findSplitNominalNominal(index);
- } else {
- return findSplitNominalNumeric(index);
- }
- }
- /**
- * Finds best split for nominal attribute and nominal class
- * and returns value.
- *
- * @param index attribute index
- * @return value of criterion for the best split
- * @exception Exception if something goes wrong
- */
- private double findSplitNominalNominal(int index) throws Exception {
- double bestVal = Double.MAX_VALUE, currVal;
- double[][] counts = new double[m_Instances.attribute(index).numValues()
- + 1][m_Instances.numClasses()];
- double[] sumCounts = new double[m_Instances.numClasses()];
- double[][] bestDist = new double[3][m_Instances.numClasses()];
- int numMissing = 0;
- // Compute counts for all the values
- for (int i = 0; i < m_Instances.numInstances(); i++) {
- Instance inst = m_Instances.instance(i);
- if (inst.isMissing(index)) {
- numMissing++;
- counts[m_Instances.attribute(index).numValues()]
- [(int)inst.classValue()] += inst.weight();
- } else {
- counts[(int)inst.value(index)][(int)inst.classValue()] += inst
- .weight();
- }
- }
- // Compute sum of counts
- for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
- for (int j = 0; j < m_Instances.numClasses(); j++) {
- sumCounts[j] += counts[i][j];
- }
- }
- // Make split counts for each possible split and evaluate
- System.arraycopy(counts[m_Instances.attribute(index).numValues()], 0,
- m_Distribution[2], 0, m_Instances.numClasses());
- for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
- for (int j = 0; j < m_Instances.numClasses(); j++) {
- m_Distribution[0][j] = counts[i][j];
- m_Distribution[1][j] = sumCounts[j] - counts[i][j];
- }
- currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
- if (Utils.sm(currVal, bestVal)) {
- bestVal = currVal;
- m_SplitPoint = (double)i;
- for (int j = 0; j < 3; j++) {
- System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
- m_Instances.numClasses());
- }
- }
- }
- // No missing values in training data.
- if (numMissing == 0) {
- System.arraycopy(sumCounts, 0, bestDist[2], 0,
- m_Instances.numClasses());
- }
- m_Distribution = bestDist;
- return bestVal;
- }
- /**
- * Finds best split for nominal attribute and numeric class
- * and returns value.
- *
- * @param index attribute index
- * @return value of criterion for the best split
- * @exception Exception if something goes wrong
- */
- private double findSplitNominalNumeric(int index) throws Exception {
- double bestVal = Double.MAX_VALUE, currVal;
- double[] sumsSquaresPerValue =
- new double[m_Instances.attribute(index).numValues()],
- sumsPerValue = new double[m_Instances.attribute(index).numValues()],
- weightsPerValue = new double[m_Instances.attribute(index).numValues()];
- double totalSumSquaresW = 0, totalSumW = 0, totalSumOfWeightsW = 0,
- totalSumOfWeights = 0, totalSum = 0;
- double[] sumsSquares = new double[3], sumOfWeights = new double[3];
- double[][] bestDist = new double[3][1];
- // Compute counts for all the values
- for (int i = 0; i < m_Instances.numInstances(); i++) {
- Instance inst = m_Instances.instance(i);
- if (inst.isMissing(index)) {
- m_Distribution[2][0] += inst.classValue() * inst.weight();
- sumsSquares[2] += inst.classValue() * inst.classValue()
- * inst.weight();
- sumOfWeights[2] += inst.weight();
- } else {
- weightsPerValue[(int)inst.value(index)] += inst.weight();
- sumsPerValue[(int)inst.value(index)] += inst.classValue()
- * inst.weight();
- sumsSquaresPerValue[(int)inst.value(index)] +=
- inst.classValue() * inst.classValue() * inst.weight();
- }
- totalSumOfWeights += inst.weight();
- totalSum += inst.classValue() * inst.weight();
- }
- // Check if the total weight is zero
- if (Utils.eq(totalSumOfWeights, 0)) {
- return bestVal;
- }
- // Compute sum of counts without missing ones
- for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
- totalSumOfWeightsW += weightsPerValue[i];
- totalSumSquaresW += sumsSquaresPerValue[i];
- totalSumW += sumsPerValue[i];
- }
- // Make split counts for each possible split and evaluate
- for (int i = 0; i < m_Instances.attribute(index).numValues(); i++) {
- m_Distribution[0][0] = sumsPerValue[i];
- sumsSquares[0] = sumsSquaresPerValue[i];
- sumOfWeights[0] = weightsPerValue[i];
- m_Distribution[1][0] = totalSumW - sumsPerValue[i];
- sumsSquares[1] = totalSumSquaresW - sumsSquaresPerValue[i];
- sumOfWeights[1] = totalSumOfWeightsW - weightsPerValue[i];
- currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
- if (Utils.sm(currVal, bestVal)) {
- bestVal = currVal;
- m_SplitPoint = (double)i;
- for (int j = 0; j < 3; j++) {
- if (!Utils.eq(sumOfWeights[j], 0)) {
- bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
- } else {
- bestDist[j][0] = totalSum / totalSumOfWeights;
- }
- }
- }
- }
- m_Distribution = bestDist;
- return bestVal;
- }
- /**
- * Finds best split for numeric attribute and returns value.
- *
- * @param index attribute index
- * @return value of criterion for the best split
- * @exception Exception if something goes wrong
- */
- private double findSplitNumeric(int index) throws Exception {
- if (m_Instances.classAttribute().isNominal()) {
- return findSplitNumericNominal(index);
- } else {
- return findSplitNumericNumeric(index);
- }
- }
- /**
- * Finds best split for numeric attribute and nominal class
- * and returns value.
- *
- * @param index attribute index
- * @return value of criterion for the best split
- * @exception Exception if something goes wrong
- */
- private double findSplitNumericNominal(int index) throws Exception {
- double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
- int numMissing = 0;
- double[] sum = new double[m_Instances.numClasses()];
- double[][] bestDist = new double[3][m_Instances.numClasses()];
- // Compute counts for all the values
- for (int i = 0; i < m_Instances.numInstances(); i++) {
- Instance inst = m_Instances.instance(i);
- if (!inst.isMissing(index)) {
- m_Distribution[1][(int)inst.classValue()] += inst.weight();
- } else {
- m_Distribution[2][(int)inst.classValue()] += inst.weight();
- numMissing++;
- }
- }
- System.arraycopy(m_Distribution[1], 0, sum, 0, m_Instances.numClasses());
- // Sort instances
- m_Instances.sort(index);
- // Make split counts for each possible split and evaluate
- for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
- Instance inst = m_Instances.instance(i);
- Instance instPlusOne = m_Instances.instance(i + 1);
- m_Distribution[0][(int)inst.classValue()] += inst.weight();
- m_Distribution[1][(int)inst.classValue()] -= inst.weight();
- if (Utils.sm(inst.value(index), instPlusOne.value(index))) {
- currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
- currVal = ContingencyTables.entropyConditionedOnRows(m_Distribution);
- if (Utils.sm(currVal, bestVal)) {
- m_SplitPoint = currCutPoint;
- bestVal = currVal;
- for (int j = 0; j < 3; j++) {
- System.arraycopy(m_Distribution[j], 0, bestDist[j], 0,
- m_Instances.numClasses());
- }
- }
- }
- }
- // No missing values in training data.
- if (numMissing == 0) {
- System.arraycopy(sum, 0, bestDist[2], 0, m_Instances.numClasses());
- }
- m_Distribution = bestDist;
- return bestVal;
- }
- /**
- * Finds best split for numeric attribute and numeric class
- * and returns value.
- *
- * @param index attribute index
- * @return value of criterion for the best split
- * @exception Exception if something goes wrong
- */
- private double findSplitNumericNumeric(int index) throws Exception {
- double bestVal = Double.MAX_VALUE, currVal, currCutPoint;
- int numMissing = 0;
- double[] sumsSquares = new double[3], sumOfWeights = new double[3];
- double[][] bestDist = new double[3][1];
- double totalSum = 0, totalSumOfWeights = 0;
- // Compute counts for all the values
- for (int i = 0; i < m_Instances.numInstances(); i++) {
- Instance inst = m_Instances.instance(i);
- if (!inst.isMissing(index)) {
- m_Distribution[1][0] += inst.classValue() * inst.weight();
- sumsSquares[1] += inst.classValue() * inst.classValue()
- * inst.weight();
- sumOfWeights[1] += inst.weight();
- } else {
- m_Distribution[2][0] += inst.classValue() * inst.weight();
- sumsSquares[2] += inst.classValue() * inst.classValue()
- * inst.weight();
- sumOfWeights[2] += inst.weight();
- numMissing++;
- }
- totalSumOfWeights += inst.weight();
- totalSum += inst.classValue() * inst.weight();
- }
- // Check if the total weight is zero
- if (Utils.eq(totalSumOfWeights, 0)) {
- return bestVal;
- }
- // Sort instances
- m_Instances.sort(index);
- // Make split counts for each possible split and evaluate
- for (int i = 0; i < m_Instances.numInstances() - (numMissing + 1); i++) {
- Instance inst = m_Instances.instance(i);
- Instance instPlusOne = m_Instances.instance(i + 1);
- m_Distribution[0][0] += inst.classValue() * inst.weight();
- sumsSquares[0] += inst.classValue() * inst.classValue() * inst.weight();
- sumOfWeights[0] += inst.weight();
- m_Distribution[1][0] -= inst.classValue() * inst.weight();
- sumsSquares[1] -= inst.classValue() * inst.classValue() * inst.weight();
- sumOfWeights[1] -= inst.weight();
- if (Utils.sm(inst.value(index), instPlusOne.value(index))) {
- currCutPoint = (inst.value(index) + instPlusOne.value(index)) / 2.0;
- currVal = variance(m_Distribution, sumsSquares, sumOfWeights);
- if (Utils.sm(currVal, bestVal)) {
- m_SplitPoint = currCutPoint;
- bestVal = currVal;
- for (int j = 0; j < 3; j++) {
- if (!Utils.eq(sumOfWeights[j], 0)) {
- bestDist[j][0] = m_Distribution[j][0] / sumOfWeights[j];
- } else {
- bestDist[j][0] = totalSum / totalSumOfWeights;
- }
- }
- }
- }
- }
- m_Distribution = bestDist;
- return bestVal;
- }
- /**
- * Computes variance for subsets.
- */
- private double variance(double[][] s,double[] sS,double[] sumOfWeights) {
- double var = 0;
- for (int i = 0; i < s.length; i++) {
- if (Utils.gr(sumOfWeights[i], 0)) {
- var += sS[i] - ((s[i][0] * s[i][0]) / (double) sumOfWeights[i]);
- }
- }
- return var;
- }
- /**
- * Returns the subset an instance falls into.
- */
- private int whichSubset(Instance instance) throws Exception {
- if (instance.isMissing(m_AttIndex)) {
- return 2;
- } else if (instance.attribute(m_AttIndex).isNominal()) {
- if ((int)instance.value(m_AttIndex) == m_SplitPoint) {
- return 0;
- } else {
- return 1;
- }
- } else {
- if (Utils.smOrEq(instance.value(m_AttIndex), m_SplitPoint)) {
- return 0;
- } else {
- return 1;
- }
- }
- }
- /**
- * Main method for testing this class.
- *
- * @param argv the options
- */
- public static void main(String [] argv) {
- Classifier scheme;
- try {
- scheme = new DecisionStump();
- System.out.println(Evaluation.evaluateModel(scheme, argv));
- } catch (Exception e) {
- System.err.println(e.getMessage());
- }
- }
- }