ThresholdSelector.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 28k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    ThresholdSelector.java
  18.  *    Copyright (C) 1999 Eibe Frank
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import java.util.Enumeration;
  23. import java.util.Random;
  24. import java.util.Vector;
  25. import weka.classifiers.evaluation.EvaluationUtils;
  26. import weka.classifiers.evaluation.ThresholdCurve;
  27. import weka.core.Attribute;
  28. import weka.core.AttributeStats;
  29. import weka.core.FastVector;
  30. import weka.core.Instance;
  31. import weka.core.Instances;
  32. import weka.core.Option;
  33. import weka.core.OptionHandler;
  34. import weka.core.SelectedTag;
  35. import weka.core.Tag;
  36. import weka.core.Utils;
  37. import weka.core.UnsupportedClassTypeException;
  38. /**
  39.  * Class for selecting a threshold on a probability output by a
  40.  * distribution classifier. The threshold is set so that a given
  41.  * performance measure is optimized. Currently this is the
  42.  * F-measure. Performance is measured either on the training data, a hold-out
  43.  * set or using cross-validation. In addition, the probabilities returned
  44.  * by the base learner can have their range expanded so that the output
  45.  * probabilities will reside between 0 and 1 (this is useful if the scheme
  46.  * normally produces probabilities in a very narrow range).<p>
  47.  *
  48.  * Valid options are:<p>
  49.  *
  50.  * -C num <br>
  51.  * The class for which threshold is determined. Valid values are:
  52.  * 1, 2 (for first and second classes, respectively), 3 (for whichever
  53.  * class is least frequent), 4 (for whichever class value is most 
  54.  * frequent), and 5 (for the first class named any of "yes","pos(itive)",
  55.  * "1", or method 3 if no matches). (default 5). <p>
  56.  *
  57.  * -W classname <br>
  58.  * Specify the full class name of the base classifier. <p>
  59.  *
  60.  * -X num <br> 
  61.  * Number of folds used for cross validation. If just a
  62.  * hold-out set is used, this determines the size of the hold-out set
  63.  * (default 3).<p>
  64.  *
  65.  * -R integer <br>
  66.  * Sets whether confidence range correction is applied. This can be used
  67.  * to ensure the confidences range from 0 to 1. Use 0 for no range correction,
  68.  * 1 for correction based on the min/max values seen during threshold selection
  69.  * (default 0).<p>
  70.  *
  71.  * -S seed <br>
  72.  * Random number seed (default 1).<p>
  73.  *
  74.  * -E integer <br>
  75.  * Sets the evaluation mode. Use 0 for evaluation using cross-validation,
  76.  * 1 for evaluation using hold-out set, and 2 for evaluation on the
  77.  * training data (default 1).<p>
  78.  *
  79.  * Options after -- are passed to the designated sub-classifier. <p>
  80.  *
  81.  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
  82.  * @version $Revision: 1.22 $ 
  83.  */
  84. public class ThresholdSelector extends DistributionClassifier 
  85.   implements OptionHandler {
  86.   /* Type of correction applied to threshold range */ 
  87.   public final static int RANGE_NONE = 0;
  88.   public final static int RANGE_BOUNDS = 1;
  89.   public static final Tag [] TAGS_RANGE = {
  90.     new Tag(RANGE_NONE, "No range correction"),
  91.     new Tag(RANGE_BOUNDS, "Correct based on min/max observed")
  92.   };
  93.   /* The evaluation modes */
  94.   public final static int EVAL_TRAINING_SET = 2;
  95.   public final static int EVAL_TUNED_SPLIT = 1;
  96.   public final static int EVAL_CROSS_VALIDATION = 0;
  97.   public static final Tag [] TAGS_EVAL = {
  98.     new Tag(EVAL_TRAINING_SET, "Entire training set"),
  99.     new Tag(EVAL_TUNED_SPLIT, "Single tuned fold"),
  100.     new Tag(EVAL_CROSS_VALIDATION, "N-Fold cross validation")
  101.   };
  102.   /* How to determine which class value to optimize for */
  103.   public final static int OPTIMIZE_0     = 0;
  104.   public final static int OPTIMIZE_1     = 1;
  105.   public final static int OPTIMIZE_LFREQ = 2;
  106.   public final static int OPTIMIZE_MFREQ = 3;
  107.   public final static int OPTIMIZE_POS_NAME = 4;
  108.   public static final Tag [] TAGS_OPTIMIZE = {
  109.     new Tag(OPTIMIZE_0, "First class value"),
  110.     new Tag(OPTIMIZE_1, "Second class value"),
  111.     new Tag(OPTIMIZE_LFREQ, "Least frequent class value"),
  112.     new Tag(OPTIMIZE_MFREQ, "Most frequent class value"),
  113.     new Tag(OPTIMIZE_POS_NAME, "Class value named: "yes", "pos(itive)","1"")
  114.   };
  115.   /** The generated base classifier */
  116.   protected DistributionClassifier m_Classifier = 
  117.     new weka.classifiers.ZeroR();
  118.   /** The upper threshold used as the basis of correction */
  119.   protected double m_HighThreshold = 1;
  120.   /** The lower threshold used as the basis of correction */
  121.   protected double m_LowThreshold = 0;
  122.   /** The threshold that lead to the best performance */
  123.   protected double m_BestThreshold = -Double.MAX_VALUE;
  124.   /** The best value that has been observed */
  125.   protected double m_BestValue = - Double.MAX_VALUE;
  126.   
  127.   /** The number of folds used in cross-validation */
  128.   protected int m_NumXValFolds = 3;
  129.   /** Random number seed */
  130.   protected int m_Seed = 1;
  131.   /** Designated class value, determined during building */
  132.   protected int m_DesignatedClass = 0;
  133.   /** Method to determine which class to optimize for */
  134.   protected int m_ClassMode = OPTIMIZE_POS_NAME;
  135.   /** The evaluation mode */
  136.   protected int m_EvalMode = EVAL_TUNED_SPLIT;
  137.   /** The range correction mode */
  138.   protected int m_RangeMode = RANGE_NONE;
  139.   /** The minimum value for the criterion. If threshold adjustment
  140.       yields less than that, the default threshold of 0.5 is used. */
  141.   protected final static double MIN_VALUE = 0.05;
  142.   /**
  143.    * Collects the classifier predictions using the specified evaluation method.
  144.    *
  145.    * @param instances the set of <code>Instances</code> to generate
  146.    * predictions for.
  147.    * @param mode the evaluation mode.
  148.    * @param numFolds the number of folds to use if not evaluating on the
  149.    * full training set.
  150.    * @return a <code>FastVector</code> containing the predictions.
  151.    * @exception Exception if an error occurs generating the predictions.
  152.    */
  153.   protected FastVector getPredictions(Instances instances, int mode, int numFolds) 
  154.     throws Exception {
  155.     EvaluationUtils eu = new EvaluationUtils();
  156.     eu.setSeed(m_Seed);
  157.     
  158.     switch (mode) {
  159.     case EVAL_TUNED_SPLIT:
  160.       Instances trainData = null, evalData = null;
  161.       Instances data = new Instances(instances);
  162.       data.randomize(new Random(m_Seed));
  163.       data.stratify(numFolds);
  164.       
  165.       // Make sure that both subsets contain at least one positive instance
  166.       for (int subsetIndex = 0; subsetIndex < numFolds; subsetIndex++) {
  167.         trainData = data.trainCV(numFolds, subsetIndex);
  168.         evalData = data.testCV(numFolds, subsetIndex);
  169.         if (checkForInstance(trainData) && checkForInstance(evalData)) {
  170.           break;
  171.         }
  172.       }
  173.       return eu.getTrainTestPredictions(m_Classifier, trainData, evalData);
  174.     case EVAL_TRAINING_SET:
  175.       return eu.getTrainTestPredictions(m_Classifier, instances, instances);
  176.     case EVAL_CROSS_VALIDATION:
  177.       return eu.getCVPredictions(m_Classifier, instances, numFolds);
  178.     default:
  179.       throw new RuntimeException("Unrecognized evaluation mode");
  180.     }
  181.   }
  182.   /**
  183.    * Finds the best threshold, this implementation searches for the
  184.    * highest FMeasure. If no FMeasure higher than MIN_VALUE is found,
  185.    * the default threshold of 0.5 is used.
  186.    *
  187.    * @param predictions a <code>FastVector</code> containing the predictions.
  188.    */
  189.   protected void findThreshold(FastVector predictions) {
  190.     Instances curve = (new ThresholdCurve()).getCurve(predictions, m_DesignatedClass);
  191.     //System.err.println(curve);
  192.     double low = 1.0;
  193.     double high = 0.0;
  194.     if (curve.numInstances() > 0) {
  195.       Instance maxFM = curve.instance(0);
  196.       int indexFM = curve.attribute(ThresholdCurve.FMEASURE_NAME).index();
  197.       int indexThreshold = curve.attribute(ThresholdCurve.THRESHOLD_NAME).index();
  198.       for (int i = 1; i < curve.numInstances(); i++) {
  199.         Instance current = curve.instance(i);
  200.         if (current.value(indexFM) > maxFM.value(indexFM)) {
  201.           maxFM = current;
  202.         }
  203.         if (m_RangeMode == RANGE_BOUNDS) {
  204.           double thresh = current.value(indexThreshold);
  205.           if (thresh < low) {
  206.             low = thresh;
  207.           }
  208.           if (thresh > high) {
  209.             high = thresh;
  210.           }
  211.         }
  212.       }
  213.       if (maxFM.value(indexFM) > MIN_VALUE) {
  214.         m_BestThreshold = maxFM.value(indexThreshold);
  215.         m_BestValue = maxFM.value(indexFM);
  216.         //System.err.println("maxFM: " + maxFM);
  217.       }
  218.       if (m_RangeMode == RANGE_BOUNDS) {
  219.         m_LowThreshold = low;
  220.         m_HighThreshold = high;
  221.         //System.err.println("Threshold range: " + low + " - " + high);
  222.       }
  223.     }
  224.   }
  225.   /**
  226.    * Returns an enumeration describing the available options
  227.    *
  228.    * @return an enumeration of all the available options
  229.    */
  230.   public Enumeration listOptions() {
  231.     Vector newVector = new Vector(6);
  232.     newVector.addElement(new Option(
  233.               "tThe class for which threshold is determined. Valid values are:n" +
  234.               "t1, 2 (for first and second classes, respectively), 3 (for whichevern" +
  235.               "tclass is least frequent), and 4 (for whichever class value is mostn" +
  236.               "tfrequent), and 5 (for the first class named any of "yes","pos(itive)"n" +
  237.               "t"1", or method 3 if no matches). (default 5).",
  238.       "C", 1, "-C <integer>"));
  239.     newVector.addElement(new Option(
  240.       "tFull name of classifier to perform parameter selection on.n"
  241.       + "teg: weka.classifiers.NaiveBayes",
  242.       "W", 1, "-W <classifier class name>"));
  243.     newVector.addElement(new Option(
  244.       "tNumber of folds used for cross validation. If just an" +
  245.       "thold-out set is used, this determines the size of the hold-out setn" +
  246.       "t(default 3).",
  247.       "X", 1, "-X <number of folds>"));
  248.     newVector.addElement(new Option(
  249.       "tSets whether confidence range correction is applied. Thisn" +
  250.               "tcan be used to ensure the confidences range from 0 to 1.n" +
  251.               "tUse 0 for no range correction, 1 for correction based onn" +
  252.               "tthe min/max values seen during threshold selectionn"+
  253.               "t(default 0).",
  254.       "R", 1, "-R <integer>"));
  255.     newVector.addElement(new Option(
  256.       "tSets the random number seed (default 1).",
  257.       "S", 1, "-S <random number seed>"));
  258.     newVector.addElement(new Option(
  259.       "tSets the evaluation mode. Use 0 forn" +
  260.       "tevaluation using cross-validation,n" +
  261.       "t1 for evaluation using hold-out set,n" +
  262.       "tand 2 for evaluation on then" +
  263.       "ttraining data (default 1).",
  264.       "E", 1, "-E <integer>"));
  265.     if ((m_Classifier != null) &&
  266. (m_Classifier instanceof OptionHandler)) {
  267.       newVector.addElement(new Option("",
  268.         "", 0,
  269. "nOptions specific to sub-classifier "
  270.         + m_Classifier.getClass().getName()
  271. + ":n(use -- to signal start of sub-classifier options)"));
  272.       Enumeration enum = ((OptionHandler)m_Classifier).listOptions();
  273.       while (enum.hasMoreElements()) {
  274. newVector.addElement(enum.nextElement());
  275.       }
  276.     }
  277.     return newVector.elements();
  278.   }
  279.   /**
  280.    * Parses a given list of options. Valid options are:<p>
  281.    *
  282.    * -C num <br>
  283.    * The class for which threshold is determined. Valid values are:
  284.    * 1, 2 (for first and second classes, respectively), 3 (for whichever
  285.    * class is least frequent), 4 (for whichever class value is most 
  286.    * frequent), and 5 (for the first class named any of "yes","pos(itive)",
  287.    * "1", or method 3 if no matches). (default 3). <p>
  288.    *
  289.    * -W classname <br>
  290.    * Specify the full class name of classifier to perform cross-validation
  291.    * selection on.<p>
  292.    *
  293.    * -X num <br> 
  294.    * Number of folds used for cross validation. If just a
  295.    * hold-out set is used, this determines the size of the hold-out set
  296.    * (default 3).<p>
  297.    *
  298.    * -R integer <br>
  299.    * Sets whether confidence range correction is applied. This can be used
  300.    * to ensure the confidences range from 0 to 1. Use 0 for no range correction,
  301.    * 1 for correction based on the min/max values seen during threshold 
  302.    * selection (default 0).<p>
  303.    *
  304.    * -S seed <br>
  305.    * Random number seed (default 1).<p>
  306.    *
  307.    * -E integer <br>
  308.    * Sets the evaluation mode. Use 0 for evaluation using cross-validation,
  309.    * 1 for evaluation using hold-out set, and 2 for evaluation on the
  310.    * training data (default 1).<p>
  311.    *
  312.    * Options after -- are passed to the designated sub-classifier. <p>
  313.    *
  314.    * @param options the list of options as an array of strings
  315.    * @exception Exception if an option is not supported
  316.    */
  317.   public void setOptions(String[] options) throws Exception {
  318.     
  319.     String classString = Utils.getOption('C', options);
  320.     if (classString.length() != 0) {
  321.       setDesignatedClass(new SelectedTag(Integer.parseInt(classString) - 1, 
  322.                                          TAGS_OPTIMIZE));
  323.     } else {
  324.       setDesignatedClass(new SelectedTag(OPTIMIZE_LFREQ, TAGS_OPTIMIZE));
  325.     }
  326.     String modeString = Utils.getOption('E', options);
  327.     if (modeString.length() != 0) {
  328.       setEvaluationMode(new SelectedTag(Integer.parseInt(modeString), 
  329.                                          TAGS_EVAL));
  330.     } else {
  331.       setEvaluationMode(new SelectedTag(EVAL_TUNED_SPLIT, TAGS_EVAL));
  332.     }
  333.     String rangeString = Utils.getOption('R', options);
  334.     if (rangeString.length() != 0) {
  335.       setRangeCorrection(new SelectedTag(Integer.parseInt(rangeString) - 1, 
  336.                                          TAGS_RANGE));
  337.     } else {
  338.       setRangeCorrection(new SelectedTag(RANGE_NONE, TAGS_RANGE));
  339.     }
  340.     String foldsString = Utils.getOption('X', options);
  341.     if (foldsString.length() != 0) {
  342.       setNumXValFolds(Integer.parseInt(foldsString));
  343.     } else {
  344.       setNumXValFolds(3);
  345.     }
  346.     String randomString = Utils.getOption('S', options);
  347.     if (randomString.length() != 0) {
  348.       setSeed(Integer.parseInt(randomString));
  349.     } else {
  350.       setSeed(1);
  351.     }
  352.     String classifierName = Utils.getOption('W', options);
  353.     if (classifierName.length() == 0) {
  354.       throw new Exception("A classifier must be specified with"
  355.   + " the -W option.");
  356.     }
  357.     setDistributionClassifier((DistributionClassifier)Classifier.
  358.   forName(classifierName,
  359.   Utils.partitionOptions(options)));
  360.   }
  361.   /**
  362.    * Gets the current settings of the Classifier.
  363.    *
  364.    * @return an array of strings suitable for passing to setOptions
  365.    */
  366.   public String [] getOptions() {
  367.     String [] classifierOptions = new String [0];
  368.     if ((m_Classifier != null) && 
  369. (m_Classifier instanceof OptionHandler)) {
  370.       classifierOptions = ((OptionHandler)m_Classifier).getOptions();
  371.     }
  372.     int current = 0;
  373.     String [] options = new String [classifierOptions.length + 13];
  374.     options[current++] = "-C"; options[current++] = "" + (m_DesignatedClass + 1);
  375.     options[current++] = "-X"; options[current++] = "" + getNumXValFolds();
  376.     options[current++] = "-S"; options[current++] = "" + getSeed();
  377.     if (getDistributionClassifier() != null) {
  378.       options[current++] = "-W";
  379.       options[current++] = getDistributionClassifier().getClass().getName();
  380.     }
  381.     options[current++] = "-E"; options[current++] = "" + m_EvalMode;
  382.     options[current++] = "-R"; options[current++] = "" + m_RangeMode;
  383.     options[current++] = "--";
  384.     System.arraycopy(classifierOptions, 0, options, current, 
  385.      classifierOptions.length);
  386.     current += classifierOptions.length;
  387.     while (current < options.length) {
  388.       options[current++] = "";
  389.     }
  390.     return options;
  391.   }
  392.   /**
  393.    * Generates the classifier.
  394.    *
  395.    * @param instances set of instances serving as training data 
  396.    * @exception Exception if the classifier has not been generated successfully
  397.    */
  398.   public void buildClassifier(Instances instances) 
  399.     throws Exception {
  400.     if (instances.numClasses() > 2) {
  401.       throw new UnsupportedClassTypeException("Only works for two-class datasets!");
  402.     }
  403.     if (!instances.classAttribute().isNominal()) {
  404.       throw new UnsupportedClassTypeException("Class attribute must be nominal!");
  405.     }
  406.     AttributeStats stats = instances.attributeStats(instances.classIndex());
  407.     m_BestThreshold = 0.5;
  408.     m_BestValue = MIN_VALUE;
  409.     m_HighThreshold = 1;
  410.     m_LowThreshold = 0;
  411.     // If data contains only one instance of positive data
  412.     // optimize on training data
  413.     if (stats.distinctCount != 2) {
  414.       System.err.println("Couldn't find examples of both classes. No adjustment.");
  415.       m_Classifier.buildClassifier(instances);
  416.     } else {
  417.       
  418.       // Determine which class value to look for
  419.       switch (m_ClassMode) {
  420.       case OPTIMIZE_0:
  421.         m_DesignatedClass = 0;
  422.         break;
  423.       case OPTIMIZE_1:
  424.         m_DesignatedClass = 1;
  425.         break;
  426.       case OPTIMIZE_POS_NAME:
  427.         Attribute cAtt = instances.classAttribute();
  428.         boolean found = false;
  429.         for (int i = 0; i < cAtt.numValues() && !found; i++) {
  430.           String name = cAtt.value(i).toLowerCase();
  431.           if (name.startsWith("yes") || name.equals("1") || 
  432.               name.startsWith("pos")) {
  433.             found = true;
  434.             m_DesignatedClass = i;
  435.           }
  436.         }
  437.         if (found) {
  438.           break;
  439.         }
  440.         // No named class found, so fall through to default of least frequent
  441.       case OPTIMIZE_LFREQ:
  442.         m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 1 : 0;
  443.         break;
  444.       case OPTIMIZE_MFREQ:
  445.         m_DesignatedClass = (stats.nominalCounts[0] > stats.nominalCounts[1]) ? 0 : 1;
  446.         break;
  447.       default:
  448.         throw new Exception("Unrecognized class value selection mode");
  449.       }
  450.       
  451.       /*
  452.         System.err.println("ThresholdSelector: Using mode=" 
  453.         + TAGS_OPTIMIZE[m_ClassMode].getReadable());
  454.         System.err.println("ThresholdSelector: Optimizing using class "
  455.         + m_DesignatedClass + "/" 
  456.         + instances.classAttribute().value(m_DesignatedClass));
  457.       */
  458.       
  459.       
  460.       if (stats.nominalCounts[m_DesignatedClass] == 1) {
  461.         System.err.println("Only 1 positive found: optimizing on training data");
  462.         findThreshold(getPredictions(instances, EVAL_TRAINING_SET, 0));
  463.       } else {
  464.         int numFolds = Math.min(m_NumXValFolds, stats.nominalCounts[m_DesignatedClass]);
  465.         //System.err.println("Number of folds for threshold selector: " + numFolds);
  466.         findThreshold(getPredictions(instances, m_EvalMode, numFolds));
  467.         if (m_EvalMode != EVAL_TRAINING_SET) {
  468.           m_Classifier.buildClassifier(instances);
  469.         }
  470.       }
  471.     }
  472.   }
  473.   /**
  474.    * Checks whether instance of designated class is in subset.
  475.    */
  476.   private boolean checkForInstance(Instances data) throws Exception {
  477.     for (int i = 0; i < data.numInstances(); i++) {
  478.       if (((int)data.instance(i).classValue()) == m_DesignatedClass) {
  479. return true;
  480.       }
  481.     }
  482.     return false;
  483.   }
  484.   /**
  485.    * Calculates the class membership probabilities for the given test instance.
  486.    *
  487.    * @param instance the instance to be classified
  488.    * @return predicted class probability distribution
  489.    * @exception Exception if instance could not be classified
  490.    * successfully
  491.    */
  492.   public double [] distributionForInstance(Instance instance) 
  493.     throws Exception {
  494.     
  495.     double [] pred = m_Classifier.distributionForInstance(instance);
  496.     double prob = pred[m_DesignatedClass];
  497.     // Warp probability
  498.     if (prob > m_BestThreshold) {
  499.       prob = 0.5 + (prob - m_BestThreshold) / 
  500.         ((m_HighThreshold - m_BestThreshold) * 2);
  501.     } else {
  502.       prob = (prob - m_LowThreshold) / 
  503.         ((m_BestThreshold - m_LowThreshold) * 2);
  504.     }
  505.     if (prob < 0) {
  506.       prob = 0.0;
  507.     } else if (prob > 1) {
  508.       prob = 1.0;
  509.     }
  510.     // Alter the distribution
  511.     pred[m_DesignatedClass] = prob;
  512.     if (pred.length == 2) { // Handle case when there's only one class
  513.       pred[(m_DesignatedClass + 1) % 2] = 1.0 - prob;
  514.     }
  515.     return pred;
  516.   }
  517.   /**
  518.    * @return a description of the classifier suitable for
  519.    * displaying in the explorer/experimenter gui
  520.    */
  521.   public String globalInfo() {
  522.     return "A metaclassifier that selecting a mid-point threshold on the "
  523.       + "probability output by a DistributionClassifier. The midpoint "
  524.       + "threshold is set so that a given performance measure is optimized. "
  525.       + "Currently this is the F-measure. Performance is measured either on "
  526.       + "the training data, a hold-out set or using cross-validation. In "
  527.       + "addition, the probabilities returned by the base learner can "
  528.       + "have their range expanded so that the output probabilities will "
  529.       + "reside between 0 and 1 (this is useful if the scheme normally "
  530.       + "produces probabilities in a very narrow range).";
  531.   }
  532.     
  533.   /**
  534.    * @return tip text for this property suitable for
  535.    * displaying in the explorer/experimenter gui
  536.    */
  537.   public String designatedClassTipText() {
  538.     return "Sets the class value for which the optimization is performed. "
  539.       + "The options are: pick the first class value; pick the second "
  540.       + "class value; pick whichever class is least frequent; pick whichever "
  541.       + "class value is most frequent; pick the first class named any of "
  542.       + ""yes","pos(itive)", "1", or the least frequent if no matches).";
  543.   }
  544.   /**
  545.    * Gets the method to determine which class value to optimize. Will
  546.    * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
  547.    * OPTIMIZE_POS_NAME.
  548.    *
  549.    * @return the class selection mode.
  550.    */
  551.   public SelectedTag getDesignatedClass() {
  552.     return new SelectedTag(m_ClassMode, TAGS_OPTIMIZE);
  553.   }
  554.   
  555.   /**
  556.    * Sets the method to determine which class value to optimize. Will
  557.    * be one of OPTIMIZE_0, OPTIMIZE_1, OPTIMIZE_LFREQ, OPTIMIZE_MFREQ,
  558.    * OPTIMIZE_POS_NAME.
  559.    *
  560.    * @param newMethod the new class selection mode.
  561.    */
  562.   public void setDesignatedClass(SelectedTag newMethod) {
  563.     
  564.     if (newMethod.getTags() == TAGS_OPTIMIZE) {
  565.       m_ClassMode = newMethod.getSelectedTag().getID();
  566.     }
  567.   }
  568.   /**
  569.    * @return tip text for this property suitable for
  570.    * displaying in the explorer/experimenter gui
  571.    */
  572.   public String evaluationModeTipText() {
  573.     return "Sets the method used to determine the threshold/performance "
  574.       + "curve. The options are: perform optimization based on the entire "
  575.       + "training set (may result in overfitting); perform an n-fold "
  576.       + "cross-validation (may be time consuming); perform one fold of "
  577.       + "an n-fold cross-validation (faster but likely less accurate).";
  578.   }
  579.   /**
  580.    * Sets the evaluation mode used. Will be one of
  581.    * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
  582.    *
  583.    * @param newMethod the new evaluation mode.
  584.    */
  585.   public void setEvaluationMode(SelectedTag newMethod) {
  586.     
  587.     if (newMethod.getTags() == TAGS_EVAL) {
  588.       m_EvalMode = newMethod.getSelectedTag().getID();
  589.     }
  590.   }
  591.   /**
  592.    * Gets the evaluation mode used. Will be one of
  593.    * EVAL_TRAINING, EVAL_TUNED_SPLIT, or EVAL_CROSS_VALIDATION
  594.    *
  595.    * @return the evaluation mode.
  596.    */
  597.   public SelectedTag getEvaluationMode() {
  598.     return new SelectedTag(m_EvalMode, TAGS_EVAL);
  599.   }
  600.   /**
  601.    * @return tip text for this property suitable for
  602.    * displaying in the explorer/experimenter gui
  603.    */
  604.   public String rangeCorrectionTipText() {
  605.     return "Sets the type of prediction range correction performed. "
  606.       + "The options are: do not do any range correction; "
  607.       + "expand predicted probabilities so that the minimum probability "
  608.       + "observed during the optimization maps to 0, and the maximum "
  609.       + "maps to 1 (values outside this range are clipped to 0 and 1).";
  610.   }
  611.   /**
  612.    * Sets the confidence range correction mode used. Will be one of
  613.    * RANGE_NONE, or RANGE_BOUNDS
  614.    *
  615.    * @param newMethod the new correciton mode.
  616.    */
  617.   public void setRangeCorrection(SelectedTag newMethod) {
  618.     
  619.     if (newMethod.getTags() == TAGS_RANGE) {
  620.       m_RangeMode = newMethod.getSelectedTag().getID();
  621.     }
  622.   }
  623.   /**
  624.    * Gets the confidence range correction mode used. Will be one of
  625.    * RANGE_NONE, or RANGE_BOUNDS
  626.    *
  627.    * @return the confidence correction mode.
  628.    */
  629.   public SelectedTag getRangeCorrection() {
  630.     return new SelectedTag(m_RangeMode, TAGS_RANGE);
  631.   }
  632.   
  633.   /**
  634.    * @return tip text for this property suitable for
  635.    * displaying in the explorer/experimenter gui
  636.    */
  637.   public String seedTipText() {
  638.     return "Sets the seed used for randomization. This is used when "
  639.       + "randomizing the data during optimization.";
  640.   }
  641.   /**
  642.    * Sets the seed for random number generation.
  643.    *
  644.    * @param seed the random number seed
  645.    */
  646.   public void setSeed(int seed) {
  647.     
  648.     m_Seed = seed;
  649.   }
  650.   /**
  651.    * Gets the random number seed.
  652.    * 
  653.    * @return the random number seed
  654.    */
  655.   public int getSeed() {
  656.     return m_Seed;
  657.   }
  658.   /**
  659.    * @return tip text for this property suitable for
  660.    * displaying in the explorer/experimenter gui
  661.    */
  662.   public String numXValFoldsTipText() {
  663.     return "Sets the number of folds used during full cross-validation "
  664.       + "and tuned fold evaluation. This number will be automatically "
  665.       + "reduced if there are insufficient positive examples.";
  666.   }
  667.   /**
  668.    * Get the number of folds used for cross-validation.
  669.    *
  670.    * @return the number of folds used for cross-validation.
  671.    */
  672.   public int getNumXValFolds() {
  673.     
  674.     return m_NumXValFolds;
  675.   }
  676.   
  677.   /**
  678.    * Set the number of folds used for cross-validation.
  679.    *
  680.    * @param newNumFolds the number of folds used for cross-validation.
  681.    */
  682.   public void setNumXValFolds(int newNumFolds) {
  683.     
  684.     if (newNumFolds < 2) {
  685.       throw new IllegalArgumentException("Number of folds must be greater than 1");
  686.     }
  687.     m_NumXValFolds = newNumFolds;
  688.   }
  689.   /**
  690.    * @return tip text for this property suitable for
  691.    * displaying in the explorer/experimenter gui
  692.    */
  693.   public String distributionClassifierTipText() {
  694.     return "Sets the base DistributionClassifier to which the optimization "
  695.       + "will be made.";
  696.   }
  697.   /**
  698.    * Set the DistributionClassifier for which threshold is set. 
  699.    *
  700.    * @param newClassifier the Classifier to use.
  701.    */
  702.   public void setDistributionClassifier(DistributionClassifier newClassifier) {
  703.     m_Classifier = newClassifier;
  704.   }
  705.   /**
  706.    * Get the DistributionClassifier used as the classifier.
  707.    *
  708.    * @return the classifier used as the classifier
  709.    */
  710.   public DistributionClassifier getDistributionClassifier() {
  711.     return m_Classifier;
  712.   }
  713.  
  714.   /**
  715.    * Returns description of the cross-validated classifier.
  716.    *
  717.    * @return description of the cross-validated classifier as a string
  718.    */
  719.   public String toString() {
  720.     if (m_BestValue == -Double.MAX_VALUE)
  721.       return "ThresholdSelector: No model built yet.";
  722.     String result = "Threshold Selector.n"
  723.     + "Classifier: " + m_Classifier.getClass().getName() + "n";
  724.     result += "Index of designated class: " + m_DesignatedClass + "n";
  725.     result += "Evaluation mode: ";
  726.     switch (m_EvalMode) {
  727.     case EVAL_CROSS_VALIDATION:
  728.       result += m_NumXValFolds + "-fold cross-validation";
  729.       break;
  730.     case EVAL_TUNED_SPLIT:
  731.       result += "tuning on 1/" + m_NumXValFolds + " of the data";
  732.       break;
  733.     case EVAL_TRAINING_SET:
  734.     default:
  735.       result += "tuning on the training data";
  736.     }
  737.     result += "n";
  738.     result += "Threshold: " + m_BestThreshold + "n";
  739.     result += "Best value: " + m_BestValue + "n";
  740.     if (m_RangeMode == RANGE_BOUNDS) {
  741.       result += "Expanding range [" + m_LowThreshold + "," + m_HighThreshold
  742.         + "] to [0, 1]n";
  743.     }
  744.     result += m_Classifier.toString();
  745.     return result;
  746.   }
  747.   
  748.   /**
  749.    * Main method for testing this class.
  750.    *
  751.    * @param argv the options
  752.    */
  753.   public static void main(String [] argv) {
  754.     try {
  755.       System.out.println(Evaluation.evaluateModel(new ThresholdSelector(), 
  756.   argv));
  757.     } catch (Exception e) {
  758.       System.err.println(e.getMessage());
  759.     }
  760.   }
  761. }