StratifiedRemoveFolds.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 10k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    StratifiedRemoveFolds.java
  18.  *    Copyright (C) 1999 Eibe Frank
  19.  *
  20.  */
  21. package weka.filters.supervised.instance;
  22. import weka.filters.*;
  23. import weka.core.*;
  24. import java.util.*;
  25. /**
  26.  * This filter takes a dataset and outputs folds suitable for cross validation.
  27.  * If you do not want the folds to be stratified then use the unsupervised 
  28.  * version.
  29.  *
  30.  * Valid options are: <p>
  31.  *
  32.  * -V <br>
  33.  * Specifies if inverse of selection is to be output.<p>
  34.  *
  35.  * -N number of folds <br>
  36.  * Specifies number of folds dataset is split into (default 10). <p>
  37.  *
  38.  * -F fold <br>
  39.  * Specifies which fold is selected. (default 1)<p>
  40.  *
  41.  * -S seed <br>
  42.  * Specifies a random number seed for shuffling the dataset.
  43.  * (default 0, don't randomize)<p>
  44.  *
  45.  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
  46.  * @version $Revision: 1.1 $ 
  47. */
  48. public class StratifiedRemoveFolds extends Filter implements SupervisedFilter,
  49.       OptionHandler {
  50.   /** Indicates if inverse of selection is to be output. */
  51.   private boolean m_Inverse = false;
  52.   /** Number of folds to split dataset into */
  53.   private int m_NumFolds = 10;
  54.   /** Fold to output */
  55.   private int m_Fold = 1;
  56.   /** Random number seed. */
  57.   private long m_Seed = 0;
  58.   /**
  59.    * Gets an enumeration describing the available options..
  60.    *
  61.    * @return an enumeration of all the available options.
  62.    */
  63.   public Enumeration listOptions() {
  64.     Vector newVector = new Vector(6);
  65.     newVector.addElement(new Option(
  66.       "tSpecifies if inverse of selection is to be output.n",
  67.       "V", 0, "-V"));
  68.     newVector.addElement(new Option(
  69.               "tSpecifies number of folds dataset is split into. n"
  70.       + "t(default 10)n",
  71.               "N", 1, "-N <number of folds>"));
  72.     newVector.addElement(new Option(
  73.       "tSpecifies which fold is selected. (default 1)n",
  74.       "F", 1, "-F <fold>"));
  75.     newVector.addElement(new Option(
  76.       "tSpecifies random number seed. (default 0, no randomizing)n",
  77.       "S", 1, "-S <seed>"));
  78.     return newVector.elements();
  79.   }
  80.   /**
  81.    * Parses the options for this object. Valid options are: <p>
  82.    *
  83.    * -V <br>
  84.    * Specifies if inverse of selection is to be output.<p>
  85.    *
  86.    * -N number of folds <br>
  87.    * Specifies number of folds dataset is split into (default 10). <p>
  88.    *
  89.    * -F fold <br>
  90.    * Specifies which fold is selected. (default 1)<p>
  91.    *
  92.    * -S seed <br>
  93.    * Specifies a random number seed for shuffling the dataset.
  94.    * (default 0, no randomizing)<p>
  95.    *
  96.    * -A <br>
  97.    * If set, data will not be stratified. <p>
  98.    *
  99.    * @param options the list of options as an array of strings
  100.    * @exception Exception if an option is not supported
  101.    */
  102.   public void setOptions(String[] options) throws Exception {
  103.     setInvertSelection(Utils.getFlag('V', options));
  104.     String numFolds = Utils.getOption('N', options);
  105.     if (numFolds.length() != 0) {
  106.       setNumFolds(Integer.parseInt(numFolds));
  107.     } else {
  108.       setNumFolds(10);
  109.     }
  110.     String fold = Utils.getOption('F', options);
  111.     if (fold.length() != 0) {
  112.       setFold(Integer.parseInt(fold));
  113.     } else {
  114.       setFold(1);
  115.     }
  116.     String seed = Utils.getOption('S', options);
  117.     if (seed.length() != 0) {
  118.       setSeed(Integer.parseInt(seed));
  119.     } else {
  120.       setSeed(0);
  121.     }
  122.     if (getInputFormat() != null) {
  123.       setInputFormat(getInputFormat());
  124.     }
  125.   }
  126.   /**
  127.    * Gets the current settings of the filter.
  128.    *
  129.    * @return an array of strings suitable for passing to setOptions
  130.    */
  131.   public String [] getOptions() {
  132.     String [] options = new String [8];
  133.     int current = 0;
  134.     options[current++] = "-S"; options[current++] = "" + getSeed();
  135.     if (getInvertSelection()) {
  136.       options[current++] = "-V";
  137.     }
  138.     options[current++] = "-N"; options[current++] = "" + getNumFolds();
  139.     options[current++] = "-F"; options[current++] = "" + getFold();
  140.     while (current < options.length) {
  141.       options[current++] = "";
  142.     }
  143.     return options;
  144.   }
  145.   /**
  146.    * Returns a string describing this filter
  147.    *
  148.    * @return a description of the filter suitable for
  149.    * displaying in the explorer/experimenter gui
  150.    */
  151.   public String globalInfo() {
  152.     return "This filter takes a dataset and outputs a specified fold for cross validation. If you do not want the folds to be stratified use the unsupervised version.";
  153.   }
  154.   /**
  155.    * Returns the tip text for this property
  156.    *
  157.    * @return tip text for this property suitable for
  158.    * displaying in the explorer/experimenter gui
  159.    */
  160.   public String invertSelectionTipText() {
  161.     return "Whether to invert the selection.";
  162.   }
  163.   /**
  164.    * Gets if selection is to be inverted.
  165.    *
  166.    * @return true if the selection is to be inverted
  167.    */
  168.   public boolean getInvertSelection() {
  169.     return m_Inverse;
  170.   }
  171.   /**
  172.    * Sets if selection is to be inverted.
  173.    *
  174.    * @param inverse true if inversion is to be performed
  175.    */
  176.   public void setInvertSelection(boolean inverse) {
  177.     
  178.     m_Inverse = inverse;
  179.   }
  180.   /**
  181.    * Returns the tip text for this property
  182.    *
  183.    * @return tip text for this property suitable for
  184.    * displaying in the explorer/experimenter gui
  185.    */
  186.   public String numFoldsTipText() {
  187.     return "The number of folds to split the dataset into.";
  188.   }
  189.   /**
  190.    * Gets the number of folds in which dataset is to be split into.
  191.    * 
  192.    * @return the number of folds the dataset is to be split into.
  193.    */
  194.   public int getNumFolds() {
  195.     return m_NumFolds;
  196.   }
  197.   /**
  198.    * Sets the number of folds the dataset is split into. If the number
  199.    * of folds is zero, it won't split it into folds. 
  200.    *
  201.    * @param numFolds number of folds dataset is to be split into
  202.    * @exception IllegalArgumentException if number of folds is negative
  203.    */
  204.   public void setNumFolds(int numFolds) {
  205.     if (numFolds < 0) {
  206.       throw new IllegalArgumentException("Number of folds has to be positive or zero.");
  207.     }
  208.     m_NumFolds = numFolds;
  209.   }
  210.   /**
  211.    * Returns the tip text for this property
  212.    *
  213.    * @return tip text for this property suitable for
  214.    * displaying in the explorer/experimenter gui
  215.    */
  216.   public String foldTipText() {
  217.     return "The fold which is selected.";
  218.   }
  219.   /**
  220.    * Gets the fold which is selected.
  221.    *
  222.    * @return the fold which is selected
  223.    */
  224.   public int getFold() {
  225.     return m_Fold;
  226.   }
  227.   /**
  228.    * Selects a fold.
  229.    *
  230.    * @param fold the fold to be selected.
  231.    * @exception IllegalArgumentException if fold's index is smaller than 1
  232.    */
  233.   public void setFold(int fold) {
  234.     if (fold < 1) {
  235.       throw new IllegalArgumentException("Fold's index has to be greater than 0.");
  236.     }
  237.     m_Fold = fold;
  238.   }
  239.   /**
  240.    * Returns the tip text for this property
  241.    *
  242.    * @return tip text for this property suitable for
  243.    * displaying in the explorer/experimenter gui
  244.    */
  245.   public String seedTipText() {
  246.     return "the random number seed for shuffling the dataset. If seed is negative, shuffling will not be performed.";
  247.   }
  248.   /**
  249.    * Gets the random number seed used for shuffling the dataset.
  250.    *
  251.    * @return the random number seed
  252.    */
  253.   public long getSeed() {
  254.     return m_Seed;
  255.   }
  256.   /**
  257.    * Sets the random number seed for shuffling the dataset. If seed
  258.    * is negative, shuffling won't be performed.
  259.    *
  260.    * @param seed the random number seed
  261.    */
  262.   public void setSeed(long seed) {
  263.     
  264.     m_Seed = seed;
  265.   }
  266.   /**
  267.    * Sets the format of the input instances.
  268.    *
  269.    * @param instanceInfo an Instances object containing the input instance
  270.    * structure (any instances contained in the object are ignored - only the
  271.    * structure is required).
  272.    * @return true because outputFormat can be collected immediately
  273.    * @exception Exception if the input format can't be set successfully
  274.    */  
  275.   public boolean setInputFormat(Instances instanceInfo) throws Exception {
  276.     if ((m_NumFolds > 0) && (m_NumFolds < m_Fold)) {
  277.       throw new IllegalArgumentException("Fold has to be smaller or equal to "+
  278.                                          "number of folds.");
  279.     }
  280.     super.setInputFormat(instanceInfo);
  281.     setOutputFormat(instanceInfo);
  282.     return true;
  283.   }
  284.   /**
  285.    * Signify that this batch of input to the filter is
  286.    * finished. Output() may now be called to retrieve the filtered
  287.    * instances.
  288.    *
  289.    * @return true if there are instances pending output
  290.    * @exception IllegalStateException if no input structure has been defined 
  291.    */
  292.   public boolean batchFinished() {
  293.     if (getInputFormat() == null) {
  294.       throw new IllegalStateException("No input instance format defined");
  295.     }
  296.     if (m_Seed > 0) {
  297.       // User has provided a random number seed.
  298.       getInputFormat().randomize(new Random(m_Seed));
  299.     }
  300.     // Select out a fold
  301.     getInputFormat().stratify(m_NumFolds);
  302.     Instances instances;
  303.     if (!m_Inverse) {
  304.       instances = getInputFormat().testCV(m_NumFolds, m_Fold - 1);
  305.     } else {
  306.       instances = getInputFormat().trainCV(m_NumFolds, m_Fold - 1);
  307.     }
  308.     for (int i = 0; i < instances.numInstances(); i++) {
  309.       push(instances.instance(i));
  310.     }
  311.     m_NewBatch = true;
  312.     return (numPendingOutput() != 0);
  313.   }
  314.   /**
  315.    * Main method for testing this class.
  316.    *
  317.    * @param argv should contain arguments to the filter: use -h for help
  318.    */
  319.   public static void main(String [] argv) {
  320.     try {
  321.       if (Utils.getFlag('b', argv)) {
  322.   Filter.batchFilterFile(new StratifiedRemoveFolds(), argv);
  323.       } else {
  324. Filter.filterFile(new StratifiedRemoveFolds(), argv);
  325.       }
  326.     } catch (Exception ex) {
  327.       System.out.println(ex.getMessage());
  328.     }
  329.   }
  330. }