CostCurve.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 5k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    CostCurve.java
  18.  *    Copyright (C) 2001 Mark Hall
  19.  *
  20.  */
  21. package weka.classifiers.evaluation;
  22. import weka.classifiers.functions.VotedPerceptron;
  23. import weka.core.Utils;
  24. import weka.core.Attribute;
  25. import weka.core.FastVector;
  26. import weka.core.Instance;
  27. import weka.core.Instances;
  28. import weka.classifiers.DistributionClassifier;
  29. /**
  30.  * Generates points illustrating probablity cost tradeoffs that can be 
  31.  * obtained by varying the threshold value between classes. For example, 
  32.  * the typical threshold value of 0.5 means the predicted probability of 
  33.  * "positive" must be higher than 0.5 for the instance to be predicted as 
  34.  * "positive".
  35.  *
  36.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  37.  * @version $Revision: 1.4 $
  38.  */
  39. public class CostCurve {
  40.   /** The name of the relation used in cost curve datasets */
  41.   public final static String RELATION_NAME = "CostCurve";
  42.   public final static String PROB_COST_FUNC_NAME = "Probability Cost Function";
  43.   public final static String NORM_EXPECTED_COST_NAME = 
  44.     "Normalized Expected Cost";
  45.   public final static String THRESHOLD_NAME = "Threshold";
  46.   /**
  47.    * Calculates the performance stats for the default class and return 
  48.    * results as a set of Instances. The
  49.    * structure of these Instances is as follows:<p> <ul> 
  50.    * <li> <b>Probability Cost Function </b>
  51.    * <li> <b>Normalized Expected Cost</b>
  52.    * <li> <b>Threshold</b> contains the probability threshold that gives
  53.    * rise to the previous performance values. 
  54.    * </ul> <p>
  55.    *
  56.    * @see TwoClassStats
  57.    * @param classIndex index of the class of interest.
  58.    * @return datapoints as a set of instances, null if no predictions
  59.    * have been made.
  60.    */
  61.   public Instances getCurve(FastVector predictions) {
  62.     if (predictions.size() == 0) {
  63.       return null;
  64.     }
  65.     return getCurve(predictions, 
  66.                     ((NominalPrediction)predictions.elementAt(0))
  67.                     .distribution().length - 1);
  68.   }
  69.   /**
  70.    * Calculates the performance stats for the desired class and return 
  71.    * results as a set of Instances.
  72.    *
  73.    * @param classIndex index of the class of interest.
  74.    * @return datapoints as a set of instances.
  75.    */
  76.   public Instances getCurve(FastVector predictions, int classIndex) {
  77.     if ((predictions.size() == 0) ||
  78.         (((NominalPrediction)predictions.elementAt(0))
  79.          .distribution().length <= classIndex)) {
  80.       return null;
  81.     }
  82.     
  83.     ThresholdCurve tc = new ThresholdCurve();
  84.     Instances threshInst = tc.getCurve(predictions, classIndex);
  85.     Instances insts = makeHeader();
  86.     int fpind = threshInst.attribute(ThresholdCurve.FP_RATE_NAME).index();
  87.     int tpind = threshInst.attribute(ThresholdCurve.TP_RATE_NAME).index();
  88.     int threshind = threshInst.attribute(ThresholdCurve.THRESHOLD_NAME).index();
  89.     
  90.     double [] vals;
  91.     double fpval, tpval, thresh;
  92.     for (int i = 0; i< threshInst.numInstances(); i++) {
  93.       fpval = threshInst.instance(i).value(fpind);
  94.       tpval = threshInst.instance(i).value(tpind);
  95.       thresh = threshInst.instance(i).value(threshind);
  96.       vals = new double [3];
  97.       vals[0] = 0; vals[1] = fpval; vals[2] = thresh;
  98.       insts.add(new Instance(1.0, vals));
  99.       vals = new double [3];
  100.       vals[0] = 1; vals[1] = 1.0 - tpval; vals[2] = thresh;
  101.       insts.add(new Instance(1.0, vals));
  102.     }
  103.     
  104.     return insts;
  105.   }
  106.   private Instances makeHeader() {
  107.     FastVector fv = new FastVector();
  108.     fv.addElement(new Attribute(PROB_COST_FUNC_NAME));
  109.     fv.addElement(new Attribute(NORM_EXPECTED_COST_NAME));
  110.     fv.addElement(new Attribute(THRESHOLD_NAME));
  111.     return new Instances(RELATION_NAME, fv, 100);
  112.   }
  113.   /**
  114.    * Tests the CostCurve generation from the command line.
  115.    * The classifier is currently hardcoded. Pipe in an arff file.
  116.    *
  117.    * @param args currently ignored
  118.    */
  119.   public static void main(String [] args) {
  120.     try {
  121.       
  122.       Instances inst = new Instances(new java.io.InputStreamReader(System.in));
  123.       
  124.       inst.setClassIndex(inst.numAttributes() - 1);
  125.       CostCurve cc = new CostCurve();
  126.       EvaluationUtils eu = new EvaluationUtils();
  127.       DistributionClassifier classifier = new weka.classifiers.functions.VotedPerceptron();
  128.       FastVector predictions = new FastVector();
  129.       for (int i = 0; i < 2; i++) { // Do two runs.
  130. eu.setSeed(i);
  131. predictions.appendElements(eu.getCVPredictions(classifier, inst, 10));
  132. //System.out.println("nnn");
  133.       }
  134.       Instances result = cc.getCurve(predictions);
  135.       System.out.println(result);
  136.       
  137.     } catch (Exception ex) {
  138.       ex.printStackTrace();
  139.     }
  140.   }
  141. }