Evaluation.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 78k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    Evaluation.java
  18.  *    Copyright (C) 1999 Eibe Frank,Len Trigg
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import java.util.*;
  23. import java.io.*;
  24. import weka.core.*;
  25. import weka.estimators.*;
  26. import java.util.zip.GZIPInputStream;
  27. import java.util.zip.GZIPOutputStream;
  28. /**
  29.  * Class for evaluating machine learning models. <p>
  30.  *
  31.  * ------------------------------------------------------------------- <p>
  32.  *
  33.  * General options when evaluating a learning scheme from the command-line: <p>
  34.  *
  35.  * -t filename <br>
  36.  * Name of the file with the training data. (required) <p>
  37.  *
  38.  * -T filename <br>
  39.  * Name of the file with the test data. If missing a cross-validation 
  40.  * is performed. <p>
  41.  *
  42.  * -c index <br>
  43.  * Index of the class attribute (1, 2, ...; default: last). <p>
  44.  *
  45.  * -x number <br>
  46.  * The number of folds for the cross-validation (default: 10). <p>
  47.  *
  48.  * -s seed <br>
  49.  * Random number seed for the cross-validation (default: 1). <p>
  50.  *
  51.  * -m filename <br>
  52.  * The name of a file containing a cost matrix. <p>
  53.  *
  54.  * -l filename <br>
  55.  * Loads classifier from the given file. <p>
  56.  *
  57.  * -d filename <br>
  58.  * Saves classifier built from the training data into the given file. <p>
  59.  *
  60.  * -v <br>
  61.  * Outputs no statistics for the training data. <p>
  62.  *
  63.  * -o <br>
  64.  * Outputs statistics only, not the classifier. <p>
  65.  * 
  66.  * -i <br>
  67.  * Outputs information-retrieval statistics per class. <p>
  68.  *
  69.  * -k <br>
  70.  * Outputs information-theoretic statistics. <p>
  71.  *
  72.  * -p range <br>
  73.  * Outputs predictions for test instances, along with the attributes in 
  74.  * the specified range (and nothing else). Use '-p 0' if no attributes are
  75.  * desired. <p>
  76.  *
  77.  * -r <br>
  78.  * Outputs cumulative margin distribution (and nothing else). <p>
  79.  *
  80.  * -g <br> 
  81.  * Only for classifiers that implement "Graphable." Outputs
  82.  * the graph representation of the classifier (and nothing
  83.  * else). <p>
  84.  *
  85.  * ------------------------------------------------------------------- <p>
  86.  *
  87.  * Example usage as the main of a classifier (called FunkyClassifier):
  88.  * <code> <pre>
  89.  * public static void main(String [] args) {
  90.  *   try {
  91.  *     Classifier scheme = new FunkyClassifier();
  92.  *     System.out.println(Evaluation.evaluateModel(scheme, args));
  93.  *   } catch (Exception e) {
  94.  *     System.err.println(e.getMessage());
  95.  *   }
  96.  * }
  97.  * </pre> </code> 
  98.  * <p>
  99.  *
  100.  * ------------------------------------------------------------------ <p>
  101.  *
  102.  * Example usage from within an application:
  103.  * <code> <pre>
  104.  * Instances trainInstances = ... instances got from somewhere
  105.  * Instances testInstances = ... instances got from somewhere
  106.  * Classifier scheme = ... scheme got from somewhere
  107.  *
  108.  * Evaluation evaluation = new Evaluation(trainInstances);
  109.  * evaluation.evaluateModel(scheme, testInstances);
  110.  * System.out.println(evaluation.toSummaryString());
  111.  * </pre> </code> 
  112.  *
  113.  *
  114.  * @author   Eibe Frank (eibe@cs.waikato.ac.nz)
  115.  * @author   Len Trigg (trigg@cs.waikato.ac.nz)
  116.  * @version  $Revision: 1.42 $
  117.   */
  118. public class Evaluation implements Summarizable {
  119.   /** The number of classes. */
  120.   private int m_NumClasses;
  121.   /** The number of folds for a cross-validation. */
  122.   private int m_NumFolds;
  123.  
  124.   /** The weight of all incorrectly classified instances. */
  125.   private double m_Incorrect;
  126.   /** The weight of all correctly classified instances. */
  127.   private double m_Correct;
  128.   /** The weight of all unclassified instances. */
  129.   private double m_Unclassified;
  130.   /*** The weight of all instances that had no class assigned to them. */
  131.   private double m_MissingClass;
  132.   /** The weight of all instances that had a class assigned to them. */
  133.   private double m_WithClass;
  134.   /** Array for storing the confusion matrix. */
  135.   private double [][] m_ConfusionMatrix;
  136.   /** The names of the classes. */
  137.   private String [] m_ClassNames;
  138.   /** Is the class nominal or numeric? */
  139.   private boolean m_ClassIsNominal;
  140.   
  141.   /** The prior probabilities of the classes */
  142.   private double [] m_ClassPriors;
  143.   /** The sum of counts for priors */
  144.   private double m_ClassPriorsSum;
  145.   /** The cost matrix (if given). */
  146.   private CostMatrix m_CostMatrix;
  147.   /** The total cost of predictions (includes instance weights) */
  148.   private double m_TotalCost;
  149.   /** Sum of errors. */
  150.   private double m_SumErr;
  151.   
  152.   /** Sum of absolute errors. */
  153.   private double m_SumAbsErr;
  154.   /** Sum of squared errors. */
  155.   private double m_SumSqrErr;
  156.   /** Sum of class values. */
  157.   private double m_SumClass;
  158.   
  159.   /** Sum of squared class values. */
  160.   private double m_SumSqrClass;
  161.   /*** Sum of predicted values. */
  162.   private double m_SumPredicted;
  163.   /** Sum of squared predicted values. */
  164.   private double m_SumSqrPredicted;
  165.   /** Sum of predicted * class values. */
  166.   private double m_SumClassPredicted;
  167.   /** Sum of absolute errors of the prior */
  168.   private double m_SumPriorAbsErr;
  169.   /** Sum of absolute errors of the prior */
  170.   private double m_SumPriorSqrErr;
  171.   /** Total Kononenko & Bratko Information */
  172.   private double m_SumKBInfo;
  173.   /*** Resolution of the margin histogram */
  174.   private static int k_MarginResolution = 500;
  175.   /** Cumulative margin distribution */
  176.   private double m_MarginCounts [];
  177.   /** Number of non-missing class training instances seen */
  178.   private int m_NumTrainClassVals;
  179.   /** Array containing all numeric training class values seen */
  180.   private double [] m_TrainClassVals;
  181.   /** Array containing all numeric training class weights */
  182.   private double [] m_TrainClassWeights;
  183.   /** Numeric class error estimator for prior */
  184.   private Estimator m_PriorErrorEstimator;
  185.   /** Numeric class error estimator for scheme */
  186.   private Estimator m_ErrorEstimator;
  187.   /**
  188.    * The minimum probablility accepted from an estimator to avoid
  189.    * taking log(0) in Sf calculations.
  190.    */
  191.   private static final double MIN_SF_PROB = Double.MIN_VALUE;
  192.   /** Total entropy of prior predictions */
  193.   private double m_SumPriorEntropy;
  194.   
  195.   /** Total entropy of scheme predictions */
  196.   private double m_SumSchemeEntropy;
  197.   
  198.   /**
  199.    * Initializes all the counters for the evaluation.
  200.    *
  201.    * @param data set of training instances, to get some header 
  202.    * information and prior class distribution information
  203.    * @exception Exception if the class is not defined
  204.    */
  205.   public Evaluation(Instances data) throws Exception {
  206.     
  207.     this(data, null);
  208.   }
  209.   /**
  210.    * Initializes all the counters for the evaluation and also takes a
  211.    * cost matrix as parameter.
  212.    *
  213.    * @param data set of instances, to get some header information
  214.    * @param costMatrix the cost matrix---if null, default costs will be used
  215.    * @exception Exception if cost matrix is not compatible with 
  216.    * data, the class is not defined or the class is numeric
  217.    */
  218.   public Evaluation(Instances data, CostMatrix costMatrix) 
  219.        throws Exception {
  220.     
  221.     m_NumClasses = data.numClasses();
  222.     m_NumFolds = 1;
  223.     m_ClassIsNominal = data.classAttribute().isNominal();
  224.     if (m_ClassIsNominal) {
  225.       m_ConfusionMatrix = new double [m_NumClasses][m_NumClasses];
  226.       m_ClassNames = new String [m_NumClasses];
  227.       for(int i = 0; i < m_NumClasses; i++) {
  228. m_ClassNames[i] = data.classAttribute().value(i);
  229.       }
  230.     }
  231.     m_CostMatrix = costMatrix;
  232.     if (m_CostMatrix != null) {
  233.       if (!m_ClassIsNominal) {
  234. throw new Exception("Class has to be nominal if cost matrix " + 
  235.     "given!");
  236.       }
  237.       if (m_CostMatrix.size() != m_NumClasses) {
  238. throw new Exception("Cost matrix not compatible with data!");
  239.       }
  240.     }
  241.     m_ClassPriors = new double [m_NumClasses];
  242.     setPriors(data);
  243.     m_MarginCounts = new double [k_MarginResolution + 1];
  244.   }
  245.   /**
  246.    * Returns a copy of the confusion matrix.
  247.    *
  248.    * @return a copy of the confusion matrix as a two-dimensional array
  249.    */
  250.   public double[][] confusionMatrix() {
  251.     double[][] newMatrix = new double[m_ConfusionMatrix.length][0];
  252.     for (int i = 0; i < m_ConfusionMatrix.length; i++) {
  253.       newMatrix[i] = new double[m_ConfusionMatrix[i].length];
  254.       System.arraycopy(m_ConfusionMatrix[i], 0, newMatrix[i], 0,
  255.        m_ConfusionMatrix[i].length);
  256.     }
  257.     return newMatrix;
  258.   }
  259.   /**
  260.    * Performs a (stratified if class is nominal) cross-validation 
  261.    * for a classifier on a set of instances.
  262.    *
  263.    * @param classifier the classifier with any options set.
  264.    * @param data the data on which the cross-validation is to be 
  265.    * performed 
  266.    * @param numFolds the number of folds for the cross-validation
  267.    * @exception Exception if a classifier could not be generated 
  268.    * successfully or the class is not defined
  269.    */
  270.   public void crossValidateModel(Classifier classifier,
  271.  Instances data, int numFolds) 
  272.     throws Exception {
  273.     
  274.     // Make a copy of the data we can reorder
  275.     data = new Instances(data);
  276.     if (data.classAttribute().isNominal()) {
  277.       data.stratify(numFolds);
  278.     }
  279.     // Do the folds
  280.     for (int i = 0; i < numFolds; i++) {
  281.       Instances train = data.trainCV(numFolds, i);
  282.       setPriors(train);
  283.       classifier.buildClassifier(train);
  284.       Instances test = data.testCV(numFolds, i);
  285.       evaluateModel(classifier, test);
  286.     }
  287.     m_NumFolds = numFolds;
  288.   }
  289.   /**
  290.    * Performs a (stratified if class is nominal) cross-validation 
  291.    * for a classifier on a set of instances.
  292.    *
  293.    * @param classifier a string naming the class of the classifier
  294.    * @param data the data on which the cross-validation is to be 
  295.    * performed 
  296.    * @param numFolds the number of folds for the cross-validation
  297.    * @param options the options to the classifier. Any options
  298.    * accepted by the classifier will be removed from this array.
  299.    * @exception Exception if a classifier could not be generated 
  300.    * successfully or the class is not defined
  301.    */
  302.   public void crossValidateModel(String classifierString,
  303.  Instances data, int numFolds,
  304.  String[] options) 
  305.        throws Exception {
  306.     
  307.     crossValidateModel(Classifier.forName(classifierString, options),
  308.        data, numFolds);
  309.   }
  310.   /**
  311.    * Evaluates a classifier with the options given in an array of
  312.    * strings. <p>
  313.    *
  314.    * Valid options are: <p>
  315.    *
  316.    * -t filename <br>
  317.    * Name of the file with the training data. (required) <p>
  318.    *
  319.    * -T filename <br>
  320.    * Name of the file with the test data. If missing a cross-validation 
  321.    * is performed. <p>
  322.    *
  323.    * -c index <br>
  324.    * Index of the class attribute (1, 2, ...; default: last). <p>
  325.    *
  326.    * -x number <br>
  327.    * The number of folds for the cross-validation (default: 10). <p>
  328.    *
  329.    * -s seed <br>
  330.    * Random number seed for the cross-validation (default: 1). <p>
  331.    *
  332.    * -m filename <br>
  333.    * The name of a file containing a cost matrix. <p>
  334.    *
  335.    * -l filename <br>
  336.    * Loads classifier from the given file. <p>
  337.    *
  338.    * -d filename <br>
  339.    * Saves classifier built from the training data into the given file. <p>
  340.    *
  341.    * -v <br>
  342.    * Outputs no statistics for the training data. <p>
  343.    *
  344.    * -o <br>
  345.    * Outputs statistics only, not the classifier. <p>
  346.    * 
  347.    * -i <br>
  348.    * Outputs detailed information-retrieval statistics per class. <p>
  349.    *
  350.    * -k <br>
  351.    * Outputs information-theoretic statistics. <p>
  352.    *
  353.    * -p range <br>
  354.    * Outputs predictions for test instances, along with the attributes in 
  355.    * the specified range (and nothing else). Use '-p 0' if no attributes are
  356.    * desired. <p>
  357.    *
  358.    * -r <br>
  359.    * Outputs cumulative margin distribution (and nothing else). <p>
  360.    *
  361.    * -g <br> 
  362.    * Only for classifiers that implement "Graphable." Outputs
  363.    * the graph representation of the classifier (and nothing
  364.    * else). <p>
  365.    *
  366.    * @param classifierString class of machine learning classifier as a string
  367.    * @param options the array of string containing the options
  368.    * @exception Exception if model could not be evaluated successfully
  369.    * @return a string describing the results 
  370.    */
  371.   public static String evaluateModel(String classifierString, 
  372.      String [] options) throws Exception {
  373.     Classifier classifier;  
  374.     // Create classifier
  375.     try {
  376.       classifier = 
  377.       (Classifier)Class.forName(classifierString).newInstance();
  378.     } catch (Exception e) {
  379.       throw new Exception("Can't find class with name " 
  380.   + classifierString + '.');
  381.     }
  382.     return evaluateModel(classifier, options);
  383.   }
  384.   
  385.   /**
  386.    * A test method for this class. Just extracts the first command line
  387.    * argument as a classifier class name and calls evaluateModel.
  388.    * @param args an array of command line arguments, the first of which
  389.    * must be the class name of a classifier.
  390.    */
  391.   public static void main(String [] args) {
  392.     try {
  393.       if (args.length == 0) {
  394. throw new Exception("The first argument must be the class name"
  395.     + " of a classifier");
  396.       }
  397.       String classifier = args[0];
  398.       args[0] = "";
  399.       System.out.println(evaluateModel(classifier, args));
  400.     } catch (Exception ex) {
  401.       ex.printStackTrace();
  402.       System.err.println(ex.getMessage());
  403.     }
  404.   }
  405.   /**
  406.    * Evaluates a classifier with the options given in an array of
  407.    * strings. <p>
  408.    *
  409.    * Valid options are: <p>
  410.    *
  411.    * -t name of training file <br>
  412.    * Name of the file with the training data. (required) <p>
  413.    *
  414.    * -T name of test file <br>
  415.    * Name of the file with the test data. If missing a cross-validation 
  416.    * is performed. <p>
  417.    *
  418.    * -c class index <br>
  419.    * Index of the class attribute (1, 2, ...; default: last). <p>
  420.    *
  421.    * -x number of folds <br>
  422.    * The number of folds for the cross-validation (default: 10). <p>
  423.    *
  424.    * -s random number seed <br>
  425.    * Random number seed for the cross-validation (default: 1). <p>
  426.    *
  427.    * -m file with cost matrix <br>
  428.    * The name of a file containing a cost matrix. <p>
  429.    *
  430.    * -l name of model input file <br>
  431.    * Loads classifier from the given file. <p>
  432.    *
  433.    * -d name of model output file <br>
  434.    * Saves classifier built from the training data into the given file. <p>
  435.    *
  436.    * -v <br>
  437.    * Outputs no statistics for the training data. <p>
  438.    *
  439.    * -o <br>
  440.    * Outputs statistics only, not the classifier. <p>
  441.    * 
  442.    * -i <br>
  443.    * Outputs detailed information-retrieval statistics per class. <p>
  444.    *
  445.    * -k <br>
  446.    * Outputs information-theoretic statistics. <p>
  447.    *
  448.    * -p <br>
  449.    * Outputs predictions for test instances (and nothing else). <p>
  450.    *
  451.    * -r <br>
  452.    * Outputs cumulative margin distribution (and nothing else). <p>
  453.    *
  454.    * -g <br> 
  455.    * Only for classifiers that implement "Graphable." Outputs
  456.    * the graph representation of the classifier (and nothing
  457.    * else). <p>
  458.    *
  459.    * @param classifier machine learning classifier
  460.    * @param options the array of string containing the options
  461.    * @exception Exception if model could not be evaluated successfully
  462.    * @return a string describing the results */
  463.   public static String evaluateModel(Classifier classifier,
  464.      String [] options) throws Exception {
  465.       
  466.     Instances train = null, tempTrain, test = null, template = null;
  467.     int seed = 1, folds = 10, classIndex = -1;
  468.     String trainFileName, testFileName, sourceClass, 
  469.       classIndexString, seedString, foldsString, objectInputFileName, 
  470.       objectOutputFileName, attributeRangeString;
  471.     boolean IRstatistics = false, noOutput = false,
  472.       printClassifications = false, trainStatistics = true,
  473.       printMargins = false, printComplexityStatistics = false,
  474.       printGraph = false, classStatistics = false, printSource = false;
  475.     StringBuffer text = new StringBuffer();
  476.     BufferedReader trainReader = null, testReader = null;
  477.     ObjectInputStream objectInputStream = null;
  478.     Random random;
  479.     CostMatrix costMatrix = null;
  480.     StringBuffer schemeOptionsText = null;
  481.     Range attributesToOutput = null;
  482.     long trainTimeStart = 0, trainTimeElapsed = 0,
  483.       testTimeStart = 0, testTimeElapsed = 0;
  484.     
  485.     try {
  486.       // Get basic options (options the same for all schemes)
  487.       classIndexString = Utils.getOption('c', options);
  488.       if (classIndexString.length() != 0) {
  489. classIndex = Integer.parseInt(classIndexString);
  490.       }
  491.       trainFileName = Utils.getOption('t', options); 
  492.       objectInputFileName = Utils.getOption('l', options);
  493.       objectOutputFileName = Utils.getOption('d', options);
  494.       testFileName = Utils.getOption('T', options);
  495.       if (trainFileName.length() == 0) {
  496. if (objectInputFileName.length() == 0) {
  497.   throw new Exception("No training file and no object "+
  498.       "input file given.");
  499. if (testFileName.length() == 0) {
  500.   throw new Exception("No training file and no test "+
  501.       "file given.");
  502. }
  503.       } else if ((objectInputFileName.length() != 0) &&
  504.  ((!(classifier instanceof UpdateableClassifier)) ||
  505.  (testFileName.length() == 0))) {
  506. throw new Exception("Classifier not incremental, or no " +
  507.     "test file provided: can't "+
  508.     "use both train and model file.");
  509.       }
  510.       try {
  511. if (trainFileName.length() != 0) {
  512.   trainReader = new BufferedReader(new FileReader(trainFileName));
  513. }
  514. if (testFileName.length() != 0) {
  515.   testReader = new BufferedReader(new FileReader(testFileName));
  516. }
  517. if (objectInputFileName.length() != 0) {
  518.           InputStream is = new FileInputStream(objectInputFileName);
  519.           if (objectInputFileName.endsWith(".gz")) {
  520.             is = new GZIPInputStream(is);
  521.           }
  522.   objectInputStream = new ObjectInputStream(is);
  523. }
  524.       } catch (Exception e) {
  525. throw new Exception("Can't open file " + e.getMessage() + '.');
  526.       }
  527.       if (testFileName.length() != 0) {
  528. template = test = new Instances(testReader, 1);
  529. if (classIndex != -1) {
  530.   test.setClassIndex(classIndex - 1);
  531. } else {
  532.   test.setClassIndex(test.numAttributes() - 1);
  533. }
  534. if (classIndex > test.numAttributes()) {
  535.   throw new Exception("Index of class attribute too large.");
  536. }
  537.       }
  538.       if (trainFileName.length() != 0) {
  539. if ((classifier instanceof UpdateableClassifier) &&
  540.     (testFileName.length() != 0)) {
  541.   train = new Instances(trainReader, 1);
  542. } else {
  543.   train = new Instances(trainReader);
  544. }
  545.         template = train;
  546. if (classIndex != -1) {
  547.   train.setClassIndex(classIndex - 1);
  548. } else {
  549.   train.setClassIndex(train.numAttributes() - 1);
  550. }
  551. if (classIndex > train.numAttributes()) {
  552.   throw new Exception("Index of class attribute too large.");
  553. }
  554. //train = new Instances(train);
  555.       }
  556.       if (template == null) {
  557.         throw new Exception("No actual dataset provided to use as template");
  558.       }
  559.       seedString = Utils.getOption('s', options);
  560.       if (seedString.length() != 0) {
  561. seed = Integer.parseInt(seedString);
  562.       }
  563.       foldsString = Utils.getOption('x', options);
  564.       if (foldsString.length() != 0) {
  565. folds = Integer.parseInt(foldsString);
  566.       }
  567.       costMatrix = handleCostOption(Utils.getOption('m', options), template.numClasses());
  568.       classStatistics = Utils.getFlag('i', options);
  569.       noOutput = Utils.getFlag('o', options);
  570.       trainStatistics = !Utils.getFlag('v', options);
  571.       printComplexityStatistics = Utils.getFlag('k', options);
  572.       printMargins = Utils.getFlag('r', options);
  573.       printGraph = Utils.getFlag('g', options);
  574.       sourceClass = Utils.getOption('z', options);
  575.       printSource = (sourceClass.length() != 0);
  576.       
  577.       // Check -p option
  578.       try {
  579. attributeRangeString = Utils.getOption('p', options);
  580.       }
  581.       catch (Exception e) {
  582. throw new Exception(e.getMessage() + "nNOTE: the -p option has changed. " +
  583.     "It now expects a parameter specifying a range of attributes " +
  584.     "to list with the predictions. Use '-p 0' for none.");
  585.       }
  586.       if (attributeRangeString.length() != 0) {
  587. printClassifications = true;
  588. if (!attributeRangeString.equals("0")) 
  589.   attributesToOutput = new Range(attributeRangeString);
  590.       }
  591.       // If a model file is given, we can't process 
  592.       // scheme-specific options
  593.       if (objectInputFileName.length() != 0) {
  594. Utils.checkForRemainingOptions(options);
  595.       } else {
  596. // Set options for classifier
  597. if (classifier instanceof OptionHandler) {
  598.   for (int i = 0; i < options.length; i++) {
  599.     if (options[i].length() != 0) {
  600.       if (schemeOptionsText == null) {
  601. schemeOptionsText = new StringBuffer();
  602.       }
  603.       if (options[i].indexOf(' ') != -1) {
  604. schemeOptionsText.append('"' + options[i] + "" ");
  605.       } else {
  606. schemeOptionsText.append(options[i] + " ");
  607.       }
  608.     }
  609.   }
  610.   ((OptionHandler)classifier).setOptions(options);
  611. }
  612.       }
  613.       Utils.checkForRemainingOptions(options);
  614.     } catch (Exception e) {
  615.       throw new Exception("nWeka exception: " + e.getMessage()
  616.    + makeOptionString(classifier));
  617.     }
  618.     // Setup up evaluation objects
  619.     Evaluation trainingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
  620.     Evaluation testingEvaluation = new Evaluation(new Instances(template, 0), costMatrix);
  621.     
  622.     if (objectInputFileName.length() != 0) {
  623.       
  624.       // Load classifier from file
  625.       classifier = (Classifier) objectInputStream.readObject();
  626.       objectInputStream.close();
  627.     }
  628.     
  629.     // Build the classifier if no object file provided
  630.     if ((classifier instanceof UpdateableClassifier) &&
  631. (testFileName.length() != 0) &&
  632. (costMatrix == null) &&
  633. (trainFileName.length() != 0)) {
  634.       
  635.       // Build classifier incrementally
  636.       trainingEvaluation.setPriors(train);
  637.       testingEvaluation.setPriors(train);
  638.       trainTimeStart = System.currentTimeMillis();
  639.       if (objectInputFileName.length() == 0) {
  640. classifier.buildClassifier(train);
  641.       }
  642.       while (train.readInstance(trainReader)) {
  643. trainingEvaluation.updatePriors(train.instance(0));
  644. testingEvaluation.updatePriors(train.instance(0));
  645. ((UpdateableClassifier)classifier).
  646.   updateClassifier(train.instance(0));
  647. train.delete(0);
  648.       }
  649.       trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
  650.       trainReader.close();
  651.     } else if (objectInputFileName.length() == 0) {
  652.       
  653.       // Build classifier in one go
  654.       tempTrain = new Instances(train);
  655.       trainingEvaluation.setPriors(tempTrain);
  656.       testingEvaluation.setPriors(tempTrain);
  657.       trainTimeStart = System.currentTimeMillis();
  658.       classifier.buildClassifier(tempTrain);
  659.       trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
  660.     } 
  661.     // Save the classifier if an object output file is provided
  662.     if (objectOutputFileName.length() != 0) {
  663.       OutputStream os = new FileOutputStream(objectOutputFileName);
  664.       if (objectOutputFileName.endsWith(".gz")) {
  665.         os = new GZIPOutputStream(os);
  666.       }
  667.       ObjectOutputStream objectOutputStream = new ObjectOutputStream(os);
  668.       objectOutputStream.writeObject(classifier);
  669.       objectOutputStream.flush();
  670.       objectOutputStream.close();
  671.     }
  672.     // If classifier is drawable output string describing graph
  673.     if ((classifier instanceof Drawable)
  674. && (printGraph)){
  675.       return ((Drawable)classifier).graph();
  676.     }
  677.     // Output the classifier as equivalent source
  678.     if ((classifier instanceof Sourcable)
  679. && (printSource)){
  680.       return wekaStaticWrapper((Sourcable) classifier, sourceClass);
  681.     }
  682.     // Output test instance predictions only
  683.     if (printClassifications) {
  684.       return printClassifications(classifier, new Instances(template, 0),
  685.   testFileName, classIndex, attributesToOutput);
  686.     }
  687.     // Output model
  688.     if (!(noOutput || printMargins)) {
  689.       if (classifier instanceof OptionHandler) {
  690. if (schemeOptionsText != null) {
  691.   text.append("nOptions: "+schemeOptionsText);
  692.   text.append("n");
  693. }
  694.       }
  695.       text.append("n" + classifier.toString() + "n");
  696.     }
  697.     if (!printMargins && (costMatrix != null)) {
  698.       text.append("n=== Evaluation Cost Matrix ===nn")
  699.         .append(costMatrix.toString());
  700.     }
  701.     // Compute error estimate from training data
  702.     if ((trainStatistics) &&
  703. (trainFileName.length() != 0)) {
  704.       if ((classifier instanceof UpdateableClassifier) &&
  705.   (testFileName.length() != 0) &&
  706.   (costMatrix == null)) {
  707. // Classifier was trained incrementally, so we have to 
  708. // reopen the training data in order to test on it.
  709. trainReader = new BufferedReader(new FileReader(trainFileName));
  710. // Incremental testing
  711. train = new Instances(trainReader, 1);
  712. if (classIndex != -1) {
  713.   train.setClassIndex(classIndex - 1);
  714. } else {
  715.   train.setClassIndex(train.numAttributes() - 1);
  716. }
  717. testTimeStart = System.currentTimeMillis();
  718. while (train.readInstance(trainReader)) {
  719.   trainingEvaluation.
  720.   evaluateModelOnce((Classifier)classifier, 
  721.     train.instance(0));
  722.   train.delete(0);
  723. }
  724. testTimeElapsed = System.currentTimeMillis() - testTimeStart;
  725. trainReader.close();
  726.       } else {
  727. testTimeStart = System.currentTimeMillis();
  728. trainingEvaluation.evaluateModel(classifier, 
  729.  train);
  730. testTimeElapsed = System.currentTimeMillis() - testTimeStart;
  731.       }
  732.       // Print the results of the training evaluation
  733.       if (printMargins) {
  734. return trainingEvaluation.toCumulativeMarginDistributionString();
  735.       } else {
  736. text.append("nTime taken to build model: " +
  737.     Utils.doubleToString(trainTimeElapsed / 1000.0,2) +
  738.     " seconds");
  739. text.append("nTime taken to test model on training data: " +
  740.     Utils.doubleToString(testTimeElapsed / 1000.0,2) +
  741.     " seconds");
  742. text.append(trainingEvaluation.
  743.     toSummaryString("nn=== Error on training" + 
  744.     " data ===n", printComplexityStatistics));
  745. if (template.classAttribute().isNominal()) {
  746.   if (classStatistics) {
  747.     text.append("nn" + trainingEvaluation.toClassDetailsString());
  748.   }
  749.   text.append("nn" + trainingEvaluation.toMatrixString());
  750. }
  751.       }
  752.     }
  753.     // Compute proper error estimates
  754.     if (testFileName.length() != 0) {
  755.       // Testing is on the supplied test data
  756.       while (test.readInstance(testReader)) {
  757.   
  758. testingEvaluation.evaluateModelOnce((Classifier)classifier, 
  759.                                             test.instance(0));
  760. test.delete(0);
  761.       }
  762.       testReader.close();
  763.       text.append("nn" + testingEvaluation.
  764.   toSummaryString("=== Error on test data ===n",
  765.   printComplexityStatistics));
  766.     } else if (trainFileName.length() != 0) {
  767.       // Testing is via cross-validation on training data
  768.       random = new Random(seed);
  769.       random.setSeed(seed);
  770.       train.randomize(random);
  771.       testingEvaluation.
  772. crossValidateModel(classifier, train, folds);
  773.       if (template.classAttribute().isNumeric()) {
  774. text.append("nnn" + testingEvaluation.
  775.     toSummaryString("=== Cross-validation ===n",
  776.     printComplexityStatistics));
  777.       } else {
  778. text.append("nnn" + testingEvaluation.
  779.     toSummaryString("=== Stratified " + 
  780.     "cross-validation ===n",
  781.     printComplexityStatistics));
  782.       }
  783.     }
  784.     if (template.classAttribute().isNominal()) {
  785.       if (classStatistics) {
  786. text.append("nn" + testingEvaluation.toClassDetailsString());
  787.       }
  788.       text.append("nn" + testingEvaluation.toMatrixString());
  789.     }
  790.     return text.toString();
  791.   }
  792.   /**
  793.    * Attempts to load a cost matrix.
  794.    *
  795.    * @param costFileName the filename of the cost matrix
  796.    * @param numClasses the number of classes that should be in the cost matrix
  797.    * (only used if the cost file is in old format).
  798.    * @return a <code>CostMatrix</code> value, or null if costFileName is empty
  799.    * @exception Exception if an error occurs.
  800.    */
  801.   private static CostMatrix handleCostOption(String costFileName, 
  802.                                              int numClasses) 
  803.     throws Exception {
  804.     if ((costFileName != null) && (costFileName.length() != 0)) {
  805.       System.out.println(
  806.            "NOTE: The behaviour of the -m option has changed between WEKA 3.0"
  807.            +" and WEKA 3.1. -m now carries out cost-sensitive *evaluation*"
  808.            +" only. For cost-sensitive *prediction*, use one of the"
  809.            +" cost-sensitive metaschemes such as"
  810.            +" weka.classifiers.CostSensitiveClassifier or"
  811.            +" weka.classifiers.MetaCost");
  812.       Reader costReader = null;
  813.       try {
  814.         costReader = new BufferedReader(new FileReader(costFileName));
  815.       } catch (Exception e) {
  816.         throw new Exception("Can't open file " + e.getMessage() + '.');
  817.       }
  818.       try {
  819.         // First try as a proper cost matrix format
  820.         return new CostMatrix(costReader);
  821.       } catch (Exception ex) {
  822.         try {
  823.           // Now try as the poxy old format :-)
  824.           //System.err.println("Attempting to read old format cost file");
  825.           try {
  826.             costReader.close(); // Close the old one
  827.             costReader = new BufferedReader(new FileReader(costFileName));
  828.           } catch (Exception e) {
  829.             throw new Exception("Can't open file " + e.getMessage() + '.');
  830.           }
  831.           CostMatrix costMatrix = new CostMatrix(numClasses);
  832.           //System.err.println("Created default cost matrix");
  833.           costMatrix.readOldFormat(costReader);
  834.           return costMatrix;
  835.           //System.err.println("Read old format");
  836.         } catch (Exception e2) {
  837.           // re-throw the original exception
  838.           //System.err.println("Re-throwing original exception");
  839.           throw ex;
  840.         }
  841.       }
  842.     } else {
  843.       return null;
  844.     }
  845.   }
  846.   /**
  847.    * Evaluates the classifier on a given set of instances.
  848.    *
  849.    * @param classifier machine learning classifier
  850.    * @param data set of test instances for evaluation
  851.    * @exception Exception if model could not be evaluated 
  852.    * successfully
  853.    */
  854.   public void evaluateModel(Classifier classifier,
  855.     Instances data) throws Exception {
  856.     
  857.     double [] predicted;
  858.     for (int i = 0; i < data.numInstances(); i++) {
  859.       evaluateModelOnce((Classifier)classifier, 
  860. data.instance(i));
  861.     }
  862.   }
  863.   
  864.   /**
  865.    * Evaluates the classifier on a single instance.
  866.    *
  867.    * @param classifier machine learning classifier
  868.    * @param instance the test instance to be classified
  869.    * @return the prediction made by the clasifier
  870.    * @exception Exception if model could not be evaluated 
  871.    * successfully or the data contains string attributes
  872.    */
  873.   public double evaluateModelOnce(Classifier classifier,
  874.   Instance instance) throws Exception {
  875.   
  876.     Instance classMissing = (Instance)instance.copy();
  877.     double pred=0;
  878.     classMissing.setDataset(instance.dataset());
  879.     classMissing.setClassMissing();
  880.     if (m_ClassIsNominal) {
  881.       if (classifier instanceof DistributionClassifier) {
  882. double [] dist = ((DistributionClassifier)classifier).
  883.  distributionForInstance(classMissing);
  884. pred = Utils.maxIndex(dist);
  885. updateStatsForClassifier(dist,
  886.  instance);
  887.       } else {
  888. pred = classifier.classifyInstance(classMissing);
  889. updateStatsForClassifier(makeDistribution(pred),
  890.  instance);
  891.       }
  892.     } else {
  893.       pred = classifier.classifyInstance(classMissing);
  894.       updateStatsForPredictor(pred,
  895.       instance);
  896.     }
  897.     return pred;
  898.   }
  899.   /**
  900.    * Evaluates the supplied distribution on a single instance.
  901.    *
  902.    * @param dist the supplied distribution
  903.    * @param instance the test instance to be classified
  904.    * @exception Exception if model could not be evaluated 
  905.    * successfully
  906.    */
  907.   public double evaluateModelOnce(double [] dist, 
  908.   Instance instance) throws Exception {
  909.     double pred;
  910.     if (m_ClassIsNominal) {
  911.       pred = Utils.maxIndex(dist);
  912.       updateStatsForClassifier(dist, instance);
  913.     } else {
  914.       pred = dist[0];
  915.       updateStatsForPredictor(pred, instance);
  916.     }
  917.     return pred;
  918.   }
  919.   /**
  920.    * Evaluates the supplied prediction on a single instance.
  921.    *
  922.    * @param prediction the supplied prediction
  923.    * @param instance the test instance to be classified
  924.    * @exception Exception if model could not be evaluated 
  925.    * successfully
  926.    */
  927.   public void evaluateModelOnce(double prediction,
  928. Instance instance) throws Exception {
  929.     
  930.     if (m_ClassIsNominal) {
  931.       updateStatsForClassifier(makeDistribution(prediction), 
  932.        instance);
  933.     } else {
  934.       updateStatsForPredictor(prediction, instance);
  935.     }
  936.   }
  937.   /**
  938.    * Wraps a static classifier in enough source to test using the weka
  939.    * class libraries.
  940.    *
  941.    * @param classifier a Sourcable Classifier
  942.    * @param className the name to give to the source code class
  943.    * @return the source for a static classifier that can be tested with
  944.    * weka libraries.
  945.    */
  946.   protected static String wekaStaticWrapper(Sourcable classifier, 
  947.                                             String className) 
  948.     throws Exception {
  949.     
  950.     //String className = "StaticClassifier";
  951.     String staticClassifier = classifier.toSource(className);
  952.     return "package weka.classifiers;n"
  953.     +"import weka.core.Attribute;n"
  954.     +"import weka.core.Instance;n"
  955.     +"import weka.core.Instances;n"
  956.     +"import weka.classifiers.Classifier;nn"
  957.     +"public class WekaWrapper extends Classifier {nn"
  958.     +"  public void buildClassifier(Instances i) throws Exception {n"
  959.     +"  }nn"
  960.     +"  public double classifyInstance(Instance i) throws Exception {nn"
  961.     +"    Object [] s = new Object [i.numAttributes()];n"
  962.     +"    for (int j = 0; j < s.length; j++) {n"
  963.     +"      if (!i.isMissing(j)) {n"
  964.     +"        if (i.attribute(j).type() == Attribute.NOMINAL) {n"
  965.     +"          s[j] = i.attribute(j).value((int) i.value(j));n"
  966.     +"        } else if (i.attribute(j).type() == Attribute.NUMERIC) {n"
  967.     +"          s[j] = new Double(i.value(j));n"
  968.     +"        }n"
  969.     +"      }n"
  970.     +"    }n"
  971.     +"    return " + className + ".classify(s);n"
  972.     +"  }nn"
  973.     +"}nn"
  974.     +staticClassifier; // The static classifer class
  975.   }
  976.   /**
  977.    * Gets the number of test instances that had a known class value
  978.    * (actually the sum of the weights of test instances with known 
  979.    * class value).
  980.    *
  981.    * @return the number of test instances with known class
  982.    */
  983.   public final double numInstances() {
  984.     
  985.     return m_WithClass;
  986.   }
  987.   /**
  988.    * Gets the number of instances incorrectly classified (that is, for
  989.    * which an incorrect prediction was made). (Actually the sum of the weights
  990.    * of these instances)
  991.    *
  992.    * @return the number of incorrectly classified instances 
  993.    */
  994.   public final double incorrect() {
  995.     return m_Incorrect;
  996.   }
  997.   /**
  998.    * Gets the percentage of instances incorrectly classified (that is, for
  999.    * which an incorrect prediction was made).
  1000.    *
  1001.    * @return the percent of incorrectly classified instances 
  1002.    * (between 0 and 100)
  1003.    */
  1004.   public final double pctIncorrect() {
  1005.     return 100 * m_Incorrect / m_WithClass;
  1006.   }
  1007.   /**
  1008.    * Gets the total cost, that is, the cost of each prediction times the
  1009.    * weight of the instance, summed over all instances.
  1010.    *
  1011.    * @return the total cost
  1012.    */
  1013.   public final double totalCost() {
  1014.     return m_TotalCost;
  1015.   }
  1016.   
  1017.   /**
  1018.    * Gets the average cost, that is, total cost of misclassifications
  1019.    * (incorrect plus unclassified) over the total number of instances.
  1020.    *
  1021.    * @return the average cost.  
  1022.    */
  1023.   public final double avgCost() {
  1024.     return m_TotalCost / m_WithClass;
  1025.   }
  1026.   /**
  1027.    * Gets the number of instances correctly classified (that is, for
  1028.    * which a correct prediction was made). (Actually the sum of the weights
  1029.    * of these instances)
  1030.    *
  1031.    * @return the number of correctly classified instances
  1032.    */
  1033.   public final double correct() {
  1034.     
  1035.     return m_Correct;
  1036.   }
  1037.   /**
  1038.    * Gets the percentage of instances correctly classified (that is, for
  1039.    * which a correct prediction was made).
  1040.    *
  1041.    * @return the percent of correctly classified instances (between 0 and 100)
  1042.    */
  1043.   public final double pctCorrect() {
  1044.     
  1045.     return 100 * m_Correct / m_WithClass;
  1046.   }
  1047.   
  1048.   /**
  1049.    * Gets the number of instances not classified (that is, for
  1050.    * which no prediction was made by the classifier). (Actually the sum
  1051.    * of the weights of these instances)
  1052.    *
  1053.    * @return the number of unclassified instances
  1054.    */
  1055.   public final double unclassified() {
  1056.     
  1057.     return m_Unclassified;
  1058.   }
  1059.   /**
  1060.    * Gets the percentage of instances not classified (that is, for
  1061.    * which no prediction was made by the classifier).
  1062.    *
  1063.    * @return the percent of unclassified instances (between 0 and 100)
  1064.    */
  1065.   public final double pctUnclassified() {
  1066.     
  1067.     return 100 * m_Unclassified / m_WithClass;
  1068.   }
  1069.   /**
  1070.    * Returns the estimated error rate or the root mean squared error
  1071.    * (if the class is numeric). If a cost matrix was given this
  1072.    * error rate gives the average cost.
  1073.    *
  1074.    * @return the estimated error rate (between 0 and 1, or between 0 and 
  1075.    * maximum cost)
  1076.    */
  1077.   public final double errorRate() {
  1078.     if (!m_ClassIsNominal) {
  1079.       return Math.sqrt(m_SumSqrErr / m_WithClass);
  1080.     }
  1081.     if (m_CostMatrix == null) {
  1082.       return m_Incorrect / m_WithClass;
  1083.     } else {
  1084.       return avgCost();
  1085.     }
  1086.   }
  1087.   /**
  1088.    * Returns value of kappa statistic if class is nominal.
  1089.    *
  1090.    * @return the value of the kappa statistic
  1091.    */
  1092.   public final double kappa() {
  1093.     
  1094.     double[] sumRows = new double[m_ConfusionMatrix.length];
  1095.     double[] sumColumns = new double[m_ConfusionMatrix.length];
  1096.     double sumOfWeights = 0;
  1097.     for (int i = 0; i < m_ConfusionMatrix.length; i++) {
  1098.       for (int j = 0; j < m_ConfusionMatrix.length; j++) {
  1099. sumRows[i] += m_ConfusionMatrix[i][j];
  1100. sumColumns[j] += m_ConfusionMatrix[i][j];
  1101. sumOfWeights += m_ConfusionMatrix[i][j];
  1102.       }
  1103.     }
  1104.     double correct = 0, chanceAgreement = 0;
  1105.     for (int i = 0; i < m_ConfusionMatrix.length; i++) {
  1106.       chanceAgreement += (sumRows[i] * sumColumns[i]);
  1107.       correct += m_ConfusionMatrix[i][i];
  1108.     }
  1109.     chanceAgreement /= (sumOfWeights * sumOfWeights);
  1110.     correct /= sumOfWeights;
  1111.     if (chanceAgreement < 1) {
  1112.       return (correct - chanceAgreement) / (1 - chanceAgreement);
  1113.     } else {
  1114.       return 1;
  1115.     }
  1116.   }
  1117.   /**
  1118.    * Returns the correlation coefficient if the class is numeric.
  1119.    *
  1120.    * @return the correlation coefficient
  1121.    * @exception Exception if class is not numeric
  1122.    */
  1123.   public final double correlationCoefficient() throws Exception {
  1124.     if (m_ClassIsNominal) {
  1125.       throw
  1126. new Exception("Can't compute correlation coefficient: " + 
  1127.       "class is nominal!");
  1128.     }
  1129.     double correlation = 0;
  1130.     double varActual = 
  1131.       m_SumSqrClass - m_SumClass * m_SumClass / m_WithClass;
  1132.     double varPredicted = 
  1133.       m_SumSqrPredicted - m_SumPredicted * m_SumPredicted / 
  1134.       m_WithClass;
  1135.     double varProd = 
  1136.       m_SumClassPredicted - m_SumClass * m_SumPredicted / m_WithClass;
  1137.     if (Utils.smOrEq(varActual * varPredicted, 0.0)) {
  1138.       correlation = 0.0;
  1139.     } else {
  1140.       correlation = varProd / Math.sqrt(varActual * varPredicted);
  1141.     }
  1142.     return correlation;
  1143.   }
  1144.   /**
  1145.    * Returns the mean absolute error. Refers to the error of the
  1146.    * predicted values for numeric classes, and the error of the 
  1147.    * predicted probability distribution for nominal classes.
  1148.    *
  1149.    * @return the mean absolute error 
  1150.    */
  1151.   public final double meanAbsoluteError() {
  1152.     return m_SumAbsErr / m_WithClass;
  1153.   }
  1154.   /**
  1155.    * Returns the mean absolute error of the prior.
  1156.    *
  1157.    * @return the mean absolute error 
  1158.    */
  1159.   public final double meanPriorAbsoluteError() {
  1160.     return m_SumPriorAbsErr / m_WithClass;
  1161.   }
  1162.   /**
  1163.    * Returns the relative absolute error.
  1164.    *
  1165.    * @return the relative absolute error 
  1166.    * @exception Exception if it can't be computed
  1167.    */
  1168.   public final double relativeAbsoluteError() throws Exception {
  1169.     return 100 * meanAbsoluteError() / meanPriorAbsoluteError();
  1170.   }
  1171.   
  1172.   /**
  1173.    * Returns the root mean squared error.
  1174.    *
  1175.    * @return the root mean squared error 
  1176.    */
  1177.   public final double rootMeanSquaredError() {
  1178.     return Math.sqrt(m_SumSqrErr / m_WithClass);
  1179.   }
  1180.   
  1181.   /**
  1182.    * Returns the root mean prior squared error.
  1183.    *
  1184.    * @return the root mean prior squared error 
  1185.    */
  1186.   public final double rootMeanPriorSquaredError() {
  1187.     return Math.sqrt(m_SumPriorSqrErr / m_WithClass);
  1188.   }
  1189.   
  1190.   /**
  1191.    * Returns the root relative squared error if the class is numeric.
  1192.    *
  1193.    * @return the root relative squared error 
  1194.    */
  1195.   public final double rootRelativeSquaredError() {
  1196.     return 100.0 * rootMeanSquaredError() / 
  1197.       rootMeanPriorSquaredError();
  1198.   }
  1199.   /**
  1200.    * Calculate the entropy of the prior distribution
  1201.    *
  1202.    * @return the entropy of the prior distribution
  1203.    * @exception Exception if the class is not nominal
  1204.    */
  1205.   public final double priorEntropy() throws Exception {
  1206.     if (!m_ClassIsNominal) {
  1207.       throw
  1208. new Exception("Can't compute entropy of class prior: " + 
  1209.       "class numeric!");
  1210.     }
  1211.     double entropy = 0;
  1212.     for(int i = 0; i < m_NumClasses; i++) {
  1213.       entropy -= m_ClassPriors[i] / m_ClassPriorsSum 
  1214. * Utils.log2(m_ClassPriors[i] / m_ClassPriorsSum);
  1215.     }
  1216.     return entropy;
  1217.   }
  1218.   /**
  1219.    * Return the total Kononenko & Bratko Information score in bits
  1220.    *
  1221.    * @return the K&B information score
  1222.    * @exception Exception if the class is not nominal
  1223.    */
  1224.   public final double KBInformation() throws Exception {
  1225.     if (!m_ClassIsNominal) {
  1226.       throw
  1227. new Exception("Can't compute K&B Info score: " + 
  1228.       "class numeric!");
  1229.     }
  1230.     return m_SumKBInfo;
  1231.   }
  1232.   /**
  1233.    * Return the Kononenko & Bratko Information score in bits per 
  1234.    * instance.
  1235.    *
  1236.    * @return the K&B information score
  1237.    * @exception Exception if the class is not nominal
  1238.    */
  1239.   public final double KBMeanInformation() throws Exception {
  1240.     if (!m_ClassIsNominal) {
  1241.       throw
  1242. new Exception("Can't compute K&B Info score: "
  1243.        + "class numeric!");
  1244.     }
  1245.     return m_SumKBInfo / m_WithClass;
  1246.   }
  1247.   /**
  1248.    * Return the Kononenko & Bratko Relative Information score
  1249.    *
  1250.    * @return the K&B relative information score
  1251.    * @exception Exception if the class is not nominal
  1252.    */
  1253.   public final double KBRelativeInformation() throws Exception {
  1254.     if (!m_ClassIsNominal) {
  1255.       throw
  1256. new Exception("Can't compute K&B Info score: " + 
  1257.       "class numeric!");
  1258.     }
  1259.     return 100.0 * KBInformation() / priorEntropy();
  1260.   }
  1261.   /**
  1262.    * Returns the total entropy for the null model
  1263.    * 
  1264.    * @return the total null model entropy
  1265.    */
  1266.   public final double SFPriorEntropy() {
  1267.     return m_SumPriorEntropy;
  1268.   }
  1269.   /**
  1270.    * Returns the entropy per instance for the null model
  1271.    * 
  1272.    * @return the null model entropy per instance
  1273.    */
  1274.   public final double SFMeanPriorEntropy() {
  1275.     return m_SumPriorEntropy / m_WithClass;
  1276.   }
  1277.   /**
  1278.    * Returns the total entropy for the scheme
  1279.    * 
  1280.    * @return the total scheme entropy
  1281.    */
  1282.   public final double SFSchemeEntropy() {
  1283.     return m_SumSchemeEntropy;
  1284.   }
  1285.   /**
  1286.    * Returns the entropy per instance for the scheme
  1287.    * 
  1288.    * @return the scheme entropy per instance
  1289.    */
  1290.   public final double SFMeanSchemeEntropy() {
  1291.     return m_SumSchemeEntropy / m_WithClass;
  1292.   }
  1293.   /**
  1294.    * Returns the total SF, which is the null model entropy minus
  1295.    * the scheme entropy.
  1296.    * 
  1297.    * @return the total SF
  1298.    */
  1299.   public final double SFEntropyGain() {
  1300.     return m_SumPriorEntropy - m_SumSchemeEntropy;
  1301.   }
  1302.   /**
  1303.    * Returns the SF per instance, which is the null model entropy
  1304.    * minus the scheme entropy, per instance.
  1305.    * 
  1306.    * @return the SF per instance
  1307.    */
  1308.   public final double SFMeanEntropyGain() {
  1309.     
  1310.     return (m_SumPriorEntropy - m_SumSchemeEntropy) / m_WithClass;
  1311.   }
  1312.   /**
  1313.    * Output the cumulative margin distribution as a string suitable
  1314.    * for input for gnuplot or similar package.
  1315.    *
  1316.    * @return the cumulative margin distribution
  1317.    * @exception Exception if the class attribute is nominal
  1318.    */
  1319.   public String toCumulativeMarginDistributionString() throws Exception {
  1320.     if (!m_ClassIsNominal) {
  1321.       throw new Exception("Class must be nominal for margin distributions");
  1322.     }
  1323.     String result = "";
  1324.     double cumulativeCount = 0;
  1325.     double margin;
  1326.     for(int i = 0; i <= k_MarginResolution; i++) {
  1327.       if (m_MarginCounts[i] != 0) {
  1328. cumulativeCount += m_MarginCounts[i];
  1329. margin = (double)i * 2.0 / k_MarginResolution - 1.0;
  1330. result = result + Utils.doubleToString(margin, 7, 3) + ' ' 
  1331. + Utils.doubleToString(cumulativeCount * 100 
  1332.        / m_WithClass, 7, 3) + 'n';
  1333.       } else if (i == 0) {
  1334. result = Utils.doubleToString(-1.0, 7, 3) + ' ' 
  1335. + Utils.doubleToString(0, 7, 3) + 'n';
  1336.       }
  1337.     }
  1338.     return result;
  1339.   }
  1340.   /**
  1341.    * Calls toSummaryString() with no title and no complexity stats
  1342.    *
  1343.    * @return a summary description of the classifier evaluation
  1344.    */
  1345.   public String toSummaryString() {
  1346.     return toSummaryString("", false);
  1347.   }
  1348.   /**
  1349.    * Calls toSummaryString() with a default title.
  1350.    *
  1351.    * @param printComplexityStatistics if true, complexity statistics are
  1352.    * returned as well
  1353.    */
  1354.   public String toSummaryString(boolean printComplexityStatistics) {
  1355.     
  1356.     return toSummaryString("=== Summary ===n", printComplexityStatistics);
  1357.   }
  1358.   /**
  1359.    * Outputs the performance statistics in summary form. Lists 
  1360.    * number (and percentage) of instances classified correctly, 
  1361.    * incorrectly and unclassified. Outputs the total number of 
  1362.    * instances classified, and the number of instances (if any) 
  1363.    * that had no class value provided. 
  1364.    *
  1365.    * @param title the title for the statistics
  1366.    * @param printComplexityStatistics if true, complexity statistics are
  1367.    * returned as well
  1368.    * @return the summary as a String
  1369.    */
  1370.   public String toSummaryString(String title, 
  1371. boolean printComplexityStatistics) { 
  1372.     
  1373.     double mae, mad = 0;
  1374.     StringBuffer text = new StringBuffer();
  1375.     text.append(title + "n");
  1376.     try {
  1377.       if (m_WithClass > 0) {
  1378. if (m_ClassIsNominal) {
  1379.   text.append("Correctly Classified Instances     ");
  1380.   text.append(Utils.doubleToString(correct(), 12, 4) + "     " +
  1381.       Utils.doubleToString(pctCorrect(),
  1382.    12, 4) + " %n");
  1383.   text.append("Incorrectly Classified Instances   ");
  1384.   text.append(Utils.doubleToString(incorrect(), 12, 4) + "     " +
  1385.       Utils.doubleToString(pctIncorrect(),
  1386.    12, 4) + " %n");
  1387.   text.append("Kappa statistic                    ");
  1388.   text.append(Utils.doubleToString(kappa(), 12, 4) + "n");
  1389.   
  1390.   if (m_CostMatrix != null) {
  1391.     text.append("Total Cost                         ");
  1392.     text.append(Utils.doubleToString(totalCost(), 12, 4) + "n");
  1393.     text.append("Average Cost                       ");
  1394.     text.append(Utils.doubleToString(avgCost(), 12, 4) + "n");
  1395.   }
  1396.   if (printComplexityStatistics) {
  1397.     text.append("K&B Relative Info Score            ");
  1398.     text.append(Utils.doubleToString(KBRelativeInformation(), 12, 4) 
  1399. + " %n");
  1400.     text.append("K&B Information Score              ");
  1401.     text.append(Utils.doubleToString(KBInformation(), 12, 4) 
  1402. + " bits");
  1403.     text.append(Utils.doubleToString(KBMeanInformation(), 12, 4) 
  1404. + " bits/instancen");
  1405.   }
  1406. } else {        
  1407.   text.append("Correlation coefficient            ");
  1408.   text.append(Utils.doubleToString(correlationCoefficient(), 12 , 4) +
  1409.       "n");
  1410. }
  1411. if (printComplexityStatistics) {
  1412.   text.append("Class complexity | order 0         ");
  1413.   text.append(Utils.doubleToString(SFPriorEntropy(), 12, 4) 
  1414.       + " bits");
  1415.   text.append(Utils.doubleToString(SFMeanPriorEntropy(), 12, 4) 
  1416.       + " bits/instancen");
  1417.   text.append("Class complexity | scheme          ");
  1418.   text.append(Utils.doubleToString(SFSchemeEntropy(), 12, 4) 
  1419.       + " bits");
  1420.   text.append(Utils.doubleToString(SFMeanSchemeEntropy(), 12, 4) 
  1421.       + " bits/instancen");
  1422.   text.append("Complexity improvement     (Sf)    ");
  1423.   text.append(Utils.doubleToString(SFEntropyGain(), 12, 4) + " bits");
  1424.   text.append(Utils.doubleToString(SFMeanEntropyGain(), 12, 4) 
  1425.       + " bits/instancen");
  1426. }
  1427. text.append("Mean absolute error                ");
  1428. text.append(Utils.doubleToString(meanAbsoluteError(), 12, 4) 
  1429.     + "n");
  1430. text.append("Root mean squared error            ");
  1431. text.append(Utils.
  1432.     doubleToString(rootMeanSquaredError(), 12, 4) 
  1433.     + "n");
  1434. text.append("Relative absolute error            ");
  1435. text.append(Utils.doubleToString(relativeAbsoluteError(), 
  1436.  12, 4) + " %n");
  1437. text.append("Root relative squared error        ");
  1438. text.append(Utils.doubleToString(rootRelativeSquaredError(), 
  1439.  12, 4) + " %n");
  1440.       }
  1441.       if (Utils.gr(unclassified(), 0)) {
  1442. text.append("UnClassified Instances             ");
  1443. text.append(Utils.doubleToString(unclassified(), 12,4) +  "     " +
  1444.     Utils.doubleToString(pctUnclassified(),
  1445.  12, 4) + " %n");
  1446.       }
  1447.       text.append("Total Number of Instances          ");
  1448.       text.append(Utils.doubleToString(m_WithClass, 12, 4) + "n");
  1449.       if (m_MissingClass > 0) {
  1450. text.append("Ignored Class Unknown Instances            ");
  1451. text.append(Utils.doubleToString(m_MissingClass, 12, 4) + "n");
  1452.       }
  1453.     } catch (Exception ex) {
  1454.       // Should never occur since the class is known to be nominal 
  1455.       // here
  1456.       System.err.println("Arggh - Must be a bug in Evaluation class");
  1457.     }
  1458.    
  1459.     return text.toString(); 
  1460.   }
  1461.   
  1462.   /**
  1463.    * Calls toMatrixString() with a default title.
  1464.    *
  1465.    * @return the confusion matrix as a string
  1466.    * @exception Exception if the class is numeric
  1467.    */
  1468.   public String toMatrixString() throws Exception {
  1469.     return toMatrixString("=== Confusion Matrix ===n");
  1470.   }
  1471.   /**
  1472.    * Outputs the performance statistics as a classification confusion
  1473.    * matrix. For each class value, shows the distribution of 
  1474.    * predicted class values.
  1475.    *
  1476.    * @param title the title for the confusion matrix
  1477.    * @return the confusion matrix as a String
  1478.    * @exception Exception if the class is numeric
  1479.    */
  1480.   public String toMatrixString(String title) throws Exception {
  1481.     StringBuffer text = new StringBuffer();
  1482.     char [] IDChars = {'a','b','c','d','e','f','g','h','i','j',
  1483.        'k','l','m','n','o','p','q','r','s','t',
  1484.        'u','v','w','x','y','z'};
  1485.     int IDWidth;
  1486.     boolean fractional = false;
  1487.     if (!m_ClassIsNominal) {
  1488.       throw new Exception("Evaluation: No confusion matrix possible!");
  1489.     }
  1490.     // Find the maximum value in the matrix
  1491.     // and check for fractional display requirement 
  1492.     double maxval = 0;
  1493.     for(int i = 0; i < m_NumClasses; i++) {
  1494.       for(int j = 0; j < m_NumClasses; j++) {
  1495. double current = m_ConfusionMatrix[i][j];
  1496.         if (current < 0) {
  1497.           current *= -10;
  1498.         }
  1499. if (current > maxval) {
  1500.   maxval = current;
  1501. }
  1502. double fract = current - Math.rint(current);
  1503. if (!fractional
  1504.     && ((Math.log(fract) / Math.log(10)) >= -2)) {
  1505.   fractional = true;
  1506. }
  1507.       }
  1508.     }
  1509.     IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10) 
  1510.  + (fractional ? 3 : 0)),
  1511.      (int)(Math.log(m_NumClasses) / 
  1512.    Math.log(IDChars.length)));
  1513.     text.append(title).append("n");
  1514.     for(int i = 0; i < m_NumClasses; i++) {
  1515.       if (fractional) {
  1516. text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3))
  1517.           .append("   ");
  1518.       } else {
  1519. text.append(" ").append(num2ShortID(i,IDChars,IDWidth));
  1520.       }
  1521.     }
  1522.     text.append("   <-- classified asn");
  1523.     for(int i = 0; i< m_NumClasses; i++) { 
  1524.       for(int j = 0; j < m_NumClasses; j++) {
  1525. text.append(" ").append(
  1526.     Utils.doubleToString(m_ConfusionMatrix[i][j],
  1527.  IDWidth,
  1528.  (fractional ? 2 : 0)));
  1529.       }
  1530.       text.append(" | ").append(num2ShortID(i,IDChars,IDWidth))
  1531.         .append(" = ").append(m_ClassNames[i]).append("n");
  1532.     }
  1533.     return text.toString();
  1534.   }
  1535.   public String toClassDetailsString() throws Exception {
  1536.     return toClassDetailsString("=== Detailed Accuracy By Class ===n");
  1537.   }
  1538.   /**
  1539.    * Generates a breakdown of the accuracy for each class,
  1540.    * incorporating various information-retrieval statistics, such as
  1541.    * true/false positive rate, precision/recall/F-Measure.  Should be
  1542.    * useful for ROC curves, recall/precision curves.  
  1543.    * 
  1544.    * @param title the title to prepend the stats string with 
  1545.    * @return the statistics presented as a string
  1546.    */
  1547.   public String toClassDetailsString(String title) throws Exception {
  1548.     if (!m_ClassIsNominal) {
  1549.       throw new Exception("Evaluation: No confusion matrix possible!");
  1550.     }
  1551.     StringBuffer text = new StringBuffer(title 
  1552.  + "nTP Rate   FP Rate"
  1553.                                          + "   Precision   Recall"
  1554.                                          + "  F-Measure   Classn");
  1555.     for(int i = 0; i < m_NumClasses; i++) {
  1556.       text.append(Utils.doubleToString(truePositiveRate(i), 7, 3))
  1557.         .append("   ");
  1558.       text.append(Utils.doubleToString(falsePositiveRate(i), 7, 3))
  1559.         .append("    ");
  1560.       text.append(Utils.doubleToString(precision(i), 7, 3))
  1561.         .append("   ");
  1562.       text.append(Utils.doubleToString(recall(i), 7, 3))
  1563.         .append("   ");
  1564.       text.append(Utils.doubleToString(fMeasure(i), 7, 3))
  1565.         .append("    ");
  1566.       text.append(m_ClassNames[i]).append('n');
  1567.     }
  1568.     return text.toString();
  1569.   }
  1570.   /**
  1571.    * Calculate the number of true positives with respect to a particular class. 
  1572.    * This is defined as<p>
  1573.    * <pre>
  1574.    * correctly classified positives
  1575.    * </pre>
  1576.    *
  1577.    * @param classIndex the index of the class to consider as "positive"
  1578.    * @return the true positive rate
  1579.    */
  1580.   public double numTruePositives(int classIndex) {
  1581.     double correct = 0;
  1582.     for (int j = 0; j < m_NumClasses; j++) {
  1583.       if (j == classIndex) {
  1584. correct += m_ConfusionMatrix[classIndex][j];
  1585.       }
  1586.     }
  1587.     return correct;
  1588.   }
  1589.   /**
  1590.    * Calculate the true positive rate with respect to a particular class. 
  1591.    * This is defined as<p>
  1592.    * <pre>
  1593.    * correctly classified positives
  1594.    * ------------------------------
  1595.    *       total positives
  1596.    * </pre>
  1597.    *
  1598.    * @param classIndex the index of the class to consider as "positive"
  1599.    * @return the true positive rate
  1600.    */
  1601.   public double truePositiveRate(int classIndex) {
  1602.     double correct = 0, total = 0;
  1603.     for (int j = 0; j < m_NumClasses; j++) {
  1604.       if (j == classIndex) {
  1605. correct += m_ConfusionMatrix[classIndex][j];
  1606.       }
  1607.       total += m_ConfusionMatrix[classIndex][j];
  1608.     }
  1609.     if (total == 0) {
  1610.       return 0;
  1611.     }
  1612.     return correct / total;
  1613.   }
  1614.   /**
  1615.    * Calculate the number of true negatives with respect to a particular class. 
  1616.    * This is defined as<p>
  1617.    * <pre>
  1618.    * correctly classified negatives
  1619.    * </pre>
  1620.    *
  1621.    * @param classIndex the index of the class to consider as "positive"
  1622.    * @return the true positive rate
  1623.    */
  1624.   public double numTrueNegatives(int classIndex) {
  1625.     double correct = 0;
  1626.     for (int i = 0; i < m_NumClasses; i++) {
  1627.       if (i != classIndex) {
  1628. for (int j = 0; j < m_NumClasses; j++) {
  1629.   if (j != classIndex) {
  1630.     correct += m_ConfusionMatrix[i][j];
  1631.   }
  1632. }
  1633.       }
  1634.     }
  1635.     return correct;
  1636.   }
  1637.   /**
  1638.    * Calculate the true negative rate with respect to a particular class. 
  1639.    * This is defined as<p>
  1640.    * <pre>
  1641.    * correctly classified negatives
  1642.    * ------------------------------
  1643.    *       total negatives
  1644.    * </pre>
  1645.    *
  1646.    * @param classIndex the index of the class to consider as "positive"
  1647.    * @return the true positive rate
  1648.    */
  1649.   public double trueNegativeRate(int classIndex) {
  1650.     double correct = 0, total = 0;
  1651.     for (int i = 0; i < m_NumClasses; i++) {
  1652.       if (i != classIndex) {
  1653. for (int j = 0; j < m_NumClasses; j++) {
  1654.   if (j != classIndex) {
  1655.     correct += m_ConfusionMatrix[i][j];
  1656.   }
  1657.   total += m_ConfusionMatrix[i][j];
  1658. }
  1659.       }
  1660.     }
  1661.     if (total == 0) {
  1662.       return 0;
  1663.     }
  1664.     return correct / total;
  1665.   }
  1666.   /**
  1667.    * Calculate number of false positives with respect to a particular class. 
  1668.    * This is defined as<p>
  1669.    * <pre>
  1670.    * incorrectly classified negatives
  1671.    * </pre>
  1672.    *
  1673.    * @param classIndex the index of the class to consider as "positive"
  1674.    * @return the false positive rate
  1675.    */
  1676.   public double numFalsePositives(int classIndex) {
  1677.     double incorrect = 0;
  1678.     for (int i = 0; i < m_NumClasses; i++) {
  1679.       if (i != classIndex) {
  1680. for (int j = 0; j < m_NumClasses; j++) {
  1681.   if (j == classIndex) {
  1682.     incorrect += m_ConfusionMatrix[i][j];
  1683.   }
  1684. }
  1685.       }
  1686.     }
  1687.     return incorrect;
  1688.   }
  1689.   /**
  1690.    * Calculate the false positive rate with respect to a particular class. 
  1691.    * This is defined as<p>
  1692.    * <pre>
  1693.    * incorrectly classified negatives
  1694.    * --------------------------------
  1695.    *        total negatives
  1696.    * </pre>
  1697.    *
  1698.    * @param classIndex the index of the class to consider as "positive"
  1699.    * @return the false positive rate
  1700.    */
  1701.   public double falsePositiveRate(int classIndex) {
  1702.     double incorrect = 0, total = 0;
  1703.     for (int i = 0; i < m_NumClasses; i++) {
  1704.       if (i != classIndex) {
  1705. for (int j = 0; j < m_NumClasses; j++) {
  1706.   if (j == classIndex) {
  1707.     incorrect += m_ConfusionMatrix[i][j];
  1708.   }
  1709.   total += m_ConfusionMatrix[i][j];
  1710. }
  1711.       }
  1712.     }
  1713.     if (total == 0) {
  1714.       return 0;
  1715.     }
  1716.     return incorrect / total;
  1717.   }
  1718.   /**
  1719.    * Calculate number of false negatives with respect to a particular class. 
  1720.    * This is defined as<p>
  1721.    * <pre>
  1722.    * incorrectly classified positives
  1723.    * </pre>
  1724.    *
  1725.    * @param classIndex the index of the class to consider as "positive"
  1726.    * @return the false positive rate
  1727.    */
  1728.   public double numFalseNegatives(int classIndex) {
  1729.     double incorrect = 0;
  1730.     for (int i = 0; i < m_NumClasses; i++) {
  1731.       if (i == classIndex) {
  1732. for (int j = 0; j < m_NumClasses; j++) {
  1733.   if (j != classIndex) {
  1734.     incorrect += m_ConfusionMatrix[i][j];
  1735.   }
  1736. }
  1737.       }
  1738.     }
  1739.     return incorrect;
  1740.   }
  1741.   /**
  1742.    * Calculate the false negative rate with respect to a particular class. 
  1743.    * This is defined as<p>
  1744.    * <pre>
  1745.    * incorrectly classified positives
  1746.    * --------------------------------
  1747.    *        total positives
  1748.    * </pre>
  1749.    *
  1750.    * @param classIndex the index of the class to consider as "positive"
  1751.    * @return the false positive rate
  1752.    */
  1753.   public double falseNegativeRate(int classIndex) {
  1754.     double incorrect = 0, total = 0;
  1755.     for (int i = 0; i < m_NumClasses; i++) {
  1756.       if (i == classIndex) {
  1757. for (int j = 0; j < m_NumClasses; j++) {
  1758.   if (j != classIndex) {
  1759.     incorrect += m_ConfusionMatrix[i][j];
  1760.   }
  1761.   total += m_ConfusionMatrix[i][j];
  1762. }
  1763.       }
  1764.     }
  1765.     if (total == 0) {
  1766.       return 0;
  1767.     }
  1768.     return incorrect / total;
  1769.   }
  1770.   /**
  1771.    * Calculate the recall with respect to a particular class. 
  1772.    * This is defined as<p>
  1773.    * <pre>
  1774.    * correctly classified positives
  1775.    * ------------------------------
  1776.    *       total positives
  1777.    * </pre><p>
  1778.    * (Which is also the same as the truePositiveRate.)
  1779.    *
  1780.    * @param classIndex the index of the class to consider as "positive"
  1781.    * @return the recall
  1782.    */
  1783.   public double recall(int classIndex) {
  1784.     return truePositiveRate(classIndex);
  1785.   }
  1786.   /**
  1787.    * Calculate the precision with respect to a particular class. 
  1788.    * This is defined as<p>
  1789.    * <pre>
  1790.    * correctly classified positives
  1791.    * ------------------------------
  1792.    *  total predicted as positive
  1793.    * </pre>
  1794.    *
  1795.    * @param classIndex the index of the class to consider as "positive"
  1796.    * @return the precision
  1797.    */
  1798.   public double precision(int classIndex) {
  1799.     double correct = 0, total = 0;
  1800.     for (int i = 0; i < m_NumClasses; i++) {
  1801.       if (i == classIndex) {
  1802. correct += m_ConfusionMatrix[i][classIndex];
  1803.       }
  1804.       total += m_ConfusionMatrix[i][classIndex];
  1805.     }
  1806.     if (total == 0) {
  1807.       return 0;
  1808.     }
  1809.     return correct / total;
  1810.   }
  1811.   /**
  1812.    * Calculate the F-Measure with respect to a particular class. 
  1813.    * This is defined as<p>
  1814.    * <pre>
  1815.    * 2 * recall * precision
  1816.    * ----------------------
  1817.    *   recall + precision
  1818.    * </pre>
  1819.    *
  1820.    * @param classIndex the index of the class to consider as "positive"
  1821.    * @return the F-Measure
  1822.    */
  1823.   public double fMeasure(int classIndex) {
  1824.     double precision = precision(classIndex);
  1825.     double recall = recall(classIndex);
  1826.     if ((precision + recall) == 0) {
  1827.       return 0;
  1828.     }
  1829.     return 2 * precision * recall / (precision + recall);
  1830.   }
  1831.   /**
  1832.    * Sets the class prior probabilities
  1833.    *
  1834.    * @param train the training instances used to determine
  1835.    * the prior probabilities
  1836.    * @exception Exception if the class attribute of the instances is not
  1837.    * set
  1838.    */
  1839.   public void setPriors(Instances train) throws Exception {
  1840.     if (!m_ClassIsNominal) {
  1841.       m_NumTrainClassVals = 0;
  1842.       m_TrainClassVals = null;
  1843.       m_TrainClassWeights = null;
  1844.       m_PriorErrorEstimator = null;
  1845.       m_ErrorEstimator = null;
  1846.       for (int i = 0; i < train.numInstances(); i++) {
  1847. Instance currentInst = train.instance(i);
  1848. if (!currentInst.classIsMissing()) {
  1849.   addNumericTrainClass(currentInst.classValue(), 
  1850.   currentInst.weight());
  1851. }
  1852.       }
  1853.     } else {
  1854.       for (int i = 0; i < m_NumClasses; i++) {
  1855. m_ClassPriors[i] = 1;
  1856.       }
  1857.       m_ClassPriorsSum = m_NumClasses;
  1858.       for (int i = 0; i < train.numInstances(); i++) {
  1859. if (!train.instance(i).classIsMissing()) {
  1860.   m_ClassPriors[(int)train.instance(i).classValue()] += 
  1861.     train.instance(i).weight();
  1862.   m_ClassPriorsSum += train.instance(i).weight();
  1863. }
  1864.       }
  1865.     }
  1866.   }
  1867.   /**
  1868.    * Updates the class prior probabilities (when incrementally 
  1869.    * training)
  1870.    *
  1871.    * @param instance the new training instance seen
  1872.    * @exception Exception if the class of the instance is not
  1873.    * set
  1874.    */
  1875.   public void updatePriors(Instance instance) throws Exception
  1876.   {
  1877.     if (!instance.classIsMissing()) {
  1878.       if (!m_ClassIsNominal) {
  1879. if (!instance.classIsMissing()) {
  1880.   addNumericTrainClass(instance.classValue(), 
  1881.        instance.weight());
  1882. }
  1883.       } else {
  1884. m_ClassPriors[(int)instance.classValue()] += 
  1885.   instance.weight();
  1886. m_ClassPriorsSum += instance.weight();
  1887.       }
  1888.     }    
  1889.   }
  1890.   /**
  1891.    * Tests whether the current evaluation object is equal to another
  1892.    * evaluation object
  1893.    *
  1894.    * @param obj the object to compare against
  1895.    * @return true if the two objects are equal
  1896.    */
  1897.   public boolean equals(Object obj) {
  1898.     if ((obj == null) || !(obj.getClass().equals(this.getClass()))) {
  1899.       return false;
  1900.     }
  1901.     Evaluation cmp = (Evaluation) obj;
  1902.     if (m_ClassIsNominal != cmp.m_ClassIsNominal) return false;
  1903.     if (m_NumClasses != cmp.m_NumClasses) return false;
  1904.     if (m_Incorrect != cmp.m_Incorrect) return false;
  1905.     if (m_Correct != cmp.m_Correct) return false;
  1906.     if (m_Unclassified != cmp.m_Unclassified) return false;
  1907.     if (m_MissingClass != cmp.m_MissingClass) return false;
  1908.     if (m_WithClass != cmp.m_WithClass) return false;
  1909.     if (m_SumErr != cmp.m_SumErr) return false;
  1910.     if (m_SumAbsErr != cmp.m_SumAbsErr) return false;
  1911.     if (m_SumSqrErr != cmp.m_SumSqrErr) return false;
  1912.     if (m_SumClass != cmp.m_SumClass) return false;
  1913.     if (m_SumSqrClass != cmp.m_SumSqrClass) return false;
  1914.     if (m_SumPredicted != cmp.m_SumPredicted) return false;
  1915.     if (m_SumSqrPredicted != cmp.m_SumSqrPredicted) return false;
  1916.     if (m_SumClassPredicted != cmp.m_SumClassPredicted) return false;
  1917.     if (m_ClassIsNominal) {
  1918.       for (int i = 0; i < m_NumClasses; i++) {
  1919. for (int j = 0; j < m_NumClasses; j++) {
  1920.   if (m_ConfusionMatrix[i][j] != cmp.m_ConfusionMatrix[i][j]) {
  1921.     return false;
  1922.   }
  1923. }
  1924.       }
  1925.     }
  1926.     
  1927.     return true;
  1928.   }
  1929.   /**
  1930.    * Prints the predictions for the given dataset into a String variable.
  1931.    */
  1932.   private static String printClassifications(Classifier classifier, 
  1933.      Instances train,
  1934.      String testFileName,
  1935.      int classIndex,
  1936.      Range attributesToOutput) throws Exception {
  1937.     StringBuffer text = new StringBuffer();
  1938.     if (testFileName.length() != 0) {
  1939.       BufferedReader testReader = null;
  1940.       try {
  1941. testReader = new BufferedReader(new FileReader(testFileName));
  1942.       } catch (Exception e) {
  1943. throw new Exception("Can't open file " + e.getMessage() + '.');
  1944.       }
  1945.       Instances test = new Instances(testReader, 1);
  1946.       if (classIndex != -1) {
  1947. test.setClassIndex(classIndex - 1);
  1948.       } else {
  1949. test.setClassIndex(test.numAttributes() - 1);
  1950.       }
  1951.       int i = 0;
  1952.       while (test.readInstance(testReader)) {
  1953. Instance instance = test.instance(0);    
  1954. Instance withMissing = (Instance)instance.copy();
  1955. withMissing.setDataset(test);
  1956. double predValue = 
  1957.   ((Classifier)classifier).classifyInstance(withMissing);
  1958. if (test.classAttribute().isNumeric()) {
  1959.   if (Instance.isMissingValue(predValue)) {
  1960.     text.append(i + " missing ");
  1961.   } else {
  1962.     text.append(i + " " + predValue + " ");
  1963.   }
  1964.   if (instance.classIsMissing()) {
  1965.     text.append("missing");
  1966.   } else {
  1967.     text.append(instance.classValue());
  1968.   }
  1969.   text.append(" " + attributeValuesString(withMissing, attributesToOutput) + "n");
  1970. } else {
  1971.   if (Instance.isMissingValue(predValue)) {
  1972.     text.append(i + " missing ");
  1973.   } else {
  1974.     text.append(i + " "
  1975.          + test.classAttribute().value((int)predValue) + " ");
  1976.   }
  1977.   if (classifier instanceof DistributionClassifier) {
  1978.     if (Instance.isMissingValue(predValue)) {
  1979.       text.append("missing ");
  1980.     } else {
  1981.       text.append(((DistributionClassifier)classifier).
  1982.            distributionForInstance(withMissing)
  1983.            [(int)predValue]+" ");
  1984.     }
  1985.   }
  1986.   text.append(instance.toString(instance.classIndex()) + " "
  1987.       + attributeValuesString(withMissing, attributesToOutput) + "n");
  1988. }
  1989. test.delete(0);
  1990. i++;
  1991.       }
  1992.       testReader.close();
  1993.     }
  1994.     return text.toString();
  1995.   }
  1996.   /**
  1997.    * Builds a string listing the attribute values in a specified range of indices,
  1998.    * separated by commas and enclosed in brackets.
  1999.    *
  2000.    * @param instance the instance to print the values from
  2001.    * @param attributes the range of the attributes to list
  2002.    * @return a string listing values of the attributes in the range
  2003.    */
  2004.   private static String attributeValuesString(Instance instance, Range attRange) {
  2005.     StringBuffer text = new StringBuffer();
  2006.     if (attRange != null) {
  2007.       boolean firstOutput = true;
  2008.       attRange.setUpper(instance.numAttributes() - 1);
  2009.       for (int i=0; i<instance.numAttributes(); i++)
  2010. if (attRange.isInRange(i) && i != instance.classIndex()) {
  2011.   if (firstOutput) text.append("(");
  2012.   else text.append(",");
  2013.   text.append(instance.toString(i));
  2014.   firstOutput = false;
  2015. }
  2016.       if (!firstOutput) text.append(")");
  2017.     }
  2018.     return text.toString();
  2019.   }
  2020.   /**
  2021.    * Make up the help string giving all the command line options
  2022.    *
  2023.    * @param classifier the classifier to include options for
  2024.    * @return a string detailing the valid command line options
  2025.    */
  2026.   private static String makeOptionString(Classifier classifier) {
  2027.     StringBuffer optionsText = new StringBuffer("");
  2028.     // General options
  2029.     optionsText.append("nnGeneral options:nn");
  2030.     optionsText.append("-t <name of training file>n");
  2031.     optionsText.append("tSets training file.n");
  2032.     optionsText.append("-T <name of test file>n");
  2033.     optionsText.append("tSets test file. If missing, a cross-validation");
  2034.     optionsText.append(" will be performed on the training data.n");
  2035.     optionsText.append("-c <class index>n");
  2036.     optionsText.append("tSets index of class attribute (default: last).n");
  2037.     optionsText.append("-x <number of folds>n");
  2038.     optionsText.append("tSets number of folds for cross-validation (default: 10).n");
  2039.     optionsText.append("-s <random number seed>n");
  2040.     optionsText.append("tSets random number seed for cross-validation (default: 1).n");
  2041.     optionsText.append("-m <name of file with cost matrix>n");
  2042.     optionsText.append("tSets file with cost matrix.n");
  2043.     optionsText.append("-l <name of input file>n");
  2044.     optionsText.append("tSets model input file.n");
  2045.     optionsText.append("-d <name of output file>n");
  2046.     optionsText.append("tSets model output file.n");
  2047.     optionsText.append("-vn");
  2048.     optionsText.append("tOutputs no statistics for training data.n");
  2049.     optionsText.append("-on");
  2050.     optionsText.append("tOutputs statistics only, not the classifier.n");
  2051.     optionsText.append("-in");
  2052.     optionsText.append("tOutputs detailed information-retrieval");
  2053.     optionsText.append(" statistics for each class.n");
  2054.     optionsText.append("-kn");
  2055.     optionsText.append("tOutputs information-theoretic statistics.n");
  2056.     optionsText.append("-p <attribute range>n");
  2057.     optionsText.append("tOnly outputs predictions for test instances, along with attributes "
  2058.        + "(0 for none).n");
  2059.     optionsText.append("-rn");
  2060.     optionsText.append("tOnly outputs cumulative margin distribution.n");
  2061.     if (classifier instanceof Sourcable) {
  2062.       optionsText.append("-z <class name>n");
  2063.       optionsText.append("tOnly outputs the source representation"
  2064.  + " of the classifier, giving it the supplied"
  2065.  + " name.n");
  2066.     }
  2067.     if (classifier instanceof Drawable) {
  2068.       optionsText.append("-gn");
  2069.       optionsText.append("tOnly outputs the graph representation"
  2070.  + " of the classifier.n");
  2071.     }
  2072.     // Get scheme-specific options
  2073.     if (classifier instanceof OptionHandler) {
  2074.       optionsText.append("nOptions specific to "
  2075.   + classifier.getClass().getName()
  2076.   + ":nn");
  2077.       Enumeration enum = ((OptionHandler)classifier).listOptions();
  2078.       while (enum.hasMoreElements()) {
  2079. Option option = (Option) enum.nextElement();
  2080. optionsText.append(option.synopsis() + 'n');
  2081. optionsText.append(option.description() + "n");
  2082.       }
  2083.     }
  2084.     return optionsText.toString();
  2085.   }
  2086.   /**
  2087.    * Method for generating indices for the confusion matrix.
  2088.    *
  2089.    * @param num integer to format
  2090.    * @return the formatted integer as a string
  2091.    */
  2092.   private String num2ShortID(int num,char [] IDChars,int IDWidth) {
  2093.     
  2094.     char ID [] = new char [IDWidth];
  2095.     int i;
  2096.     
  2097.     for(i = IDWidth - 1; i >=0; i--) {
  2098.       ID[i] = IDChars[num % IDChars.length];
  2099.       num = num / IDChars.length - 1;
  2100.       if (num < 0) {
  2101. break;
  2102.       }
  2103.     }
  2104.     for(i--; i >= 0; i--) {
  2105.       ID[i] = ' ';
  2106.     }
  2107.     return new String(ID);
  2108.   }
  2109.   /**
  2110.    * Convert a single prediction into a probability distribution
  2111.    * with all zero probabilities except the predicted value which
  2112.    * has probability 1.0;
  2113.    *
  2114.    * @param predictedClass the index of the predicted class
  2115.    * @return the probability distribution
  2116.    */
  2117.   private double [] makeDistribution(double predictedClass) {
  2118.     double [] result = new double [m_NumClasses];
  2119.     if (Instance.isMissingValue(predictedClass)) {
  2120.       return result;
  2121.     }
  2122.     if (m_ClassIsNominal) {
  2123.       result[(int)predictedClass] = 1.0;
  2124.     } else {
  2125.       result[0] = predictedClass;
  2126.     }
  2127.     return result;
  2128.   } 
  2129.   /**
  2130.    * Updates all the statistics about a classifiers performance for 
  2131.    * the current test instance.
  2132.    *
  2133.    * @param predictedDistribution the probabilities assigned to 
  2134.    * each class
  2135.    * @param instance the instance to be classified
  2136.    * @exception Exception if the class of the instance is not
  2137.    * set
  2138.    */
  2139.   private void updateStatsForClassifier(double [] predictedDistribution,
  2140. Instance instance)
  2141.        throws Exception {
  2142.     int actualClass = (int)instance.classValue();
  2143.     double costFactor = 1;
  2144.     if (!instance.classIsMissing()) {
  2145.       updateMargins(predictedDistribution, actualClass, instance.weight());
  2146.       // Determine the predicted class (doesn't detect multiple 
  2147.       // classifications)
  2148.       int predictedClass = -1;
  2149.       double bestProb = 0.0;
  2150.       for(int i = 0; i < m_NumClasses; i++) {
  2151. if (predictedDistribution[i] > bestProb) {
  2152.   predictedClass = i;
  2153.   bestProb = predictedDistribution[i];
  2154. }
  2155.       }
  2156.       m_WithClass += instance.weight();
  2157.       // Determine misclassification cost
  2158.       if (m_CostMatrix != null) {
  2159.         if (predictedClass < 0) {
  2160.           // For missing predictions, we assume the worst possible cost.
  2161.           // This is pretty harsh.
  2162.           // Perhaps we could take the negative of the cost of a correct
  2163.           // prediction (-m_CostMatrix.getElement(actualClass,actualClass)),
  2164.           // although often this will be zero
  2165.           m_TotalCost += instance.weight()
  2166.             * m_CostMatrix.getMaxCost(actualClass);
  2167.         } else {
  2168.           m_TotalCost += instance.weight() 
  2169.             * m_CostMatrix.getElement(actualClass, predictedClass);
  2170.         }
  2171.       }
  2172.       // Update counts when no class was predicted
  2173.       if (predictedClass < 0) {
  2174. m_Unclassified += instance.weight();
  2175. return;
  2176.       }
  2177.       double predictedProb = Math.max(MIN_SF_PROB,
  2178.       predictedDistribution[actualClass]);
  2179.       double priorProb = Math.max(MIN_SF_PROB,
  2180.   m_ClassPriors[actualClass]
  2181.   / m_ClassPriorsSum);
  2182.       if (predictedProb >= priorProb) {
  2183. m_SumKBInfo += (Utils.log2(predictedProb) - 
  2184. Utils.log2(priorProb))
  2185.   * instance.weight();
  2186.       } else {
  2187. m_SumKBInfo -= (Utils.log2(1.0-predictedProb) - 
  2188. Utils.log2(1.0-priorProb))
  2189.   * instance.weight();
  2190.       }
  2191.       m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
  2192.       m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
  2193.       updateNumericScores(predictedDistribution, 
  2194.   makeDistribution(instance.classValue()), 
  2195.   instance.weight());
  2196.       // Update other stats
  2197.       m_ConfusionMatrix[actualClass][predictedClass] += instance.weight();
  2198.       if (predictedClass != actualClass) {
  2199. m_Incorrect += instance.weight();
  2200.       } else {
  2201. m_Correct += instance.weight();
  2202.       }
  2203.     } else {
  2204.       m_MissingClass += instance.weight();
  2205.     }
  2206.   }
  2207.   /**
  2208.    * Updates all the statistics about a predictors performance for 
  2209.    * the current test instance.
  2210.    *
  2211.    * @param predictedValue the numeric value the classifier predicts
  2212.    * @param instance the instance to be classified
  2213.    * @exception Exception if the class of the instance is not
  2214.    * set
  2215.    */
  2216.   private void updateStatsForPredictor(double predictedValue,
  2217.        Instance instance) 
  2218.        throws Exception {
  2219.     if (!instance.classIsMissing()){
  2220.       // Update stats
  2221.       m_WithClass += instance.weight();
  2222.       if (Instance.isMissingValue(predictedValue)) {
  2223. m_Unclassified += instance.weight();
  2224. return;
  2225.       }
  2226.       m_SumClass += instance.weight() * instance.classValue();
  2227.       m_SumSqrClass += instance.weight() * instance.classValue()
  2228.       * instance.classValue();
  2229.       m_SumClassPredicted += instance.weight() 
  2230.       * instance.classValue() * predictedValue;
  2231.       m_SumPredicted += predictedValue;
  2232.       m_SumSqrPredicted += predictedValue * predictedValue;
  2233.       if (m_ErrorEstimator == null) {
  2234. setNumericPriorsFromBuffer();
  2235.       }
  2236.       double predictedProb = Math.max(m_ErrorEstimator.getProbability(
  2237.       predictedValue 
  2238.       - instance.classValue()),
  2239.       MIN_SF_PROB);
  2240.       double priorProb = Math.max(m_PriorErrorEstimator.getProbability(
  2241.                           instance.classValue()),
  2242.   MIN_SF_PROB);
  2243.       m_SumSchemeEntropy -= Utils.log2(predictedProb) * instance.weight();
  2244.       m_SumPriorEntropy -= Utils.log2(priorProb) * instance.weight();
  2245.       m_ErrorEstimator.addValue(predictedValue - instance.classValue(), 
  2246. instance.weight());
  2247.       updateNumericScores(makeDistribution(predictedValue),
  2248.   makeDistribution(instance.classValue()),
  2249.   instance.weight());
  2250.      
  2251.     } else
  2252.       m_MissingClass += instance.weight();
  2253.   }
  2254.   /**
  2255.    * Update the cumulative record of classification margins
  2256.    *
  2257.    * @param predictedDistribution the probability distribution predicted for
  2258.    * the current instance
  2259.    * @param actualClass the index of the actual instance class
  2260.    * @param weight the weight assigned to the instance
  2261.    */
  2262.   private void updateMargins(double [] predictedDistribution, 
  2263.      int actualClass, double weight) {
  2264.     double probActual = predictedDistribution[actualClass];
  2265.     double probNext = 0;
  2266.     for(int i = 0; i < m_NumClasses; i++)
  2267.       if ((i != actualClass) &&
  2268.   (predictedDistribution[i] > probNext))
  2269. probNext = predictedDistribution[i];
  2270.     double margin = probActual - probNext;
  2271.     int bin = (int)((margin + 1.0) / 2.0 * k_MarginResolution);
  2272.     m_MarginCounts[bin] += weight;
  2273.   }
  2274.   /**
  2275.    * Update the numeric accuracy measures. For numeric classes, the
  2276.    * accuracy is between the actual and predicted class values. For 
  2277.    * nominal classes, the accuracy is between the actual and 
  2278.    * predicted class probabilities.
  2279.    *
  2280.    * @param predicted the predicted values
  2281.    * @param actual the actual value
  2282.    * @param weight the weight associated with this prediction
  2283.    */
  2284.   private void updateNumericScores(double [] predicted, 
  2285.    double [] actual, double weight) {
  2286.     double diff;
  2287.     double sumErr = 0, sumAbsErr = 0, sumSqrErr = 0;
  2288.     double sumPriorAbsErr = 0, sumPriorSqrErr = 0;
  2289.     for(int i = 0; i < m_NumClasses; i++) {
  2290.       diff = predicted[i] - actual[i];
  2291.       sumErr += diff;
  2292.       sumAbsErr += Math.abs(diff);
  2293.       sumSqrErr += diff * diff;
  2294.       diff = (m_ClassPriors[i] / m_ClassPriorsSum) - actual[i];
  2295.       sumPriorAbsErr += Math.abs(diff);
  2296.       sumPriorSqrErr += diff * diff;
  2297.     }
  2298.     m_SumErr += weight * sumErr / m_NumClasses;
  2299.     m_SumAbsErr += weight * sumAbsErr / m_NumClasses;
  2300.     m_SumSqrErr += weight * sumSqrErr / m_NumClasses;
  2301.     m_SumPriorAbsErr += weight * sumPriorAbsErr / m_NumClasses;
  2302.     m_SumPriorSqrErr += weight * sumPriorSqrErr / m_NumClasses;
  2303.   }
  2304.   /**
  2305.    * Adds a numeric (non-missing) training class value and weight to 
  2306.    * the buffer of stored values.
  2307.    *
  2308.    * @param classValue the class value
  2309.    * @param weight the instance weight
  2310.    */
  2311.   private void addNumericTrainClass(double classValue, double weight) {
  2312.     if (m_TrainClassVals == null) {
  2313.       m_TrainClassVals = new double [100];
  2314.       m_TrainClassWeights = new double [100];
  2315.     }
  2316.     if (m_NumTrainClassVals == m_TrainClassVals.length) {
  2317.       double [] temp = new double [m_TrainClassVals.length * 2];
  2318.       System.arraycopy(m_TrainClassVals, 0, 
  2319.        temp, 0, m_TrainClassVals.length);
  2320.       m_TrainClassVals = temp;
  2321.       temp = new double [m_TrainClassWeights.length * 2];
  2322.       System.arraycopy(m_TrainClassWeights, 0, 
  2323.        temp, 0, m_TrainClassWeights.length);
  2324.       m_TrainClassWeights = temp;
  2325.     }
  2326.     m_TrainClassVals[m_NumTrainClassVals] = classValue;
  2327.     m_TrainClassWeights[m_NumTrainClassVals] = weight;
  2328.     m_NumTrainClassVals++;
  2329.   }
  2330.   /**
  2331.    * Sets up the priors for numeric class attributes from the 
  2332.    * training class values that have been seen so far.
  2333.    */
  2334.   private void setNumericPriorsFromBuffer() {
  2335.     
  2336.     double numPrecision = 0.01; // Default value
  2337.     if (m_NumTrainClassVals > 1) {
  2338.       double [] temp = new double [m_NumTrainClassVals];
  2339.       System.arraycopy(m_TrainClassVals, 0, temp, 0, m_NumTrainClassVals);
  2340.       int [] index = Utils.sort(temp);
  2341.       double lastVal = temp[index[0]];
  2342.       double currentVal, deltaSum = 0;
  2343.       int distinct = 0;
  2344.       for (int i = 1; i < temp.length; i++) {
  2345. double current = temp[index[i]];
  2346. if (current != lastVal) {
  2347.   deltaSum += current - lastVal;
  2348.   lastVal = current;
  2349.   distinct++;
  2350. }
  2351.       }
  2352.       if (distinct > 0) {
  2353. numPrecision = deltaSum / distinct;
  2354.       }
  2355.     }
  2356.     m_PriorErrorEstimator = new KernelEstimator(numPrecision);
  2357.     m_ErrorEstimator = new KernelEstimator(numPrecision);
  2358.     m_ClassPriors[0] = m_ClassPriorsSum = 0.0001; // zf correction
  2359.     for (int i = 0; i < m_NumTrainClassVals; i++) {
  2360.       m_ClassPriors[0] += m_TrainClassVals[i] * m_TrainClassWeights[i];
  2361.       m_ClassPriorsSum += m_TrainClassWeights[i];
  2362.       m_PriorErrorEstimator.addValue(m_TrainClassVals[i],
  2363.      m_TrainClassWeights[i]);
  2364.     }
  2365.   }
  2366. }