MultiScheme.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 12k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    MultiScheme.java
  18.  *    Copyright (C) 1999 Len Trigg
  19.  *
  20.  */
  21. package weka.classifiers.meta;
  22. import weka.classifiers.Evaluation;
  23. import weka.classifiers.Classifier;
  24. import weka.classifiers.rules.ZeroR;
  25. import java.io.*;
  26. import java.util.*;
  27. import weka.core.*;
  28. /**
  29.  * Class for selecting a classifier from among several using cross 
  30.  * validation on the training data.<p>
  31.  *
  32.  * Valid options from the command line are:<p>
  33.  *
  34.  * -D <br>
  35.  * Turn on debugging output.<p>
  36.  *
  37.  * -S seed <br>
  38.  * Random number seed (default 1).<p>
  39.  *
  40.  * -B classifierstring <br>
  41.  * Classifierstring should contain the full class name of a scheme
  42.  * included for selection followed by options to the classifier
  43.  * (required, option should be used once for each classifier).<p>
  44.  *
  45.  * -X num_folds <br>
  46.  * Use cross validation error as the basis for classifier selection.
  47.  * (default 0, is to use error on the training data instead)<p>
  48.  *
  49.  * @author Len Trigg (trigg@cs.waikato.ac.nz)
  50.  * @version $Revision: 1.11 $
  51.  */
  52. public class MultiScheme extends Classifier implements OptionHandler {
  53.   /** The classifier that had the best performance on training data. */
  54.   protected Classifier m_Classifier;
  55.  
  56.   /** The list of classifiers */
  57.   protected Classifier [] m_Classifiers = {
  58.      new weka.classifiers.rules.ZeroR()
  59.   };
  60.   /** The index into the vector for the selected scheme */
  61.   protected int m_ClassifierIndex;
  62.   /**
  63.    * Number of folds to use for cross validation (0 means use training
  64.    * error for selection)
  65.    */
  66.   protected int m_NumXValFolds;
  67.   /** Debugging mode, gives extra output if true */
  68.   protected boolean m_Debug;
  69.   /** Random number seed */
  70.   protected int m_Seed = 1;
  71.   /**
  72.    * Returns an enumeration describing the available options.
  73.    *
  74.    * @return an enumeration of all the available options.
  75.    */
  76.   public Enumeration listOptions() {
  77.     Vector newVector = new Vector(4);
  78.     newVector.addElement(new Option(
  79.       "tTurn on debugging output.",
  80.       "D", 0, "-D"));
  81.     newVector.addElement(new Option(
  82.       "tFull class name of classifier to include, followedn"
  83.       + "tby scheme options. May be specified multiple times,n"
  84.       + "trequired at least twice.n"
  85.       + "teg: "weka.classifiers.bayes.NaiveBayes -D"",
  86.       "B", 1, "-B <classifier specification>"));
  87.     newVector.addElement(new Option(
  88.       "tSets the random number seed (default 1).",
  89.       "S", 1, "-S <random number seed>"));
  90.     newVector.addElement(new Option(
  91.       "tUse cross validation for model selection using then"
  92.       + "tgiven number of folds. (default 0, is ton"
  93.       + "tuse training error)",
  94.       "X", 1, "-X <number of folds>"));
  95.     return newVector.elements();
  96.   }
  97.   /**
  98.    * Parses a given list of options. Valid options are:<p>
  99.    *
  100.    * -D <br>
  101.    * Turn on debugging output.<p>
  102.    *
  103.    * -S seed <br>
  104.    * Random number seed (default 1).<p>
  105.    *
  106.    * -B classifierstring <br>
  107.    * Classifierstring should contain the full class name of a scheme
  108.    * included for selection followed by options to the classifier
  109.    * (required, option should be used once for each classifier).<p>
  110.    *
  111.    * -X num_folds <br>
  112.    * Use cross validation error as the basis for classifier selection.
  113.    * (default 0, is to use error on the training data instead)<p>
  114.    *
  115.    * @param options the list of options as an array of strings
  116.    * @exception Exception if an option is not supported
  117.    */
  118.   public void setOptions(String[] options) throws Exception {
  119.     setDebug(Utils.getFlag('D', options));
  120.     
  121.     String numFoldsString = Utils.getOption('X', options);
  122.     if (numFoldsString.length() != 0) {
  123.       setNumFolds(Integer.parseInt(numFoldsString));
  124.     } else {
  125.       setNumFolds(0);
  126.     }
  127.     
  128.     String randomString = Utils.getOption('S', options);
  129.     if (randomString.length() != 0) {
  130.       setSeed(Integer.parseInt(randomString));
  131.     } else {
  132.       setSeed(1);
  133.     }
  134.     // Iterate through the schemes
  135.     FastVector classifiers = new FastVector();
  136.     while (true) {
  137.       String classifierString = Utils.getOption('B', options);
  138.       if (classifierString.length() == 0) {
  139. break;
  140.       }
  141.       String [] classifierSpec = Utils.splitOptions(classifierString);
  142.       if (classifierSpec.length == 0) {
  143. throw new Exception("Invalid classifier specification string");
  144.       }
  145.       String classifierName = classifierSpec[0];
  146.       classifierSpec[0] = "";
  147.       classifiers.addElement(Classifier.forName(classifierName,
  148. classifierSpec));
  149.     }
  150.     if (classifiers.size() <= 1) {
  151.       throw new Exception("At least two classifiers must be specified"
  152.   + " with the -B option.");
  153.     } else {
  154.       Classifier [] classifiersArray = new Classifier [classifiers.size()];
  155.       for (int i = 0; i < classifiersArray.length; i++) {
  156. classifiersArray[i] = (Classifier) classifiers.elementAt(i);
  157.       }
  158.       setClassifiers(classifiersArray);
  159.     }
  160.     
  161.   }
  162.   /**
  163.    * Gets the current settings of the Classifier.
  164.    *
  165.    * @return an array of strings suitable for passing to setOptions
  166.    */
  167.   public String [] getOptions() {
  168.     String [] options = new String [5];
  169.     int current = 0;
  170.     if (m_Classifiers.length != 0) {
  171.       options = new String [m_Classifiers.length * 2 + 5];
  172.       for (int i = 0; i < m_Classifiers.length; i++) {
  173. options[current++] = "-B";
  174. options[current++] = "" + getClassifierSpec(i);
  175.       }
  176.     }
  177.     if (getNumFolds() > 1) {
  178.       options[current++] = "-X"; options[current++] = "" + getNumFolds();
  179.     }
  180.     options[current++] = "-S"; options[current++] = "" + getSeed();
  181.     if (getDebug()) {
  182.       options[current++] = "-D";
  183.     }
  184.     while (current < options.length) {
  185.       options[current++] = "";
  186.     }
  187.     return options;
  188.   }
  189.   /**
  190.    * Sets the list of possible classifers to choose from.
  191.    *
  192.    * @param classifiers an array of classifiers with all options set.
  193.    */
  194.   public void setClassifiers(Classifier [] classifiers) {
  195.     m_Classifiers = classifiers;
  196.   }
  197.   /**
  198.    * Gets the list of possible classifers to choose from.
  199.    *
  200.    * @return the array of Classifiers
  201.    */
  202.   public Classifier [] getClassifiers() {
  203.     return m_Classifiers;
  204.   }
  205.   
  206.   /**
  207.    * Gets a single classifier from the set of available classifiers.
  208.    *
  209.    * @param index the index of the classifier wanted
  210.    * @return the Classifier
  211.    */
  212.   public Classifier getClassifier(int index) {
  213.     return m_Classifiers[index];
  214.   }
  215.   
  216.   /**
  217.    * Gets the classifier specification string, which contains the class name of
  218.    * the classifier and any options to the classifier
  219.    *
  220.    * @param index the index of the classifier string to retrieve, starting from
  221.    * 0.
  222.    * @return the classifier string, or the empty string if no classifier
  223.    * has been assigned (or the index given is out of range).
  224.    */
  225.   protected String getClassifierSpec(int index) {
  226.     
  227.     if (m_Classifiers.length < index) {
  228.       return "";
  229.     }
  230.     Classifier c = getClassifier(index);
  231.     if (c instanceof OptionHandler) {
  232.       return c.getClass().getName() + " "
  233. + Utils.joinOptions(((OptionHandler)c).getOptions());
  234.     }
  235.     return c.getClass().getName();
  236.   }
  237.   /**
  238.    * Sets the seed for random number generation.
  239.    *
  240.    * @param seed the random number seed
  241.    */
  242.   public void setSeed(int seed) {
  243.     
  244.     m_Seed = seed;;
  245.   }
  246.   /**
  247.    * Gets the random number seed.
  248.    * 
  249.    * @return the random number seed
  250.    */
  251.   public int getSeed() {
  252.     return m_Seed;
  253.   }
  254.   /** 
  255.    * Gets the number of folds for cross-validation. A number less
  256.    * than 2 specifies using training error rather than cross-validation.
  257.    *
  258.    * @return the number of folds for cross-validation
  259.    */
  260.   public int getNumFolds() {
  261.     return m_NumXValFolds;
  262.   }
  263.   /**
  264.    * Sets the number of folds for cross-validation. A number less
  265.    * than 2 specifies using training error rather than cross-validation.
  266.    *
  267.    * @param numFolds the number of folds for cross-validation
  268.    */
  269.   public void setNumFolds(int numFolds) {
  270.     
  271.     m_NumXValFolds = numFolds;
  272.   }
  273.   /**
  274.    * Set debugging mode
  275.    *
  276.    * @param debug true if debug output should be printed
  277.    */
  278.   public void setDebug(boolean debug) {
  279.     m_Debug = debug;
  280.   }
  281.   /**
  282.    * Get whether debugging is turned on
  283.    *
  284.    * @return true if debugging output is on
  285.    */
  286.   public boolean getDebug() {
  287.     return m_Debug;
  288.   }
  289.   /**
  290.    * Buildclassifier selects a classifier from the set of classifiers
  291.    * by minimising error on the training data.
  292.    *
  293.    * @param data the training data to be used for generating the
  294.    * boosted classifier.
  295.    * @exception Exception if the classifier could not be built successfully
  296.    */
  297.   public void buildClassifier(Instances data) throws Exception {
  298.     if (m_Classifiers.length == 0) {
  299.       throw new Exception("No base classifiers have been set!");
  300.     }
  301.     Instances newData = new Instances(data);
  302.     newData.deleteWithMissingClass();
  303.     newData.randomize(new Random(m_Seed));
  304.     if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1))
  305.       newData.stratify(m_NumXValFolds);
  306.     Instances train = newData;               // train on all data by default
  307.     Instances test = newData;               // test on training data by default
  308.     Classifier bestClassifier = null;
  309.     int bestIndex = -1;
  310.     double bestPerformance = Double.NaN;
  311.     int numClassifiers = m_Classifiers.length;
  312.     for (int i = 0; i < numClassifiers; i++) {
  313.       Classifier currentClassifier = getClassifier(i);
  314.       Evaluation evaluation;
  315.       if (m_NumXValFolds > 1) {
  316. evaluation = new Evaluation(newData);
  317. for (int j = 0; j < m_NumXValFolds; j++) {
  318.   train = newData.trainCV(m_NumXValFolds, j);
  319.   test = newData.testCV(m_NumXValFolds, j);
  320.   currentClassifier.buildClassifier(train);
  321.   evaluation.setPriors(train);
  322.   evaluation.evaluateModel(currentClassifier, test);
  323. }
  324.       } else {
  325. currentClassifier.buildClassifier(train);
  326. evaluation = new Evaluation(train);
  327. evaluation.evaluateModel(currentClassifier, test);
  328.       }
  329.       double error = evaluation.errorRate();
  330.       if (m_Debug) {
  331. System.err.println("Error rate: " + Utils.doubleToString(error, 6, 4)
  332.    + " for classifier "
  333.    + currentClassifier.getClass().getName());
  334.       }
  335.       if ((i == 0) || (error < bestPerformance)) {
  336. bestClassifier = currentClassifier;
  337. bestPerformance = error;
  338. bestIndex = i;
  339.       }
  340.     }
  341.     m_ClassifierIndex = bestIndex;
  342.     m_Classifier = bestClassifier;
  343.     if (m_NumXValFolds > 1) {
  344.       m_Classifier.buildClassifier(newData);
  345.     }
  346.   }
  347.   /**
  348.    * Classifies a given instance using the selected classifier.
  349.    *
  350.    * @param instance the instance to be classified
  351.    * @exception Exception if instance could not be classified
  352.    * successfully
  353.    */
  354.   public double classifyInstance(Instance instance) throws Exception {
  355.     return m_Classifier.classifyInstance(instance);
  356.   }
  357.   /**
  358.    * Output a representation of this classifier
  359.    */
  360.   public String toString() {
  361.     if (m_Classifier == null) {
  362.       return "MultiScheme: No model built yet.";
  363.     }
  364.     String result = "MultiScheme selection using";
  365.     if (m_NumXValFolds > 1) {
  366.       result += " cross validation error";
  367.     } else {
  368.       result += " error on training data";
  369.     }
  370.     result += " from the following:n";
  371.     for (int i = 0; i < m_Classifiers.length; i++) {
  372.       result += 't' + getClassifierSpec(i) + 'n';
  373.     }
  374.     result += "Selected scheme: "
  375.       + getClassifierSpec(m_ClassifierIndex)
  376.       + "nn"
  377.       + m_Classifier.toString();
  378.     return result;
  379.   }
  380.   /**
  381.    * Main method for testing this class.
  382.    *
  383.    * @param argv should contain the following arguments:
  384.    * -t training file [-T test file] [-c class index]
  385.    */
  386.   public static void main(String [] argv) {
  387.     try {
  388.       System.out.println(Evaluation.evaluateModel(new MultiScheme(), argv));
  389.     } catch (Exception e) {
  390.       System.err.println(e.getMessage());
  391.     }
  392.   }
  393. }