MultiScheme.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 12k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    MultiScheme.java
  18.  *    Copyright (C) 1999 Len Trigg
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. /**
  26.  * Class for selecting a classifier from among several using cross 
  27.  * validation on the training data.<p>
  28.  *
  29.  * Valid options from the command line are:<p>
  30.  *
  31.  * -D <br>
  32.  * Turn on debugging output.<p>
  33.  *
  34.  * -S seed <br>
  35.  * Random number seed (default 1).<p>
  36.  *
  37.  * -B classifierstring <br>
  38.  * Classifierstring should contain the full class name of a scheme
  39.  * included for selection followed by options to the classifier
  40.  * (required, option should be used once for each classifier).<p>
  41.  *
  42.  * -X num_folds <br>
  43.  * Use cross validation error as the basis for classifier selection.
  44.  * (default 0, is to use error on the training data instead)<p>
  45.  *
  46.  * @author Len Trigg (trigg@cs.waikato.ac.nz)
  47.  * @version $Revision: 1.9 $
  48.  */
  49. public class MultiScheme extends Classifier implements OptionHandler {
  50.   /** The classifier that had the best performance on training data. */
  51.   protected Classifier m_Classifier;
  52.  
  53.   /** The list of classifiers */
  54.   protected Classifier [] m_Classifiers = {
  55.      new weka.classifiers.ZeroR()
  56.   };
  57.   /** The index into the vector for the selected scheme */
  58.   protected int m_ClassifierIndex;
  59.   /**
  60.    * Number of folds to use for cross validation (0 means use training
  61.    * error for selection)
  62.    */
  63.   protected int m_NumXValFolds;
  64.   /** Debugging mode, gives extra output if true */
  65.   protected boolean m_Debug;
  66.   /** Random number seed */
  67.   protected int m_Seed = 1;
  68.   /**
  69.    * Returns an enumeration describing the available options
  70.    *
  71.    * @return an enumeration of all the available options
  72.    */
  73.   public Enumeration listOptions() {
  74.     Vector newVector = new Vector(4);
  75.     newVector.addElement(new Option(
  76.       "tTurn on debugging output.",
  77.       "D", 0, "-D"));
  78.     newVector.addElement(new Option(
  79.       "tFull class name of classifier to include, followedn"
  80.       + "tby scheme options. May be specified multiple times,n"
  81.       + "trequired at least twice.n"
  82.       + "teg: "weka.classifiers.NaiveBayes -D"",
  83.       "B", 1, "-B <classifier specification>"));
  84.     newVector.addElement(new Option(
  85.       "tSets the random number seed (default 1).",
  86.       "S", 1, "-S <random number seed>"));
  87.     newVector.addElement(new Option(
  88.       "tUse cross validation for model selection using then"
  89.       + "tgiven number of folds. (default 0, is ton"
  90.       + "tuse training error)",
  91.       "X", 1, "-X <number of folds>"));
  92.     return newVector.elements();
  93.   }
  94.   /**
  95.    * Parses a given list of options. Valid options are:<p>
  96.    *
  97.    * -D <br>
  98.    * Turn on debugging output.<p>
  99.    *
  100.    * -S seed <br>
  101.    * Random number seed (default 1).<p>
  102.    *
  103.    * -B classifierstring <br>
  104.    * Classifierstring should contain the full class name of a scheme
  105.    * included for selection followed by options to the classifier
  106.    * (required, option should be used once for each classifier).<p>
  107.    *
  108.    * -X num_folds <br>
  109.    * Use cross validation error as the basis for classifier selection.
  110.    * (default 0, is to use error on the training data instead)<p>
  111.    *
  112.    * @param options the list of options as an array of strings
  113.    * @exception Exception if an option is not supported
  114.    */
  115.   public void setOptions(String[] options) throws Exception {
  116.     setDebug(Utils.getFlag('D', options));
  117.     
  118.     String numFoldsString = Utils.getOption('X', options);
  119.     if (numFoldsString.length() != 0) {
  120.       setNumFolds(Integer.parseInt(numFoldsString));
  121.     } else {
  122.       setNumFolds(0);
  123.     }
  124.     
  125.     String randomString = Utils.getOption('S', options);
  126.     if (randomString.length() != 0) {
  127.       setSeed(Integer.parseInt(randomString));
  128.     } else {
  129.       setSeed(1);
  130.     }
  131.     // Iterate through the schemes
  132.     FastVector classifiers = new FastVector();
  133.     while (true) {
  134.       String classifierString = Utils.getOption('B', options);
  135.       if (classifierString.length() == 0) {
  136. break;
  137.       }
  138.       String [] classifierSpec = Utils.splitOptions(classifierString);
  139.       if (classifierSpec.length == 0) {
  140. throw new Exception("Invalid classifier specification string");
  141.       }
  142.       String classifierName = classifierSpec[0];
  143.       classifierSpec[0] = "";
  144.       classifiers.addElement(Classifier.forName(classifierName,
  145. classifierSpec));
  146.     }
  147.     if (classifiers.size() <= 1) {
  148.       throw new Exception("At least two classifiers must be specified"
  149.   + " with the -B option.");
  150.     } else {
  151.       Classifier [] classifiersArray = new Classifier [classifiers.size()];
  152.       for (int i = 0; i < classifiersArray.length; i++) {
  153. classifiersArray[i] = (Classifier) classifiers.elementAt(i);
  154.       }
  155.       setClassifiers(classifiersArray);
  156.     }
  157.     
  158.   }
  159.   /**
  160.    * Gets the current settings of the Classifier.
  161.    *
  162.    * @return an array of strings suitable for passing to setOptions
  163.    */
  164.   public String [] getOptions() {
  165.     String [] options = new String [5];
  166.     int current = 0;
  167.     if (m_Classifiers.length != 0) {
  168.       options = new String [m_Classifiers.length * 2 + 5];
  169.       for (int i = 0; i < m_Classifiers.length; i++) {
  170. options[current++] = "-B";
  171. options[current++] = "" + getClassifierSpec(i);
  172.       }
  173.     }
  174.     if (getNumFolds() > 1) {
  175.       options[current++] = "-X"; options[current++] = "" + getNumFolds();
  176.     }
  177.     options[current++] = "-S"; options[current++] = "" + getSeed();
  178.     if (getDebug()) {
  179.       options[current++] = "-D";
  180.     }
  181.     while (current < options.length) {
  182.       options[current++] = "";
  183.     }
  184.     return options;
  185.   }
  186.   /**
  187.    * Sets the list of possible classifers to choose from.
  188.    *
  189.    * @param classifiers an array of classifiers with all options set.
  190.    */
  191.   public void setClassifiers(Classifier [] classifiers) {
  192.     m_Classifiers = classifiers;
  193.   }
  194.   /**
  195.    * Gets the list of possible classifers to choose from.
  196.    *
  197.    * @return the array of Classifiers
  198.    */
  199.   public Classifier [] getClassifiers() {
  200.     return m_Classifiers;
  201.   }
  202.   
  203.   /**
  204.    * Gets a single classifier from the set of available classifiers.
  205.    *
  206.    * @param index the index of the classifier wanted
  207.    * @return the Classifier
  208.    */
  209.   public Classifier getClassifier(int index) {
  210.     return m_Classifiers[index];
  211.   }
  212.   
  213.   /**
  214.    * Gets the classifier specification string, which contains the class name of
  215.    * the classifier and any options to the classifier
  216.    *
  217.    * @param index the index of the classifier string to retrieve, starting from
  218.    * 0.
  219.    * @return the classifier string, or the empty string if no classifier
  220.    * has been assigned (or the index given is out of range).
  221.    */
  222.   protected String getClassifierSpec(int index) {
  223.     
  224.     if (m_Classifiers.length < index) {
  225.       return "";
  226.     }
  227.     Classifier c = getClassifier(index);
  228.     if (c instanceof OptionHandler) {
  229.       return c.getClass().getName() + " "
  230. + Utils.joinOptions(((OptionHandler)c).getOptions());
  231.     }
  232.     return c.getClass().getName();
  233.   }
  234.   /**
  235.    * Sets the seed for random number generation.
  236.    *
  237.    * @param seed the random number seed
  238.    */
  239.   public void setSeed(int seed) {
  240.     
  241.     m_Seed = seed;;
  242.   }
  243.   /**
  244.    * Gets the random number seed.
  245.    * 
  246.    * @return the random number seed
  247.    */
  248.   public int getSeed() {
  249.     return m_Seed;
  250.   }
  251.   /** 
  252.    * Gets the number of folds for cross-validation. A number less
  253.    * than 2 specifies using training error rather than cross-validation.
  254.    *
  255.    * @return the number of folds for cross-validation
  256.    */
  257.   public int getNumFolds() {
  258.     return m_NumXValFolds;
  259.   }
  260.   /**
  261.    * Sets the number of folds for cross-validation. A number less
  262.    * than 2 specifies using training error rather than cross-validation.
  263.    *
  264.    * @param numFolds the number of folds for cross-validation
  265.    */
  266.   public void setNumFolds(int numFolds) {
  267.     
  268.     m_NumXValFolds = numFolds;
  269.   }
  270.   /**
  271.    * Set debugging mode
  272.    *
  273.    * @param debug true if debug output should be printed
  274.    */
  275.   public void setDebug(boolean debug) {
  276.     m_Debug = debug;
  277.   }
  278.   /**
  279.    * Get whether debugging is turned on
  280.    *
  281.    * @return true if debugging output is on
  282.    */
  283.   public boolean getDebug() {
  284.     return m_Debug;
  285.   }
  286.   /**
  287.    * Buildclassifier selects a classifier from the set of classifiers
  288.    * by minimising error on the training data.
  289.    *
  290.    * @param data the training data to be used for generating the
  291.    * boosted classifier.
  292.    * @exception Exception if the classifier could not be built successfully
  293.    */
  294.   public void buildClassifier(Instances data) throws Exception {
  295.     if (m_Classifiers.length == 0) {
  296.       throw new Exception("No base classifiers have been set!");
  297.     }
  298.     Instances newData = new Instances(data);
  299.     newData.deleteWithMissingClass();
  300.     newData.randomize(new Random(m_Seed));
  301.     if (newData.classAttribute().isNominal() && (m_NumXValFolds > 1))
  302.       newData.stratify(m_NumXValFolds);
  303.     Instances train = newData;               // train on all data by default
  304.     Instances test = newData;               // test on training data by default
  305.     Classifier bestClassifier = null;
  306.     int bestIndex = -1;
  307.     double bestPerformance = Double.NaN;
  308.     int numClassifiers = m_Classifiers.length;
  309.     for (int i = 0; i < numClassifiers; i++) {
  310.       Classifier currentClassifier = getClassifier(i);
  311.       Evaluation evaluation;
  312.       if (m_NumXValFolds > 1) {
  313. evaluation = new Evaluation(newData);
  314. for (int j = 0; j < m_NumXValFolds; j++) {
  315.   train = newData.trainCV(m_NumXValFolds, j);
  316.   test = newData.testCV(m_NumXValFolds, j);
  317.   currentClassifier.buildClassifier(train);
  318.   evaluation.setPriors(train);
  319.   evaluation.evaluateModel(currentClassifier, test);
  320. }
  321.       } else {
  322. currentClassifier.buildClassifier(train);
  323. evaluation = new Evaluation(train);
  324. evaluation.evaluateModel(currentClassifier, test);
  325.       }
  326.       double error = evaluation.errorRate();
  327.       if (m_Debug) {
  328. System.err.println("Error rate: " + Utils.doubleToString(error, 6, 4)
  329.    + " for classifier "
  330.    + currentClassifier.getClass().getName());
  331.       }
  332.       if ((i == 0) || (error < bestPerformance)) {
  333. bestClassifier = currentClassifier;
  334. bestPerformance = error;
  335. bestIndex = i;
  336.       }
  337.     }
  338.     m_ClassifierIndex = bestIndex;
  339.     m_Classifier = bestClassifier;
  340.     if (m_NumXValFolds > 1) {
  341.       m_Classifier.buildClassifier(newData);
  342.     }
  343.   }
  344.   /**
  345.    * Classifies a given instance using the selected classifier.
  346.    *
  347.    * @param instance the instance to be classified
  348.    * @exception Exception if instance could not be classified
  349.    * successfully
  350.    */
  351.   public double classifyInstance(Instance instance) throws Exception {
  352.     return m_Classifier.classifyInstance(instance);
  353.   }
  354.   /**
  355.    * Output a representation of this classifier
  356.    */
  357.   public String toString() {
  358.     if (m_Classifier == null) {
  359.       return "MultiScheme: No model built yet.";
  360.     }
  361.     String result = "MultiScheme selection using";
  362.     if (m_NumXValFolds > 1) {
  363.       result += " cross validation error";
  364.     } else {
  365.       result += " error on training data";
  366.     }
  367.     result += " from the following:n";
  368.     for (int i = 0; i < m_Classifiers.length; i++) {
  369.       result += 't' + getClassifierSpec(i) + 'n';
  370.     }
  371.     result += "Selected scheme: "
  372.       + getClassifierSpec(m_ClassifierIndex)
  373.       + "nn"
  374.       + m_Classifier.toString();
  375.     return result;
  376.   }
  377.   /**
  378.    * Main method for testing this class.
  379.    *
  380.    * @param argv should contain the following arguments:
  381.    * -t training file [-T test file] [-c class index]
  382.    */
  383.   public static void main(String [] argv) {
  384.     try {
  385.       System.out.println(Evaluation.evaluateModel(new MultiScheme(), argv));
  386.     } catch (Exception e) {
  387.       System.err.println(e.getMessage());
  388.     }
  389.   }
  390. }