CostMatrix.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 13k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    CostMatrix.java
  18.  *    Copyright (C) 1999 Intelligenesis Corp.
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import weka.core.Utils;
  23. import weka.core.Instance;
  24. import weka.core.Instances;
  25. import weka.core.Matrix;
  26. import java.io.Reader;
  27. import java.io.FileReader;
  28. import java.io.BufferedReader;
  29. import java.io.InputStreamReader;
  30. import java.io.StreamTokenizer;
  31. import java.util.Random;
  32. /**
  33.  * Class for a misclassification cost matrix. The element in the i'th column
  34.  * of the j'th row is the cost for (mis)classifying an instance of class j as 
  35.  * having class i. It is valid to have non-zero values down the diagonal 
  36.  * (these are typically negative to indicate some varying degree of "gain" 
  37.  * from making a correct prediction).
  38.  *
  39.  * @author Len Trigg (len@intelligenesis.net)
  40.  * @version $Revision: 1.8 $
  41.  */
  42. public class CostMatrix extends Matrix {
  43.   /** The filename extension that should be used for cost files */
  44.   public static String FILE_EXTENSION = ".cost";
  45.   /**
  46.    * Creates a cost matrix identical to an existing matrix.
  47.    *
  48.    * @param toCopy the matrix to copy.
  49.    */
  50.   public CostMatrix(CostMatrix toCopy) {
  51.     this(toCopy.size());
  52.     for (int i = 0; i < size(); i++) {
  53.       for (int j = 0; j < size(); j++) {
  54. setElement(i, j, toCopy.getElement(i, j));
  55.       }
  56.     }
  57.   }
  58.   /**
  59.    * Creates a default cost matrix for the given number of classes. The 
  60.    * default misclassification cost is 1.
  61.    *
  62.    * @param numClasses the number of classes
  63.    */
  64.   public CostMatrix(int numClasses) {
  65.     
  66.     super(numClasses, numClasses);
  67.   }  
  68.   /**
  69.    * Creates a cost matrix from a cost file.
  70.    *
  71.    * @param r a reader from which the cost matrix will be read
  72.    * @exception Exception if an error occurs
  73.    */
  74.   public CostMatrix(Reader r) throws Exception {
  75.     super(r);
  76.     if (numColumns() != numRows()) {
  77.       throw new Exception("Cost matrix is not square");
  78.     }
  79.   }
  80.   /**
  81.    * Creates a cost matrix for the class attribute of the supplied instances, 
  82.    * where the misclassification costs are higher for misclassifying a rare
  83.    * class as a frequent one. The cost of classifying an instance of class i 
  84.    * as class j is weight * Pj / Pi. (Pi and Pj are laplace estimates)
  85.    *
  86.    * @param instances a value of type 'Instances'
  87.    * @param weight a value of type 'double'
  88.    * @return a value of type CostMatrix
  89.    * @exception Exception if no class attribute is assigned, or the class
  90.    * attribute is not nominal
  91.    */
  92.   public static CostMatrix makeFrequencyDependentMatrix(Instances instances,
  93.                                                         double weight) 
  94.     throws Exception {
  95.     if (!instances.classAttribute().isNominal()) {
  96.       throw new Exception("Class attribute is not nominal!");
  97.     }
  98.     int numClasses = instances.numClasses();
  99.     // Collect class probabilities
  100.     double probs [] = new double [numClasses];
  101.     for (int i = 0; i < probs.length; i++) {
  102.       probs[i]++;
  103.     }
  104.     for (int i = 0; i < instances.numInstances(); i++) {
  105.       Instance current = instances.instance(i);
  106.       if (!current.classIsMissing()) {
  107.         probs[(int)current.classValue()]++;
  108.       }
  109.     }
  110.     Utils.normalize(probs);
  111.     // Create and populate the cost matrix
  112.     CostMatrix newMatrix = new CostMatrix(numClasses);
  113.     for (int i = 0; i < numClasses; i++) {
  114.       for (int j = 0; j < numClasses; j++) {
  115.         if (i != j) {
  116.           newMatrix.setElement(i, j, weight * probs[j] / probs[i]);
  117.         }
  118.       }
  119.     }
  120.     
  121.     return newMatrix;
  122.   }
  123.   /**
  124.    * Reads misclassification cost matrix from given reader. 
  125.    * Each line has to contain three numbers: the index of the true 
  126.    * class, the index of the incorrectly assigned class, and the 
  127.    * weight, separated by white space characters. Comments can be 
  128.    * appended to the end of a line by using the '%' character.
  129.    *
  130.    * @param reader the reader from which the cost matrix is to be read
  131.    * @exception Exception if the cost matrix does not have the 
  132.    * right format
  133.    */
  134.   public void readOldFormat(Reader reader)throws Exception {
  135.     initialize();
  136.     StreamTokenizer tokenizer = new StreamTokenizer(reader);
  137.     tokenizer.commentChar('%');
  138.     tokenizer.eolIsSignificant(true);
  139.     int currentToken;
  140.     while (StreamTokenizer.TT_EOF != 
  141.    (currentToken = tokenizer.nextToken())) {
  142.       
  143.       // Skip empty lines 
  144.       if (currentToken == StreamTokenizer.TT_EOL) {
  145. continue;
  146.       }
  147.       
  148.       // Get index of first class.
  149.       if (currentToken != StreamTokenizer.TT_NUMBER) {
  150. throw new Exception("Only numbers and comments allowed "+
  151.     "in cost file!");
  152.       }
  153.       double firstIndex = tokenizer.nval;
  154.       if (!Utils.eq((double)(int)firstIndex, firstIndex)) {
  155. throw new Exception("First number in line has to be "+
  156.     "index of a class!");
  157.       }
  158.       if ((int)firstIndex >= size()) {
  159. throw new Exception("Class index out of range!");
  160.       }
  161.       // Get index of second class.
  162.       if (StreamTokenizer.TT_EOF == 
  163.   (currentToken = tokenizer.nextToken())) {
  164. throw new Exception("Premature end of file!");
  165.       }
  166.       if (currentToken == StreamTokenizer.TT_EOL) {
  167. throw new Exception("Premature end of line!");
  168.       }
  169.       if (currentToken != StreamTokenizer.TT_NUMBER) {
  170. throw new Exception("Only numbers and comments allowed "+
  171.     "in cost file!");
  172.       }
  173.       double secondIndex = tokenizer.nval;
  174.       if (!Utils.eq((double)(int)secondIndex,secondIndex)) {
  175. throw new Exception("Second number in line has to be "+
  176.     "index of a class!");
  177.       }
  178.       if ((int)secondIndex >= size()) {
  179. throw new Exception("Class index out of range!");
  180.       }
  181.       // Get cost factor.
  182.       if (StreamTokenizer.TT_EOF == 
  183.   (currentToken = tokenizer.nextToken())) {
  184. throw new Exception("Premature end of file!");
  185.       }
  186.       if (currentToken == StreamTokenizer.TT_EOL) {
  187. throw new Exception("Premature end of line!");
  188.       }
  189.       if (currentToken != StreamTokenizer.TT_NUMBER) {
  190. throw new Exception("Only numbers and comments allowed "+
  191.     "in cost file!");
  192.       }
  193.       double weight = tokenizer.nval;
  194.       setElement((int)firstIndex, (int)secondIndex, weight);
  195.     }
  196.   }
  197.   /**
  198.    * Sets the costs to default values (i.e. 0 down the diagonal, and 1 for
  199.    * any misclassification).
  200.    */
  201.   public void initialize() {
  202.     for (int i = 0; i < numRows(); i++) {
  203.       for (int j = 0; j < numColumns(); j++) {
  204. if (i != j) {
  205.   setElement(i, j, 1);
  206. } else {
  207.   setElement(i, j, 0);
  208. }
  209.       }
  210.     }
  211.   }
  212.   /**
  213.    * Gets the number of classes.
  214.    *
  215.    * @return the number of classes
  216.    */
  217.   public int size() {
  218.     return numColumns();
  219.   }
  220.   
  221.   /**
  222.    * Normalizes the cost matrix so that diagonal elements are zero. The value
  223.    * of non-zero diagonal elements is subtracted from the row containing the
  224.    * value. For example: <p>
  225.    *
  226.    * <pre><code>
  227.    * 2  5
  228.    * 3 -1
  229.    * </code></pre>
  230.    * 
  231.    * <p> becomes <p>
  232.    *
  233.    * <pre><code>
  234.    * 0  3
  235.    * 4  0
  236.    * </code></pre><p>
  237.    *
  238.    * This normalization will affect total classification cost during 
  239.    * evaluation, but will not affect the decision made by applying minimum
  240.    * expected cost criteria during prediction.
  241.    */
  242.   public void normalize() {
  243.     for (int i = 0; i < size(); i++) {
  244.       double diag = getElement(i, i);
  245.       for (int j = 0; j < size(); j++) {
  246.         addElement(i, j, -diag);
  247.       }
  248.     }
  249.   }
  250.   /** 
  251.    * Changes the dataset to reflect a given set of costs.
  252.    * Sets the weights of instances according to the misclassification
  253.    * cost matrix, or does resampling according to the cost matrix (if
  254.    * a random number generator is provided). Returns a new dataset.
  255.    *
  256.    * @param instances the instances to apply cost weights to.
  257.    * @param random a random number generator 
  258.    * @return the new dataset
  259.    * @exception Exception if the cost matrix does not have the right
  260.    * format 
  261.    */
  262.   public Instances applyCostMatrix(Instances instances, Random random) 
  263.        throws Exception {
  264.     if (instances.classIndex() < 0) {
  265.       throw new Exception("Class index is not set!");
  266.     }
  267.     if (size() != instances.numClasses()) {
  268.       throw new Exception("Cost matrix and instances have different class"
  269.   + " size!");
  270.     }
  271.     // If this cost matrix hasn't been normalized, apply a normalized
  272.     // version instead.
  273.     for (int i = 0; i < size(); i++) {
  274.       if (!Utils.eq(m_Elements[i][i], 0)) {
  275.         CostMatrix cm = new CostMatrix(this);
  276.         cm.normalize();
  277.         return cm.applyCostMatrix(instances, random);
  278.       }
  279.     }
  280.       
  281.     // Determine the prior weights of all instances in each class
  282.     double [] weightOfInstancesInClass = new double [size()];
  283.     for (int j = 0; j < instances.numInstances(); j++) {
  284.       Instance current = instances.instance(j);
  285.       weightOfInstancesInClass[(int)current.classValue()] += 
  286. current.weight();
  287.     }
  288.     double sumOfWeights = Utils.sum(weightOfInstancesInClass);
  289.     double [] weightFactor = new double [size()];
  290.     double sumOfWeightFactors = 0;
  291.     for (int i = 0; i < size(); i++) {
  292.       // Using Kai Ming Ting's formula for deriving weights for 
  293.       // the classes and Breiman's heuristic for multiclass 
  294.       // problems.
  295.       double sumOfMissClassWeights = 0;
  296.       for (int j = 0; j < size(); j++) {
  297. if (Utils.sm(m_Elements[i][j], 0)) {
  298.   throw new Exception("Neg. weights in misclassification "+
  299.       "cost matrix!"); 
  300. }
  301. sumOfMissClassWeights += m_Elements[i][j];
  302.       }
  303.       weightFactor[i] = sumOfMissClassWeights * sumOfWeights;
  304.       sumOfWeightFactors += sumOfMissClassWeights 
  305. * weightOfInstancesInClass[i];
  306.     }
  307.     for (int i = 0; i < size(); i++) {
  308.       weightFactor[i] /= sumOfWeightFactors;
  309.     }
  310.     
  311.     // Store new weights
  312.     double [] weightOfInstances = new double[instances.numInstances()];
  313.     for (int i = 0; i < instances.numInstances(); i++) {
  314.       Instance current = instances.instance(i);
  315.       weightOfInstances[i] = current.weight() 
  316. * weightFactor[(int)current.classValue()];
  317.     } 
  318.     // Change instances weight or do resampling
  319.     if (random != null) {
  320.       return instances.resampleWithWeights(random, weightOfInstances);
  321.     } else { 
  322.       instances = new Instances(instances);
  323.       for (int i = 0; i < instances.numInstances(); i++) {
  324. instances.instance(i).setWeight(weightOfInstances[i]);
  325.       }
  326.       return instances;
  327.     }
  328.   }
  329.   /**
  330.    * Calculates the expected misclassification cost for each possible
  331.    * class value, given class probability estimates.
  332.    *
  333.    * @param probabilities an array containing probability estimates for each 
  334.    * class value.
  335.    * @return an array containing the expected misclassification cost for each
  336.    * class.
  337.    * @exception Exception if the number of probabilities does not match the 
  338.    * number of classes.
  339.    */
  340.   public double [] expectedCosts(double [] probabilities) throws Exception {
  341.     if (probabilities.length != size()) {
  342.       throw new Exception("Number of classes in probability estimates does not"
  343.   + " match size of cost matrix!");
  344.     }
  345.     double [] costs = new double[size()];
  346.     for (int i = 0; i < size(); i++) {
  347.       double expectedCost = 0;
  348.       for (int j = 0; j < size(); j++) {
  349. expectedCost += m_Elements[j][i] * probabilities[j];
  350.       }
  351.       costs[i] = expectedCost;
  352.     }
  353.     return costs;
  354.   }
  355.   /**
  356.    * Gets the maximum misclassification cost possible for a given actual
  357.    * class value
  358.    *
  359.    * @param actualClass the index of the actual class value
  360.    * @return the highest cost possible for misclassifying this class
  361.    */
  362.   public double getMaxCost(int actualClass) {
  363.     return m_Elements[actualClass][Utils.maxIndex(m_Elements[actualClass])];
  364.   }
  365.   /**
  366.    * Tests out creation of a frequency dependent cost matrix from the command
  367.    * line. Either pipe a set of instances into system.in or give the name of
  368.    * a dataset as an argument. The last column will be treated as the class
  369.    * attribute and a cost matrix with weight 1000 output.
  370.    *
  371.    * @param []args a value of type 'String'
  372.    */
  373.   public static void main(String []args) {
  374.     try {
  375.       Reader r = null;
  376.       if (args.length > 1) {
  377. throw (new Exception("Usage: Instances <filename>"));
  378.       } else if (args.length == 0) {
  379.         r = new BufferedReader(new InputStreamReader(System.in));
  380.       } else {
  381.         r = new BufferedReader(new FileReader(args[0]));
  382.       }
  383.       Instances i = new Instances(r);
  384.       i.setClassIndex(i.numAttributes() - 1);
  385.       CostMatrix.makeFrequencyDependentMatrix(i, 1000)
  386.         .write(new java.io.PrintWriter(System.out));
  387.     } catch (Exception ex) {
  388.       System.err.println(ex);
  389.     }
  390.   }
  391. } // CostMatrix