CostSensitiveClassifier.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 19k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    CostSensitiveClassifier.java
  18.  *    Copyright (C) 1999 Intelligenesis Corp.
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import java.io.BufferedReader;
  23. import java.io.File;
  24. import java.io.FileReader;
  25. import java.util.Enumeration;
  26. import java.util.Random;
  27. import java.util.Vector;
  28. import weka.core.Instance;
  29. import weka.core.Instances;
  30. import weka.core.Option;
  31. import weka.core.OptionHandler;
  32. import weka.core.SelectedTag;
  33. import weka.core.Tag;
  34. import weka.core.Utils;
  35. import weka.core.WeightedInstancesHandler;
  36. import weka.core.Drawable;
  37. import weka.filters.Filter;
  38. /**
  39.  * This metaclassifier makes its base classifier cost-sensitive. Two methods
  40.  * can be used to introduce cost-sensitivity: reweighting training instances 
  41.  * according to the total cost assigned to each class; or predicting the class
  42.  * with minimum expected misclassification cost (rather than the most likely 
  43.  * class). The minimum expected cost approach requires that the base classifier
  44.  * be a DistributionClassifier. <p>
  45.  *
  46.  * Valid options are:<p>
  47.  *
  48.  * -M <br>
  49.  * Minimize expected misclassification cost. The base classifier must 
  50.  * produce probability estimates i.e. a DistributionClassifier).
  51.  * (default is to reweight training instances according to costs per class)<p>
  52.  *
  53.  * -W classname <br>
  54.  * Specify the full class name of a classifier (required).<p>
  55.  *
  56.  * -C cost file <br>
  57.  * File name of a cost matrix to use. If this is not supplied, a cost
  58.  * matrix will be loaded on demand. The name of the on-demand file
  59.  * is the relation name of the training data plus ".cost", and the
  60.  * path to the on-demand file is specified with the -D option.<p>
  61.  *
  62.  * -D directory <br>
  63.  * Name of a directory to search for cost files when loading costs on demand
  64.  * (default current directory). <p>
  65.  *
  66.  * -S seed <br>
  67.  * Random number seed used when reweighting by resampling (default 1).<p>
  68.  *
  69.  * Options after -- are passed to the designated classifier.<p>
  70.  *
  71.  * @author Len Trigg (len@intelligenesis.net)
  72.  * @version $Revision: 1.10 $
  73.  */
  74. public class CostSensitiveClassifier extends Classifier
  75.   implements OptionHandler, Drawable {
  76.   /* Specify possible sources of the cost matrix */
  77.   public static final int MATRIX_ON_DEMAND = 1;
  78.   public static final int MATRIX_SUPPLIED = 2;
  79.   public static final Tag [] TAGS_MATRIX_SOURCE = {
  80.     new Tag(MATRIX_ON_DEMAND, "Load cost matrix on demand"),
  81.     new Tag(MATRIX_SUPPLIED, "Use explicit cost matrix")
  82.   };
  83.   /** Indicates the current cost matrix source */
  84.   protected int m_MatrixSource = MATRIX_ON_DEMAND;
  85.   /** 
  86.    * The directory used when loading cost files on demand, null indicates
  87.    * current directory 
  88.    */
  89.   protected File m_OnDemandDirectory = new File(System.getProperty("user.dir"));
  90.   /** The name of the cost file, for command line options */
  91.   protected String m_CostFile;
  92.   /** The cost matrix */
  93.   protected CostMatrix m_CostMatrix = new CostMatrix(1);
  94.   /** The classifier */
  95.   protected Classifier m_Classifier = new weka.classifiers.ZeroR();
  96.   /** Seed for reweighting using resampling. */
  97.   protected int m_Seed = 1;
  98.   /** 
  99.    * True if the costs should be used by selecting the minimum expected
  100.    * cost (false means weight training data by the costs)
  101.    */
  102.   protected boolean m_MinimizeExpectedCost;
  103.   
  104.   /**
  105.    * Returns an enumeration describing the available options
  106.    *
  107.    * @return an enumeration of all the available options
  108.    */
  109.   public Enumeration listOptions() {
  110.     Vector newVector = new Vector(5);
  111.     newVector.addElement(new Option(
  112.       "tMinimize expected misclassification cost. Then"
  113.       +"tbase classifier must produce probability estimatesn"
  114.       +"t(i.e. a DistributionClassifier). Default is ton"
  115.       +"treweight training instances according to costs per class",
  116.       "M", 0, "-M"));
  117.     newVector.addElement(new Option(
  118.       "tFull class name of classifier to use. (required)n"
  119.       + "teg: weka.classifiers.NaiveBayes",
  120.       "W", 1, "-W <class name>"));
  121.     newVector.addElement(new Option(
  122.       "tFile name of a cost matrix to use. If this is not supplied,n"
  123.               +"ta cost matrix will be loaded on demand. The name of then"
  124.               +"ton-demand file is the relation name of the training datan"
  125.               +"tplus ".cost", and the path to the on-demand file isn"
  126.               +"tspecified with the -D option.",
  127.       "C", 1, "-C <cost file name>"));
  128.     newVector.addElement(new Option(
  129.               "tName of a directory to search for cost files when loadingn"
  130.               +"tcosts on demand (default current directory).",
  131.               "D", 1, "-D <directory>"));
  132.     newVector.addElement(new Option(
  133.       "tSeed used when reweighting via resampling. (Default 1)",
  134.       "S", 1, "-S <num>"));
  135.     return newVector.elements();
  136.   }
  137.   /**
  138.    * Parses a given list of options. Valid options are:<p>
  139.    *
  140.    * -M <br>
  141.    * Minimize expected misclassification cost. The base classifier must 
  142.    * produce probability estimates i.e. a DistributionClassifier).
  143.    * (default is to reweight training instances according to costs per class)<p>
  144.    *
  145.    * -W classname <br>
  146.    * Specify the full class name of a classifier (required).<p>
  147.    *
  148.    * -C cost file <br>
  149.    * File name of a cost matrix to use. If this is not supplied, a cost
  150.    * matrix will be loaded on demand. The name of the on-demand file
  151.    * is the relation name of the training data plus ".cost", and the
  152.    * path to the on-demand file is specified with the -D option.<p>
  153.    *
  154.    * -D directory <br>
  155.    * Name of a directory to search for cost files when loading costs on demand
  156.    * (default current directory). <p>
  157.    *
  158.    * -S seed <br>
  159.    * Random number seed used when reweighting by resampling (default 1).<p>
  160.    *
  161.    * Options after -- are passed to the designated classifier.<p>
  162.    *
  163.    * @param options the list of options as an array of strings
  164.    * @exception Exception if an option is not supported
  165.    */
  166.   public void setOptions(String[] options) throws Exception {
  167.     setMinimizeExpectedCost(Utils.getFlag('M', options));
  168.     String seedString = Utils.getOption('S', options);
  169.     if (seedString.length() != 0) {
  170.       setSeed(Integer.parseInt(seedString));
  171.     } else {
  172.       setSeed(1);
  173.     }
  174.     String classifierName = Utils.getOption('W', options);
  175.     if (classifierName.length() == 0) {
  176.       throw new Exception("A classifier must be specified with"
  177.   + " the -W option.");
  178.     }
  179.     setClassifier(Classifier.forName(classifierName,
  180.      Utils.partitionOptions(options)));
  181.     String costFile = Utils.getOption('C', options);
  182.     if (costFile.length() != 0) {
  183.       try {
  184. setCostMatrix(new CostMatrix(new BufferedReader(
  185.      new FileReader(costFile))));
  186.       } catch (Exception ex) {
  187. // now flag as possible old format cost matrix. Delay cost matrix
  188. // loading until buildClassifer is called
  189. setCostMatrix(null);
  190.       }
  191.       setCostMatrixSource(new SelectedTag(MATRIX_SUPPLIED,
  192.                                           TAGS_MATRIX_SOURCE));
  193.       m_CostFile = costFile;
  194.     } else {
  195.       setCostMatrixSource(new SelectedTag(MATRIX_ON_DEMAND, 
  196.                                           TAGS_MATRIX_SOURCE));
  197.     }
  198.     
  199.     String demandDir = Utils.getOption('D', options);
  200.     if (demandDir.length() != 0) {
  201.       setOnDemandDirectory(new File(demandDir));
  202.     }
  203.   }
  204.   /**
  205.    * Gets the current settings of the Classifier.
  206.    *
  207.    * @return an array of strings suitable for passing to setOptions
  208.    */
  209.   public String [] getOptions() {
  210.     String [] classifierOptions = new String [0];
  211.     if ((m_Classifier != null) && 
  212. (m_Classifier instanceof OptionHandler)) {
  213.       classifierOptions = ((OptionHandler)m_Classifier).getOptions();
  214.     }
  215.     String [] options = new String [classifierOptions.length + 9];
  216.     int current = 0;
  217.     if (m_MatrixSource == MATRIX_SUPPLIED) {
  218.       if (m_CostFile != null) {
  219.         options[current++] = "-C";
  220.         options[current++] = "" + m_CostFile;
  221.       }
  222.     } else {
  223.       options[current++] = "-D";
  224.       options[current++] = "" + getOnDemandDirectory();
  225.     }
  226.     options[current++] = "-S"; options[current++] = "" + getSeed();
  227.     if (getMinimizeExpectedCost()) {
  228.       options[current++] = "-M";
  229.     }
  230.     if (getClassifier() != null) {
  231.       options[current++] = "-W";
  232.       options[current++] = getClassifier().getClass().getName();
  233.     }
  234.     options[current++] = "--";
  235.     System.arraycopy(classifierOptions, 0, options, current, 
  236.      classifierOptions.length);
  237.     current += classifierOptions.length;
  238.     while (current < options.length) {
  239.       options[current++] = "";
  240.     }
  241.     return options;
  242.   }
  243.   /**
  244.    * @return a description of the classifier suitable for
  245.    * displaying in the explorer/experimenter gui
  246.    */
  247.   public String globalInfo() {
  248.     return "A metaclassifier that makes its base classifier cost-sensitive. "
  249.       + "Two methods can be used to introduce cost-sensitivity: reweighting "
  250.       + "training instances according to the total cost assigned to each "
  251.       + "class; or predicting the class with minimum expected "
  252.       + "misclassification cost (rather than the most likely class). The "
  253.       + "minimum expected cost approach requires that the base classifier be "
  254.       + "a DistributionClassifier (and is optimal if given accurate "
  255.       + "probabilities by it's base classifier). Performance can often be "
  256.       + "improved by using a Bagged classifier to improve the probability "
  257.       + "estimates of the base classifier.";
  258.   }
  259.   /**
  260.    * @return tip text for this property suitable for
  261.    * displaying in the explorer/experimenter gui
  262.    */
  263.   public String costMatrixSourceTipText() {
  264.     return "Sets where to get the cost matrix. The two options are"
  265.       + "to use the supplied explicit cost matrix (the setting of the "
  266.       + "costMatrix property), or to load a cost matrix from a file when "
  267.       + "required (this file will be loaded from the directory set by the "
  268.       + "onDemandDirectory property and will be named relation_name" 
  269.       + CostMatrix.FILE_EXTENSION + ").";
  270.   }
  271.   /**
  272.    * Gets the source location method of the cost matrix. Will be one of
  273.    * MATRIX_ON_DEMAND or MATRIX_SUPPLIED.
  274.    *
  275.    * @return the cost matrix source.
  276.    */
  277.   public SelectedTag getCostMatrixSource() {
  278.     return new SelectedTag(m_MatrixSource, TAGS_MATRIX_SOURCE);
  279.   }
  280.   
  281.   /**
  282.    * Sets the source location of the cost matrix. Values other than
  283.    * MATRIX_ON_DEMAND or MATRIX_SUPPLIED will be ignored.
  284.    *
  285.    * @param newMethod the cost matrix location method.
  286.    */
  287.   public void setCostMatrixSource(SelectedTag newMethod) {
  288.     
  289.     if (newMethod.getTags() == TAGS_MATRIX_SOURCE) {
  290.       m_MatrixSource = newMethod.getSelectedTag().getID();
  291.     }
  292.   }
  293.   /**
  294.    * @return tip text for this property suitable for
  295.    * displaying in the explorer/experimenter gui
  296.    */
  297.   public String onDemandDirectoryTipText() {
  298.     return "Sets the directory where cost files are loaded from. This option "
  299.       + "is used when the costMatrixSource is set to "On Demand".";
  300.   }
  301.   /**
  302.    * Returns the directory that will be searched for cost files when
  303.    * loading on demand.
  304.    *
  305.    * @return The cost file search directory.
  306.    */
  307.   public File getOnDemandDirectory() {
  308.     return m_OnDemandDirectory;
  309.   }
  310.   /**
  311.    * Sets the directory that will be searched for cost files when
  312.    * loading on demand.
  313.    *
  314.    * @param newDir The cost file search directory.
  315.    */
  316.   public void setOnDemandDirectory(File newDir) {
  317.     if (newDir.isDirectory()) {
  318.       m_OnDemandDirectory = newDir;
  319.     } else {
  320.       m_OnDemandDirectory = new File(newDir.getParent());
  321.     }
  322.     m_MatrixSource = MATRIX_ON_DEMAND;
  323.   }
  324.   /**
  325.    * @return tip text for this property suitable for
  326.    * displaying in the explorer/experimenter gui
  327.    */
  328.   public String minimizeExpectedCostTipText() {
  329.     return "Sets whether the minimum expected cost criteria will be used. If "
  330.       + "this is false, the training data will be reweighted according to the "
  331.       + "costs assigned to each class. If true, the minimum expected cost "
  332.       + "criteria will be used.";
  333.   }
  334.   /**
  335.    * Gets the value of MinimizeExpectedCost.
  336.    *
  337.    * @return Value of MinimizeExpectedCost.
  338.    */
  339.   public boolean getMinimizeExpectedCost() {
  340.     
  341.     return m_MinimizeExpectedCost;
  342.   }
  343.   
  344.   /**
  345.    * Set the value of MinimizeExpectedCost.
  346.    *
  347.    * @param newMinimizeExpectedCost Value to assign to MinimizeExpectedCost.
  348.    */
  349.   public void setMinimizeExpectedCost(boolean newMinimizeExpectedCost) {
  350.     
  351.     m_MinimizeExpectedCost = newMinimizeExpectedCost;
  352.   }
  353.   
  354.   /**
  355.    * @return tip text for this property suitable for
  356.    * displaying in the explorer/experimenter gui
  357.    */
  358.   public String classifierTipText() {
  359.     return "Sets the Classifier used as the basis for "
  360.       + "the cost-sensitive classification. This must be a "
  361.       + "DistributionClassifier if using the minimum expected cost criteria.";
  362.   }
  363.   /**
  364.    * Sets the distribution classifier
  365.    *
  366.    * @param classifier the classifier with all options set.
  367.    */
  368.   public void setClassifier(Classifier classifier) {
  369.     m_Classifier = classifier;
  370.   }
  371.   /**
  372.    * Gets the classifier used.
  373.    *
  374.    * @return the classifier
  375.    */
  376.   public Classifier getClassifier() {
  377.     return m_Classifier;
  378.   }
  379.   
  380.   /**
  381.    * Gets the classifier specification string, which contains the class name of
  382.    * the classifier and any options to the classifier
  383.    *
  384.    * @return the classifier string.
  385.    */
  386.   protected String getClassifierSpec() {
  387.     
  388.     Classifier c = getClassifier();
  389.     if (c instanceof OptionHandler) {
  390.       return c.getClass().getName() + " "
  391. + Utils.joinOptions(((OptionHandler)c).getOptions());
  392.     }
  393.     return c.getClass().getName();
  394.   }
  395.   
  396.   /**
  397.    * @return tip text for this property suitable for
  398.    * displaying in the explorer/experimenter gui
  399.    */
  400.   public String costMatrixTipText() {
  401.     return "Sets the cost matrix explicitly. This matrix is used if the "
  402.       + "costMatrixSource property is set to "Supplied".";
  403.   }
  404.   /**
  405.    * Gets the misclassification cost matrix.
  406.    *
  407.    * @return the cost matrix
  408.    */
  409.   public CostMatrix getCostMatrix() {
  410.     
  411.     return m_CostMatrix;
  412.   }
  413.   
  414.   /**
  415.    * Sets the misclassification cost matrix.
  416.    *
  417.    * @param the cost matrix
  418.    */
  419.   public void setCostMatrix(CostMatrix newCostMatrix) {
  420.     
  421.     m_CostMatrix = newCostMatrix;
  422.     m_MatrixSource = MATRIX_SUPPLIED;
  423.   }
  424.   
  425.   /**
  426.    * @return tip text for this property suitable for
  427.    * displaying in the explorer/experimenter gui
  428.    */
  429.   public String seedTipText() {
  430.     return "Sets the random number seed when reweighting instances. Ignored "
  431.       + "when using minimum expected cost criteria.";
  432.   }
  433.   
  434.   /**
  435.    * Set seed for resampling.
  436.    *
  437.    * @param seed the seed for resampling
  438.    */
  439.   public void setSeed(int seed) {
  440.     m_Seed = seed;
  441.   }
  442.   /**
  443.    * Get seed for resampling.
  444.    *
  445.    * @return the seed for resampling
  446.    */
  447.   public int getSeed() {
  448.     return m_Seed;
  449.   }
  450.   /**
  451.    * Builds the model of the base learner.
  452.    *
  453.    * @param data the training data
  454.    * @exception Exception if the classifier could not be built successfully
  455.    */
  456.   public void buildClassifier(Instances data) throws Exception {
  457.     if (m_Classifier == null) {
  458.       throw new Exception("No base classifier has been set!");
  459.     }
  460.     if (m_MinimizeExpectedCost 
  461. && !(m_Classifier instanceof DistributionClassifier)) {
  462.       throw new Exception("Classifier must be a DistributionClassifier to use"
  463.   + " minimum expected cost method");
  464.     }
  465.     if (!data.classAttribute().isNominal()) {
  466.       throw new Exception("Class attribute must be nominal!");
  467.     }
  468.     if (m_MatrixSource == MATRIX_ON_DEMAND) {
  469.       String costName = data.relationName() + CostMatrix.FILE_EXTENSION;
  470.       File costFile = new File(getOnDemandDirectory(), costName);
  471.       if (!costFile.exists()) {
  472.         throw new Exception("On-demand cost file doesn't exist: " + costFile);
  473.       }
  474.       setCostMatrix(new CostMatrix(new BufferedReader(
  475.                                    new FileReader(costFile))));
  476.     } else if (m_CostMatrix == null) {
  477.       // try loading an old format cost file
  478.       m_CostMatrix = new CostMatrix(data.numClasses());
  479.       m_CostMatrix.readOldFormat(new BufferedReader(
  480.        new FileReader(m_CostFile)));
  481.     }
  482.     if (!m_MinimizeExpectedCost) {
  483.       Random random = null;
  484.       if (!(m_Classifier instanceof WeightedInstancesHandler)) {
  485. random = new Random(m_Seed);
  486.       }
  487.       data = m_CostMatrix.applyCostMatrix(data, random);
  488.     }
  489.     m_Classifier.buildClassifier(data);
  490.   }
  491.   /**
  492.    * Classifies a given instance by choosing the class with the minimum
  493.    * expected misclassification cost.
  494.    *
  495.    * @param instance the instance to be classified
  496.    * @exception Exception if instance could not be classified
  497.    * successfully
  498.    */
  499.   public double classifyInstance(Instance instance) throws Exception {
  500.     if (!m_MinimizeExpectedCost) {
  501.       return m_Classifier.classifyInstance(instance);
  502.     }
  503.     double [] pred = ((DistributionClassifier) m_Classifier)
  504.       .distributionForInstance(instance);
  505.     double [] costs = m_CostMatrix.expectedCosts(pred);
  506.     /*
  507.     for (int i = 0; i < pred.length; i++) {
  508.       System.out.print(pred[i] + " ");
  509.     }
  510.     System.out.println();
  511.     for (int i = 0; i < costs.length; i++) {
  512.       System.out.print(costs[i] + " ");
  513.     }
  514.     System.out.println("n");
  515.     */
  516.     
  517.     return Utils.minIndex(costs);
  518.   }
  519.   /**
  520.    * Returns graph describing the classifier (if possible).
  521.    *
  522.    * @return the graph of the classifier in dotty format
  523.    * @exception Exception if the classifier cannot be graphed
  524.    */
  525.   public String graph() throws Exception {
  526.     
  527.     if (m_Classifier instanceof Drawable)
  528.       return ((Drawable)m_Classifier).graph();
  529.     else throw new Exception("Classifier: " + getClassifierSpec()
  530.      + " cannot be graphed");
  531.   }
  532.   /**
  533.    * Output a representation of this classifier
  534.    */
  535.   public String toString() {
  536.     if (m_Classifier == null) {
  537.       return "CostSensitiveClassifier: No model built yet.";
  538.     }
  539.     String result = "CostSensitiveClassifier using ";
  540.       if (m_MinimizeExpectedCost) {
  541. result += "minimized expected misclasification costn";
  542.       } else {
  543. result += "reweighted training instancesn";
  544.       }
  545.       result += "n" + getClassifierSpec()
  546. + "nnClassifier Modeln"
  547. + m_Classifier.toString()
  548. + "nnCost Matrixn"
  549. + m_CostMatrix.toString();
  550.     return result;
  551.   }
  552.   /**
  553.    * Main method for testing this class.
  554.    *
  555.    * @param argv should contain the following arguments:
  556.    * -t training file [-T test file] [-c class index]
  557.    */
  558.   public static void main(String [] argv) {
  559.     try {
  560.       System.out.println(Evaluation
  561.  .evaluateModel(new CostSensitiveClassifier(),
  562. argv));
  563.     } catch (Exception e) {
  564.       System.err.println(e.getMessage());
  565.     }
  566.   }
  567. }