EvaluationUtils.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 5k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    EvaluationUtils.java
  18.  *    Copyright (C) 2000 Intelligenesis Corp.
  19.  *
  20.  */
  21. package weka.classifiers.evaluation;
  22. import weka.core.FastVector;
  23. import weka.core.Instance;
  24. import weka.core.Instances;
  25. import weka.classifiers.DistributionClassifier;
  26. import java.util.Random;
  27. /**
  28.  * Contains utility functions for generating lists of predictions in 
  29.  * various manners.
  30.  *
  31.  * @author Len Trigg (len@intelligenesis.net)
  32.  * @version $Revision: 1.6 $
  33.  */
  34. public class EvaluationUtils {
  35.   /** Seed used to randomize data in cross-validation */
  36.   private int m_Seed = 1;
  37.   /** Sets the seed for randomization during cross-validation */
  38.   public void setSeed(int seed) { m_Seed = seed; }
  39.   /** Gets the seed for randomization during cross-validation */
  40.   public int getSeed() { return m_Seed; }
  41.   
  42.   /**
  43.    * Generate a bunch of predictions ready for processing, by performing a
  44.    * cross-validation on the supplied dataset.
  45.    *
  46.    * @param classifier the DistributionClassifier to evaluate
  47.    * @param data the dataset
  48.    * @param numFolds the number of folds in the cross-validation.
  49.    * @exception Exception if an error occurs
  50.    */
  51.   public FastVector getCVPredictions(DistributionClassifier classifier, 
  52.                                      Instances data, 
  53.                                      int numFolds) 
  54.     throws Exception {
  55.     FastVector predictions = new FastVector();
  56.     Instances runInstances = new Instances(data);
  57.     Random random = new Random(m_Seed);
  58.     runInstances.randomize(random);
  59.     if (runInstances.classAttribute().isNominal() && (numFolds > 1)) {
  60.       runInstances.stratify(numFolds);
  61.     }
  62.     int inst = 0;
  63.     for (int fold = 0; fold < numFolds; fold++) {
  64.       Instances train = runInstances.trainCV(numFolds, fold);
  65.       Instances test = runInstances.testCV(numFolds, fold);
  66.       FastVector foldPred = getTrainTestPredictions(classifier, train, test);
  67.       predictions.appendElements(foldPred);
  68.     } 
  69.     return predictions;
  70.   }
  71.   /**
  72.    * Generate a bunch of predictions ready for processing, by performing a
  73.    * evaluation on a test set after training on the given training set.
  74.    *
  75.    * @param classifier the DistributionClassifier to evaluate
  76.    * @param train the training dataset
  77.    * @param test the test dataset
  78.    * @exception Exception if an error occurs
  79.    */
  80.   public FastVector getTrainTestPredictions(DistributionClassifier classifier, 
  81.                                             Instances train, Instances test) 
  82.     throws Exception {
  83.     
  84.     classifier.buildClassifier(train);
  85.     return getTestPredictions(classifier, test);
  86.   }
  87.   /**
  88.    * Generate a bunch of predictions ready for processing, by performing a
  89.    * evaluation on a test set assuming the classifier is already trained.
  90.    *
  91.    * @param classifier the pre-trained DistributionClassifier to evaluate
  92.    * @param test the test dataset
  93.    * @exception Exception if an error occurs
  94.    */
  95.   public FastVector getTestPredictions(DistributionClassifier classifier, 
  96.                                        Instances test) 
  97.     throws Exception {
  98.     
  99.     FastVector predictions = new FastVector();
  100.     for (int i = 0; i < test.numInstances(); i++) {
  101.       if (!test.instance(i).classIsMissing()) {
  102.         predictions.addElement(getPrediction(classifier, test.instance(i)));
  103.       }
  104.     }
  105.     return predictions;
  106.   }
  107.   
  108.   /**
  109.    * Generate a single prediction for a test instance given the pre-trained
  110.    * classifier.
  111.    *
  112.    * @param classifier the pre-trained DistributionClassifier to evaluate
  113.    * @param test the test instance
  114.    * @exception Exception if an error occurs
  115.    */
  116.   public Prediction getPrediction(DistributionClassifier classifier,
  117.                                   Instance test)
  118.     throws Exception {
  119.    
  120.     double actual = test.classValue();
  121.     double [] dist = classifier.distributionForInstance(test);
  122.     if (test.classAttribute().isNominal()) {
  123.       return new NominalPrediction(actual, dist, test.weight());
  124.     } else {
  125.       return new NumericPrediction(actual, dist[0], test.weight());
  126.     }
  127.   }
  128. }