ConfusionMatrix.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 9k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    NominalPrediction.java
  18.  *    Copyright (C) 2000 Intelligenesis Corp.
  19.  *
  20.  */
  21. package weka.classifiers.evaluation;
  22. import weka.core.Matrix;
  23. import weka.core.FastVector;
  24. import weka.core.Utils;
  25. import weka.classifiers.CostMatrix;
  26. /**
  27.  * Cells of this matrix correspond to counts of the number (or weight)
  28.  * of predictions for each actual value / predicted value combination.
  29.  *
  30.  * @author Len Trigg (len@intelligenesis.net)
  31.  * @version $Revision: 1.4 $
  32.  */
  33. public class ConfusionMatrix extends Matrix {
  34.   /** Stores the names of the classes */
  35.   protected String [] m_ClassNames;
  36.   /**
  37.    * Creates the confusion matrix with the given class names.
  38.    *
  39.    * @param classNames an array containing the names the classes.
  40.    */
  41.   public ConfusionMatrix(String [] classNames) {
  42.     super(classNames.length, classNames.length);
  43.     m_ClassNames = (String [])classNames.clone();
  44.   }
  45.   /**
  46.    * Makes a copy of this ConfusionMatrix after applying the
  47.    * supplied CostMatrix to the cells. The resulting ConfusionMatrix
  48.    * can be used to get cost-weighted statistics.
  49.    *
  50.    * @param costs the CostMatrix.
  51.    * @return a ConfusionMatrix that has had costs applied.
  52.    * @exception Exception if the CostMatrix is not of the same size
  53.    * as this ConfusionMatrix.
  54.    */
  55.   public ConfusionMatrix makeWeighted(CostMatrix costs) throws Exception {
  56.     if (costs.size() != size()) {
  57.       throw new Exception("Cost and confusion matrices must be the same size");
  58.     }
  59.     ConfusionMatrix weighted = new ConfusionMatrix(m_ClassNames);
  60.     for (int row = 0; row < size(); row++) {
  61.       for (int col = 0; col < size(); col++) {
  62.         weighted.setElement(row, col, getElement(row, col) * 
  63.                             costs.getElement(row, col));
  64.       }
  65.     }
  66.     return weighted;
  67.   }
  68.   /**
  69.    * Creates and returns a clone of this object.
  70.    *
  71.    * @return a clone of this instance.
  72.    * @exception CloneNotSupportedException if an error occurs
  73.    */
  74.   public Object clone() throws CloneNotSupportedException {
  75.     ConfusionMatrix m = (ConfusionMatrix)super.clone();
  76.     m.m_ClassNames = (String [])m_ClassNames.clone();
  77.     return m;
  78.   }
  79.   /**
  80.    * Gets the number of classes.
  81.    *
  82.    * @return the number of classes
  83.    */
  84.   public int size() {
  85.     return m_ClassNames.length;
  86.   }
  87.   /**
  88.    * Gets the name of one of the classes.
  89.    *
  90.    * @param index the index of the class.
  91.    * @return the class name.
  92.    */
  93.   public String className(int index) {
  94.     return m_ClassNames[index];
  95.   }
  96.   /**
  97.    * Includes a prediction in the confusion matrix.
  98.    *
  99.    * @param pred the NominalPrediction to include
  100.    * @exception Exception if no valid prediction was made (i.e. 
  101.    * unclassified).
  102.    */
  103.   public void addPrediction(NominalPrediction pred) throws Exception {
  104.     if (pred.predicted() == NominalPrediction.MISSING_VALUE) {
  105.       throw new Exception("No predicted value given.");
  106.     }
  107.     if (pred.actual() == NominalPrediction.MISSING_VALUE) {
  108.       throw new Exception("No actual value given.");
  109.     }
  110.     addElement((int)pred.actual(), (int)pred.predicted(), pred.weight());
  111.   }
  112.   /**
  113.    * Includes a whole bunch of predictions in the confusion matrix.
  114.    *
  115.    * @param predictions a FastVector containing the NominalPredictions
  116.    * to include
  117.    * @exception Exception if no valid prediction was made (i.e. 
  118.    * unclassified).
  119.    */
  120.   public void addPredictions(FastVector predictions) throws Exception {
  121.     for (int i = 0; i < predictions.size(); i++) {
  122.       addPrediction((NominalPrediction)predictions.elementAt(i));
  123.     }
  124.   }
  125.   
  126.   /**
  127.    * Gets the performance with respect to one of the classes
  128.    * as a TwoClassStats object.
  129.    *
  130.    * @param classIndex the index of the class of interest.
  131.    * @return the generated TwoClassStats object.
  132.    */
  133.   public TwoClassStats getTwoClassStats(int classIndex) {
  134.     double fp = 0, tp = 0, fn = 0, tn = 0;
  135.     for (int row = 0; row < size(); row++) {
  136.       for (int col = 0; col < size(); col++) {
  137.         if (row == classIndex) {
  138.           if (col == classIndex) {
  139.             tp += getElement(row, col);
  140.           } else {
  141.             fn += getElement(row, col);
  142.           }          
  143.         } else {
  144.           if (col == classIndex) {
  145.             fp += getElement(row, col);
  146.           } else {
  147.             tn += getElement(row, col);
  148.           }          
  149.         }
  150.       }
  151.     }
  152.     return new TwoClassStats(tp, fp, tn, fn);
  153.   }
  154.   /**
  155.    * Gets the number of correct classifications (that is, for which a
  156.    * correct prediction was made). (Actually the sum of the weights of
  157.    * these classifications)
  158.    *
  159.    * @return the number of correct classifications 
  160.    */
  161.   public double correct() {
  162.     double correct = 0;
  163.     for (int i = 0; i < size(); i++) {
  164.       correct += getElement(i, i);
  165.     }
  166.     return correct;
  167.   }
  168.   /**
  169.    * Gets the number of incorrect classifications (that is, for which an
  170.    * incorrect prediction was made). (Actually the sum of the weights of
  171.    * these classifications)
  172.    *
  173.    * @return the number of incorrect classifications 
  174.    */
  175.   public double incorrect() {
  176.     double incorrect = 0;
  177.     for (int row = 0; row < size(); row++) {
  178.       for (int col = 0; col < size(); col++) {
  179.         if (row != col) {
  180.           incorrect += getElement(row, col);
  181.         }
  182.       }
  183.     }
  184.     return incorrect;
  185.   }
  186.   /**
  187.    * Gets the number of predictions that were made
  188.    * (actually the sum of the weights of predictions where the
  189.    * class value was known).
  190.    *
  191.    * @return the number of predictions with known class
  192.    */
  193.   public double total() {
  194.     double total = 0;
  195.     for (int row = 0; row < size(); row++) {
  196.       for (int col = 0; col < size(); col++) {
  197.         total += getElement(row, col);
  198.       }
  199.     }
  200.     return total;
  201.   }
  202.   /**
  203.    * Returns the estimated error rate.
  204.    *
  205.    * @return the estimated error rate (between 0 and 1).
  206.    */
  207.   public double errorRate() {
  208.     return incorrect() / total();
  209.   }
  210.   /**
  211.    * Calls toString() with a default title.
  212.    *
  213.    * @return the confusion matrix as a string
  214.    */
  215.   public String toString() {
  216.     return toString("=== Confusion Matrix ===n");
  217.   }
  218.   /**
  219.    * Outputs the performance statistics as a classification confusion
  220.    * matrix. For each class value, shows the distribution of 
  221.    * predicted class values.
  222.    *
  223.    * @param title the title for the confusion matrix
  224.    * @return the confusion matrix as a String
  225.    */
  226.   public String toString(String title) {
  227.     StringBuffer text = new StringBuffer();
  228.     char [] IDChars = {'a','b','c','d','e','f','g','h','i','j',
  229.        'k','l','m','n','o','p','q','r','s','t',
  230.        'u','v','w','x','y','z'};
  231.     int IDWidth;
  232.     boolean fractional = false;
  233.     // Find the maximum value in the matrix
  234.     // and check for fractional display requirement 
  235.     double maxval = 0;
  236.     for (int i = 0; i < size(); i++) {
  237.       for (int j = 0; j < size(); j++) {
  238. double current = getElement(i, j);
  239.         if (current < 0) {
  240.           current *= -10;
  241.         }
  242. if (current > maxval) {
  243.   maxval = current;
  244. }
  245. double fract = current - Math.rint(current);
  246. if (!fractional
  247.     && ((Math.log(fract) / Math.log(10)) >= -2)) {
  248.   fractional = true;
  249. }
  250.       }
  251.     }
  252.     IDWidth = 1 + Math.max((int)(Math.log(maxval) / Math.log(10) 
  253.  + (fractional ? 3 : 0)),
  254.      (int)(Math.log(size()) / 
  255.    Math.log(IDChars.length)));
  256.     text.append(title).append("n");
  257.     for (int i = 0; i < size(); i++) {
  258.       if (fractional) {
  259. text.append(" ").append(num2ShortID(i,IDChars,IDWidth - 3))
  260.           .append("   ");
  261.       } else {
  262. text.append(" ").append(num2ShortID(i,IDChars,IDWidth));
  263.       }
  264.     }
  265.     text.append("     actual classn");
  266.     for (int i = 0; i< size(); i++) { 
  267.       for (int j = 0; j < size(); j++) {
  268. text.append(" ").append(
  269.     Utils.doubleToString(getElement(i, j),
  270.  IDWidth,
  271.  (fractional ? 2 : 0)));
  272.       }
  273.       text.append(" | ").append(num2ShortID(i,IDChars,IDWidth))
  274.         .append(" = ").append(m_ClassNames[i]).append("n");
  275.     }
  276.     return text.toString();
  277.   }
  278.   /**
  279.    * Method for generating indices for the confusion matrix.
  280.    *
  281.    * @param num integer to format
  282.    * @return the formatted integer as a string
  283.    */
  284.   private static String num2ShortID(int num, char [] IDChars, int IDWidth) {
  285.     
  286.     char ID [] = new char [IDWidth];
  287.     int i;
  288.     
  289.     for(i = IDWidth - 1; i >=0; i--) {
  290.       ID[i] = IDChars[num % IDChars.length];
  291.       num = num / IDChars.length - 1;
  292.       if (num < 0) {
  293. break;
  294.       }
  295.     }
  296.     for(i--; i >= 0; i--) {
  297.       ID[i] = ' ';
  298.     }
  299.     return new String(ID);
  300.   }
  301. }