CostMatrix.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 10k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    CostMatrix.java
  18.  *    Copyright (C) 2001 Richard Kirkby
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import weka.core.Matrix;
  23. import weka.core.Instances;
  24. import weka.core.Utils;
  25. import java.io.Reader;
  26. import java.io.StreamTokenizer;
  27. import java.util.Random;
  28. /**
  29.  * Class for storing and manipulating a misclassification cost matrix.
  30.  * The element at position i,j in the matrix is the penalty for classifying
  31.  * an instance of class j as class i.
  32.  *
  33.  * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
  34.  * @version $Revision: 1.9 $
  35.  */
  36. public class CostMatrix extends Matrix {
  37.   /** The deafult file extension for cost matrix files */
  38.   public static String FILE_EXTENSION = ".cost";
  39.   /**
  40.    * Creates a cost matrix that is a copy of another.
  41.    *
  42.    * @param toCopy the matrix to copy.
  43.    */
  44.   public CostMatrix(CostMatrix toCopy) {
  45.     
  46.     super(toCopy.size(), toCopy.size());
  47.     for (int x=0; x<toCopy.size(); x++) 
  48.       for (int y=0; y<toCopy.size(); y++) 
  49. setElement(x, y, toCopy.getElement(x, y)); 
  50.   }
  51.   /**
  52.    * Creates a default cost matrix of a particular size. All values will be 0.
  53.    *
  54.    * @param numOfClasses the number of classes that the cost matrix holds.
  55.    */  
  56.   public CostMatrix(int numOfClasses) {
  57.     
  58.     super(numOfClasses, numOfClasses);
  59.   }
  60.   /**
  61.    * Creates a cost matrix from a reader.
  62.    *
  63.    * @param reader the reader to get the values from.
  64.    * @exception Exception if the matrix is invalid.
  65.    */  
  66.   public CostMatrix(Reader reader) throws Exception {
  67.     super(reader);
  68.     // make sure that the matrix is square
  69.     if (numRows() != numColumns())
  70.       throw new Exception("Trying to create a non-square cost matrix");
  71.   }
  72.   /**
  73.    * Sets the cost of all correct classifications to 0, and all
  74.    * misclassifications to 1.
  75.    *
  76.    */ 
  77.   public void initialize() {
  78.     for (int i = 0; i < size(); i++) {
  79.       for (int j = 0; j < size(); j++) {
  80. setElement(i, j, i == j ? 0.0 : 1.0);
  81.       }
  82.     }
  83.   }
  84.   /**
  85.    * Gets the size of the matrix.
  86.    *
  87.    * @return the size.
  88.    */
  89.   public int size() {
  90.     return numColumns();
  91.   }
  92.   /**
  93.    * Applies the cost matrix to a set of instances. If a random number generator is 
  94.    * supplied the instances will be resampled, otherwise they will be rewighted. 
  95.    * Adapted from code once sitting in Instances.java
  96.    *
  97.    * @param data the instances to reweight.
  98.    * @param random a random number generator for resampling, if null then instances are
  99.    * rewighted.
  100.    * @return a new dataset reflecting the cost of misclassification.
  101.    * @exception Exception if the data has no class or the matrix in inappropriate.
  102.    */
  103.   public Instances applyCostMatrix(Instances data, Random random) throws Exception {
  104.     double sumOfWeightFactors = 0, sumOfMissClassWeights,
  105.       sumOfWeights;
  106.     double [] weightOfInstancesInClass, weightFactor, weightOfInstances;
  107.     Instances newData;
  108.     if (data.classIndex() < 0) {
  109.       throw new Exception("Class index is not set!");
  110.     }
  111.  
  112.     if (size() != data.numClasses()) { 
  113.       throw new Exception("Misclassification cost matrix has "+
  114.   "wrong format!");
  115.     }
  116.     weightFactor = new double[data.numClasses()];
  117.     weightOfInstancesInClass = new double[data.numClasses()];
  118.     for (int j = 0; j < data.numInstances(); j++) {
  119.       weightOfInstancesInClass[(int)data.instance(j).classValue()] += 
  120. data.instance(j).weight();
  121.     }
  122.     sumOfWeights = Utils.sum(weightOfInstancesInClass);
  123.     // normalize the matrix if not already
  124.     for (int i=0; i<size(); i++)
  125.       if (!Utils.eq(getElement(i, i),0)) {
  126. CostMatrix normMatrix = new CostMatrix(this);
  127. normMatrix.normalize();
  128. return normMatrix.applyCostMatrix(data, random);
  129.       }
  130.     
  131.     for (int i = 0; i < data.numClasses(); i++) {
  132.       // Using Kai Ming Ting's formula for deriving weights for 
  133.       // the classes and Breiman's heuristic for multiclass 
  134.       // problems.
  135.       sumOfMissClassWeights = 0;
  136.       for (int j = 0; j < data.numClasses(); j++) {
  137. if (Utils.sm(getElement(i,j),0)) {
  138.   throw new Exception("Neg. weights in misclassification "+
  139.       "cost matrix!"); 
  140. }
  141. sumOfMissClassWeights += getElement(i,j);
  142.       }
  143.       weightFactor[i] = sumOfMissClassWeights * sumOfWeights;
  144.       sumOfWeightFactors += sumOfMissClassWeights * 
  145. weightOfInstancesInClass[i];
  146.     }
  147.     for (int i = 0; i < data.numClasses(); i++) {
  148.       weightFactor[i] /= sumOfWeightFactors;
  149.     }
  150.     
  151.     // Store new weights
  152.     weightOfInstances = new double[data.numInstances()];
  153.     for (int i = 0; i < data.numInstances(); i++) {
  154.       weightOfInstances[i] = data.instance(i).weight()*
  155. weightFactor[(int)data.instance(i).classValue()];
  156.     } 
  157.     // Change instances weight or do resampling
  158.     if (random != null) {
  159.       return data.resampleWithWeights(random, weightOfInstances);
  160.     } else { 
  161.       Instances instances = new Instances(data);
  162.       for (int i = 0; i < data.numInstances(); i++) {
  163. instances.instance(i).setWeight(weightOfInstances[i]);
  164.       }
  165.       return instances;
  166.     }
  167.   }
  168.   /**
  169.    * Calculates the expected misclassification cost for each possible class value,
  170.    * given class probability estimates. 
  171.    *
  172.    * @param classProbs the class probability estimates.
  173.    * @return the expected costs.
  174.    * @exception Exception if the wrong number of class probabilities is supplied.
  175.    */
  176.   public double[] expectedCosts(double[] classProbs) throws Exception {
  177.     if (classProbs.length != size())
  178.       throw new Exception("Length of probability estimates don't match cost matrix");
  179.     double[] costs = new double[size()];
  180.     for (int x=0; x<size(); x++)
  181.       for (int y=0; y<size(); y++) 
  182. costs[x] += classProbs[y] * getElement(x, y);
  183.     return costs;
  184.   }
  185.   /**
  186.    * Gets the maximum cost for a particular class value.
  187.    *
  188.    * @param classVal the class value.
  189.    * @return the maximum cost.
  190.    */
  191.   public double getMaxCost(int classVal) {
  192.     double maxCost = Double.NEGATIVE_INFINITY;
  193.     for (int i=0; i<size(); i++) {
  194.       double cost = getElement(classVal, i);
  195.       if (cost > maxCost) maxCost = cost;
  196.     }
  197.     return maxCost;
  198.   }
  199.   /**
  200.    * Normalizes the matrix so that the diagonal contains zeros.
  201.    *
  202.    */
  203.   public void normalize() {
  204.     for (int y=0; y<size(); y++) {
  205.       double diag = getElement(y, y);
  206.       for (int x=0; x<size(); x++)
  207. setElement(x, y, getElement(x, y) - diag);
  208.     }
  209.   }
  210.   /**
  211.    * Loads a cost matrix in the old format from a reader. Adapted from code once sitting 
  212.    * in Instances.java
  213.    *
  214.    * @param reader the reader to get the values from.
  215.    * @exception Exception if the matrix cannot be read correctly.
  216.    */  
  217.   public void readOldFormat(Reader reader) throws Exception {
  218.     StreamTokenizer tokenizer;
  219.     int currentToken;
  220.     double firstIndex, secondIndex, weight;
  221.     tokenizer = new StreamTokenizer(reader);
  222.     initialize();
  223.     tokenizer.commentChar('%');
  224.     tokenizer.eolIsSignificant(true);
  225.     while (StreamTokenizer.TT_EOF != 
  226.    (currentToken = tokenizer.nextToken())) {
  227.       // Skip empty lines 
  228.       if (currentToken == StreamTokenizer.TT_EOL) {
  229. continue;
  230.       }
  231.       // Get index of first class.
  232.       if (currentToken != StreamTokenizer.TT_NUMBER) {
  233. throw new Exception("Only numbers and comments allowed "+
  234.     "in cost file!");
  235.       }
  236.       firstIndex = tokenizer.nval;
  237.       if (!Utils.eq((double)(int)firstIndex,firstIndex)) {
  238. throw new Exception("First number in line has to be "+
  239.     "index of a class!");
  240.       }
  241.       if ((int)firstIndex >= size()) {
  242. throw new Exception("Class index out of range!");
  243.       }
  244.       // Get index of second class.
  245.       if (StreamTokenizer.TT_EOF == 
  246.   (currentToken = tokenizer.nextToken())) {
  247. throw new Exception("Premature end of file!");
  248.       }
  249.       if (currentToken == StreamTokenizer.TT_EOL) {
  250. throw new Exception("Premature end of line!");
  251.       }
  252.       if (currentToken != StreamTokenizer.TT_NUMBER) {
  253. throw new Exception("Only numbers and comments allowed "+
  254.     "in cost file!");
  255.       }
  256.       secondIndex = tokenizer.nval;
  257.       if (!Utils.eq((double)(int)secondIndex,secondIndex)) {
  258. throw new Exception("Second number in line has to be "+
  259.     "index of a class!");
  260.       }
  261.       if ((int)secondIndex >= size()) {
  262. throw new Exception("Class index out of range!");
  263.       }
  264.       if ((int)secondIndex == (int)firstIndex) {
  265. throw new Exception("Diagonal of cost matrix non-zero!");
  266.       }
  267.       // Get cost factor.
  268.       if (StreamTokenizer.TT_EOF == 
  269.   (currentToken = tokenizer.nextToken())) {
  270. throw new Exception("Premature end of file!");
  271.       }
  272.       if (currentToken == StreamTokenizer.TT_EOL) {
  273. throw new Exception("Premature end of line!");
  274.       }
  275.       if (currentToken != StreamTokenizer.TT_NUMBER) {
  276. throw new Exception("Only numbers and comments allowed "+
  277.     "in cost file!");
  278.       }
  279.       weight = tokenizer.nval;
  280.       if (!Utils.gr(weight,0)) {
  281. throw new Exception("Only positive weights allowed!");
  282.       }
  283.       setElement((int)firstIndex, (int)secondIndex, weight);
  284.     }
  285.   }
  286. }