Distribution.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 18k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    Distribution.java
  18.  *    Copyright (C) 1999 Eibe Frank
  19.  *
  20.  */
  21. package weka.classifiers.trees.j48;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. /**
  26.  * Class for handling a distribution of class values.
  27.  *
  28.  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
  29.  * @version $Revision: 1.7 $
  30.  */
  31. public class Distribution implements Cloneable, Serializable {
  32.   /** Weight of instances per class per bag. */
  33.   private double m_perClassPerBag[][]; 
  34.   /** Weight of instances per bag. */
  35.   private double m_perBag[];           
  36.   /** Weight of instances per class. */
  37.   private double m_perClass[];         
  38.   /** Total weight of instances. */
  39.   private double totaL;            
  40.   /**
  41.    * Creates and initializes a new distribution.
  42.    */
  43.   public Distribution(int numBags,int numClasses) {
  44.     int i;
  45.     m_perClassPerBag = new double [numBags][0];
  46.     m_perBag = new double [numBags];
  47.     m_perClass = new double [numClasses];
  48.     for (i=0;i<numBags;i++)
  49.       m_perClassPerBag[i] = new double [numClasses];
  50.     totaL = 0;
  51.   }
  52.   /**
  53.    * Creates and initializes a new distribution using the given
  54.    * array. WARNING: it just copies a reference to this array.
  55.    */
  56.   public Distribution(double [][] table) {
  57.     int i, j;
  58.     m_perClassPerBag = table;
  59.     m_perBag = new double [table.length];
  60.     m_perClass = new double [table[0].length];
  61.     for (i = 0; i < table.length; i++) 
  62.       for (j  = 0; j < table[i].length; j++) {
  63. m_perBag[i] += table[i][j];
  64. m_perClass[j] += table[i][j];
  65. totaL += table[i][j];
  66.       }
  67.   }
  68.   /**
  69.    * Creates a distribution with only one bag according
  70.    * to instances in source.
  71.    *
  72.    * @exception Exception if something goes wrong
  73.    */
  74.   public Distribution(Instances source) throws Exception {
  75.     
  76.     m_perClassPerBag = new double [1][0];
  77.     m_perBag = new double [1];
  78.     totaL = 0;
  79.     m_perClass = new double [source.numClasses()];
  80.     m_perClassPerBag[0] = new double [source.numClasses()];
  81.     Enumeration enum = source.enumerateInstances();
  82.     while (enum.hasMoreElements())
  83.       add(0,(Instance) enum.nextElement());
  84.   }
  85.   /**
  86.    * Creates a distribution according to given instances and
  87.    * split model.
  88.    *
  89.    * @exception Exception if something goes wrong
  90.    */
  91.   public Distribution(Instances source, 
  92.       ClassifierSplitModel modelToUse)
  93.        throws Exception {
  94.     int index;
  95.     Instance instance;
  96.     double[] weights;
  97.     m_perClassPerBag = new double [modelToUse.numSubsets()][0];
  98.     m_perBag = new double [modelToUse.numSubsets()];
  99.     totaL = 0;
  100.     m_perClass = new double [source.numClasses()];
  101.     for (int i = 0; i < modelToUse.numSubsets(); i++)
  102.       m_perClassPerBag[i] = new double [source.numClasses()];
  103.     Enumeration enum = source.enumerateInstances();
  104.     while (enum.hasMoreElements()) {
  105.       instance = (Instance) enum.nextElement();
  106.       index = modelToUse.whichSubset(instance);
  107.       if (index != -1)
  108. add(index, instance);
  109.       else {
  110. weights = modelToUse.weights(instance);
  111. addWeights(instance, weights);
  112.       }
  113.     }
  114.   }
  115.   /**
  116.    * Creates distribution with only one bag by merging all
  117.    * bags of given distribution.
  118.    */
  119.   public Distribution(Distribution toMerge) {
  120.     totaL = toMerge.totaL;
  121.     m_perClass = new double [toMerge.numClasses()];
  122.     System.arraycopy(toMerge.m_perClass,0,m_perClass,0,toMerge.numClasses());
  123.     m_perClassPerBag = new double [1] [0];
  124.     m_perClassPerBag[0] = new double [toMerge.numClasses()];
  125.     System.arraycopy(toMerge.m_perClass,0,m_perClassPerBag[0],0,
  126.      toMerge.numClasses());
  127.     m_perBag = new double [1];
  128.     m_perBag[0] = totaL;
  129.   }
  130.   /**
  131.    * Creates distribution with two bags by merging all bags apart of
  132.    * the indicated one.
  133.    */
  134.   public Distribution(Distribution toMerge, int index) {
  135.     int i;
  136.     totaL = toMerge.totaL;
  137.     m_perClass = new double [toMerge.numClasses()];
  138.     System.arraycopy(toMerge.m_perClass,0,m_perClass,0,toMerge.numClasses());
  139.     m_perClassPerBag = new double [2] [0];
  140.     m_perClassPerBag[0] = new double [toMerge.numClasses()];
  141.     System.arraycopy(toMerge.m_perClassPerBag[index],0,m_perClassPerBag[0],0,
  142.      toMerge.numClasses());
  143.     m_perClassPerBag[1] = new double [toMerge.numClasses()];
  144.     for (i=0;i<toMerge.numClasses();i++)
  145.       m_perClassPerBag[1][i] = toMerge.m_perClass[i]-m_perClassPerBag[0][i];
  146.     m_perBag = new double [2];
  147.     m_perBag[0] = toMerge.m_perBag[index];
  148.     m_perBag[1] = totaL-m_perBag[0];
  149.   }
  150.   
  151.   /**
  152.    * Returns number of non-empty bags of distribution.
  153.    */
  154.   public final int actualNumBags() {
  155.     
  156.     int returnValue = 0;
  157.     int i;
  158.     for (i=0;i<m_perBag.length;i++)
  159.       if (Utils.gr(m_perBag[i],0))
  160. returnValue++;
  161.     
  162.     return returnValue;
  163.   }
  164.   /**
  165.    * Returns number of classes actually occuring in distribution.
  166.    */
  167.   public final int actualNumClasses() {
  168.     int returnValue = 0;
  169.     int i;
  170.     for (i=0;i<m_perClass.length;i++)
  171.       if (Utils.gr(m_perClass[i],0))
  172. returnValue++;
  173.     
  174.     return returnValue;
  175.   }
  176.   /**
  177.    * Returns number of classes actually occuring in given bag.
  178.    */
  179.   public final int actualNumClasses(int bagIndex) {
  180.     int returnValue = 0;
  181.     int i;
  182.     for (i=0;i<m_perClass.length;i++)
  183.       if (Utils.gr(m_perClassPerBag[bagIndex][i],0))
  184. returnValue++;
  185.     
  186.     return returnValue;
  187.   }
  188.   /**
  189.    * Adds given instance to given bag.
  190.    *
  191.    * @exception Exception if something goes wrong
  192.    */
  193.   public final void add(int bagIndex,Instance instance) 
  194.        throws Exception {
  195.     
  196.     int classIndex;
  197.     double weight;
  198.     classIndex = (int)instance.classValue();
  199.     weight = instance.weight();
  200.     m_perClassPerBag[bagIndex][classIndex] = 
  201.       m_perClassPerBag[bagIndex][classIndex]+weight;
  202.     m_perBag[bagIndex] = m_perBag[bagIndex]+weight;
  203.     m_perClass[classIndex] = m_perClass[classIndex]+weight;
  204.     totaL = totaL+weight;
  205.   }
  206.   /**
  207.    * Subtracts given instance from given bag.
  208.    *
  209.    * @exception Exception if something goes wrong
  210.    */
  211.   public final void sub(int bagIndex,Instance instance) 
  212.        throws Exception {
  213.     
  214.     int classIndex;
  215.     double weight;
  216.     classIndex = (int)instance.classValue();
  217.     weight = instance.weight();
  218.     m_perClassPerBag[bagIndex][classIndex] = 
  219.       m_perClassPerBag[bagIndex][classIndex]-weight;
  220.     m_perBag[bagIndex] = m_perBag[bagIndex]-weight;
  221.     m_perClass[classIndex] = m_perClass[classIndex]-weight;
  222.     totaL = totaL-weight;
  223.   }
  224.   /**
  225.    * Adds counts to given bag.
  226.    */
  227.   public final void add(int bagIndex, double[] counts) {
  228.     
  229.     double sum = Utils.sum(counts);
  230.     for (int i = 0; i < counts.length; i++)
  231.       m_perClassPerBag[bagIndex][i] += counts[i];
  232.     m_perBag[bagIndex] = m_perBag[bagIndex]+sum;
  233.     for (int i = 0; i < counts.length; i++)
  234.       m_perClass[i] = m_perClass[i]+counts[i];
  235.     totaL = totaL+sum;
  236.   }
  237.   /**
  238.    * Adds all instances with unknown values for given attribute, weighted
  239.    * according to frequency of instances in each bag.
  240.    *
  241.    * @exception Exception if something goes wrong
  242.    */
  243.   public final void addInstWithUnknown(Instances source,
  244.        int attIndex)
  245.        throws Exception {
  246.     double [] probs;
  247.     double weight,newWeight;
  248.     int classIndex;
  249.     Instance instance;
  250.     int j;
  251.     probs = new double [m_perBag.length];
  252.     for (j=0;j<m_perBag.length;j++) {
  253.       if (Utils.eq(totaL, 0)) {
  254. probs[j] = 1.0 / probs.length;
  255.       } else {
  256. probs[j] = m_perBag[j]/totaL;
  257.       }
  258.     }
  259.     Enumeration enum = source.enumerateInstances();
  260.     while (enum.hasMoreElements()) {
  261.       instance = (Instance) enum.nextElement();
  262.       if (instance.isMissing(attIndex)) {
  263. classIndex = (int)instance.classValue();
  264. weight = instance.weight();
  265. m_perClass[classIndex] = m_perClass[classIndex]+weight;
  266. totaL = totaL+weight;
  267. for (j = 0; j < m_perBag.length; j++) {
  268.   newWeight = probs[j]*weight;
  269.   m_perClassPerBag[j][classIndex] = m_perClassPerBag[j][classIndex]+
  270.     newWeight;
  271.   m_perBag[j] = m_perBag[j]+newWeight;
  272. }
  273.       }
  274.     }
  275.   }
  276.   /**
  277.    * Adds all instances in given range to given bag.
  278.    *
  279.    * @exception Exception if something goes wrong
  280.    */
  281.   public final void addRange(int bagIndex,Instances source,
  282.      int startIndex, int lastPlusOne)
  283.        throws Exception {
  284.     double sumOfWeights = 0;
  285.     int classIndex;
  286.     Instance instance;
  287.     int i;
  288.     for (i = startIndex; i < lastPlusOne; i++) {
  289.       instance = (Instance) source.instance(i);
  290.       classIndex = (int)instance.classValue();
  291.       sumOfWeights = sumOfWeights+instance.weight();
  292.       m_perClassPerBag[bagIndex][classIndex] += instance.weight();
  293.       m_perClass[classIndex] += instance.weight();
  294.     }
  295.     m_perBag[bagIndex] += sumOfWeights;
  296.     totaL += sumOfWeights;
  297.   }
  298.   /**
  299.    * Adds given instance to all bags weighting it according to given weights.
  300.    *
  301.    * @exception Exception if something goes wrong
  302.    */
  303.   public final void addWeights(Instance instance, 
  304.        double [] weights)
  305.        throws Exception {
  306.     int classIndex;
  307.     int i;
  308.     classIndex = (int)instance.classValue();
  309.     for (i=0;i<m_perBag.length;i++) {
  310.       double weight = instance.weight() * weights[i];
  311.       m_perClassPerBag[i][classIndex] = m_perClassPerBag[i][classIndex] + weight;
  312.       m_perBag[i] = m_perBag[i] + weight;
  313.       m_perClass[classIndex] = m_perClass[classIndex] + weight;
  314.       totaL = totaL + weight;
  315.     }
  316.   }
  317.   /**
  318.    * Checks if at least two bags contain a minimum number of instances.
  319.    */
  320.   public final boolean check(double minNoObj) {
  321.     int counter = 0;
  322.     int i;
  323.     for (i=0;i<m_perBag.length;i++)
  324.       if (Utils.grOrEq(m_perBag[i],minNoObj))
  325. counter++;
  326.     if (counter > 1)
  327.       return true;
  328.     else
  329.       return false;
  330.   }
  331.   /**
  332.    * Clones distribution (Deep copy of distribution).
  333.    */
  334.   public final Object clone() {
  335.     int i,j;
  336.     Distribution newDistribution = new Distribution (m_perBag.length,
  337.      m_perClass.length);
  338.     for (i=0;i<m_perBag.length;i++) {
  339.       newDistribution.m_perBag[i] = m_perBag[i];
  340.       for (j=0;j<m_perClass.length;j++)
  341. newDistribution.m_perClassPerBag[i][j] = m_perClassPerBag[i][j];
  342.     }
  343.     for (j=0;j<m_perClass.length;j++)
  344.       newDistribution.m_perClass[j] = m_perClass[j];
  345.     newDistribution.totaL = totaL;
  346.   
  347.     return newDistribution;
  348.   }
  349.   /**
  350.    * Deletes given instance from given bag.
  351.    *
  352.    * @exception Exception if something goes wrong
  353.    */
  354.   public final void del(int bagIndex,Instance instance) 
  355.        throws Exception {
  356.     int classIndex;
  357.     double weight;
  358.     classIndex = (int)instance.classValue();
  359.     weight = instance.weight();
  360.     m_perClassPerBag[bagIndex][classIndex] = 
  361.       m_perClassPerBag[bagIndex][classIndex]-weight;
  362.     m_perBag[bagIndex] = m_perBag[bagIndex]-weight;
  363.     m_perClass[classIndex] = m_perClass[classIndex]-weight;
  364.     totaL = totaL-weight;
  365.   }
  366.   /**
  367.    * Deletes all instances in given range from given bag.
  368.    *
  369.    * @exception Exception if something goes wrong
  370.    */
  371.   public final void delRange(int bagIndex,Instances source,
  372.      int startIndex, int lastPlusOne)
  373.        throws Exception {
  374.     double sumOfWeights = 0;
  375.     int classIndex;
  376.     Instance instance;
  377.     int i;
  378.     for (i = startIndex; i < lastPlusOne; i++) {
  379.       instance = (Instance) source.instance(i);
  380.       classIndex = (int)instance.classValue();
  381.       sumOfWeights = sumOfWeights+instance.weight();
  382.       m_perClassPerBag[bagIndex][classIndex] -= instance.weight();
  383.       m_perClass[classIndex] -= instance.weight();
  384.     }
  385.     m_perBag[bagIndex] -= sumOfWeights;
  386.     totaL -= sumOfWeights;
  387.   }
  388.   /**
  389.    * Prints distribution.
  390.    */
  391.   
  392.   public final String dumpDistribution() {
  393.     StringBuffer text;
  394.     int i,j;
  395.     text = new StringBuffer();
  396.     for (i=0;i<m_perBag.length;i++) {
  397.       text.append("Bag num "+i+"n");
  398.       for (j=0;j<m_perClass.length;j++)
  399. text.append("Class num "+j+" "+m_perClassPerBag[i][j]+"n");
  400.     }
  401.     return text.toString();
  402.   }
  403.   /**
  404.    * Sets all counts to zero.
  405.    */
  406.   public final void initialize() {
  407.     for (int i = 0; i < m_perClass.length; i++) 
  408.       m_perClass[i] = 0;
  409.     for (int i = 0; i < m_perBag.length; i++)
  410.       m_perBag[i] = 0;
  411.     for (int i = 0; i < m_perBag.length; i++)
  412.       for (int j = 0; j < m_perClass.length; j++)
  413. m_perClassPerBag[i][j] = 0;
  414.     totaL = 0;
  415.   }
  416.   /**
  417.    * Returns matrix with distribution of class values.
  418.    */
  419.   public final double[][] matrix() {
  420.     return m_perClassPerBag;
  421.   }
  422.   
  423.   /**
  424.    * Returns index of bag containing maximum number of instances.
  425.    */
  426.   public final int maxBag() {
  427.     double max;
  428.     int maxIndex;
  429.     int i;
  430.     
  431.     max = 0;
  432.     maxIndex = -1;
  433.     for (i=0;i<m_perBag.length;i++)
  434.       if (Utils.grOrEq(m_perBag[i],max)) {
  435. max = m_perBag[i];
  436. maxIndex = i;
  437.       }
  438.     return maxIndex;
  439.   }
  440.   /**
  441.    * Returns class with highest frequency over all bags.
  442.    */
  443.   public final int maxClass() {
  444.     double maxCount = 0;
  445.     int maxIndex = 0;
  446.     int i;
  447.     for (i=0;i<m_perClass.length;i++)
  448.       if (Utils.gr(m_perClass[i],maxCount)) {
  449. maxCount = m_perClass[i];
  450. maxIndex = i;
  451.       }
  452.     return maxIndex;
  453.   }
  454.   /**
  455.    * Returns class with highest frequency for given bag.
  456.    */
  457.   public final int maxClass(int index) {
  458.     double maxCount = 0;
  459.     int maxIndex = 0;
  460.     int i;
  461.     if (Utils.gr(m_perBag[index],0)) {
  462.       for (i=0;i<m_perClass.length;i++)
  463. if (Utils.gr(m_perClassPerBag[index][i],maxCount)) {
  464.   maxCount = m_perClassPerBag[index][i];
  465.   maxIndex = i;
  466. }
  467.       return maxIndex;
  468.     }else
  469.       return maxClass();
  470.   }
  471.   /**
  472.    * Returns number of bags.
  473.    */
  474.   public final int numBags() {
  475.     
  476.     return m_perBag.length;
  477.   }
  478.   /**
  479.    * Returns number of classes.
  480.    */
  481.   public final int numClasses() {
  482.     return m_perClass.length;
  483.   }
  484.   /**
  485.    * Returns perClass(maxClass()).
  486.    */
  487.   public final double numCorrect() {
  488.     return m_perClass[maxClass()];
  489.   }
  490.   /**
  491.    * Returns perClassPerBag(index,maxClass(index)).
  492.    */
  493.   public final double numCorrect(int index) {
  494.     return m_perClassPerBag[index][maxClass(index)];
  495.   }
  496.   /**
  497.    * Returns total-numCorrect().
  498.    */
  499.   public final double numIncorrect() {
  500.     return totaL-numCorrect();
  501.   }
  502.   /**
  503.    * Returns perBag(index)-numCorrect(index).
  504.    */
  505.   public final double numIncorrect(int index) {
  506.     return m_perBag[index]-numCorrect(index);
  507.   }
  508.   /**
  509.    * Returns number of (possibly fractional) instances of given class in 
  510.    * given bag.
  511.    */
  512.   public final double perClassPerBag(int bagIndex, int classIndex) {
  513.     return m_perClassPerBag[bagIndex][classIndex];
  514.   }
  515.   /**
  516.    * Returns number of (possibly fractional) instances in given bag.
  517.    */
  518.   public final double perBag(int bagIndex) {
  519.     return m_perBag[bagIndex];
  520.   }
  521.   /**
  522.    * Returns number of (possibly fractional) instances of given class.
  523.    */
  524.   public final double perClass(int classIndex) {
  525.     return m_perClass[classIndex];
  526.   }
  527.   /**
  528.    * Returns relative frequency of class over all bags with
  529.    * Laplace correction.
  530.    */
  531.   public final double laplaceProb(int classIndex) {
  532.     return (m_perClass[classIndex] + 1) / 
  533.       (totaL + (double) actualNumClasses());
  534.   }
  535.   /**
  536.    * Returns relative frequency of class for given bag.
  537.    */
  538.   public final double laplaceProb(int classIndex, int intIndex) {
  539.     return (m_perClassPerBag[intIndex][classIndex] + 1.0) /
  540.       (m_perBag[intIndex] + (double) actualNumClasses());
  541.   }
  542.   /**
  543.    * Returns relative frequency of class over all bags.
  544.    */
  545.   public final double prob(int classIndex) {
  546.     if (!Utils.eq(totaL, 0)) {
  547.       return m_perClass[classIndex]/totaL;
  548.     } else {
  549.       return 0;
  550.     }
  551.   }
  552.   /**
  553.    * Returns relative frequency of class for given bag.
  554.    */
  555.   public final double prob(int classIndex,int intIndex) {
  556.     if (Utils.gr(m_perBag[intIndex],0))
  557.       return m_perClassPerBag[intIndex][classIndex]/m_perBag[intIndex];
  558.     else
  559.       return prob(classIndex);
  560.   }
  561.   /** 
  562.    * Subtracts the given distribution from this one. The results
  563.    * has only one bag.
  564.    */
  565.   public final Distribution subtract(Distribution toSubstract) {
  566.     Distribution newDist = new Distribution(1,m_perClass.length);
  567.     newDist.m_perBag[0] = totaL-toSubstract.totaL;
  568.     newDist.totaL = newDist.m_perBag[0];
  569.     for (int i = 0; i < m_perClass.length; i++) {
  570.       newDist.m_perClassPerBag[0][i] = m_perClass[i] - toSubstract.m_perClass[i];
  571.       newDist.m_perClass[i] = newDist.m_perClassPerBag[0][i];
  572.     }
  573.     return newDist;
  574.   }
  575.   /**
  576.    * Returns total number of (possibly fractional) instances.
  577.    */
  578.   public final double total() {
  579.     return totaL;
  580.   }
  581.   /**
  582.    * Shifts given instance from one bag to another one.
  583.    *
  584.    * @exception Exception if something goes wrong
  585.    */
  586.   public final void shift(int from,int to,Instance instance) 
  587.        throws Exception {
  588.     
  589.     int classIndex;
  590.     double weight;
  591.     classIndex = (int)instance.classValue();
  592.     weight = instance.weight();
  593.     m_perClassPerBag[from][classIndex] -= weight;
  594.     m_perClassPerBag[to][classIndex] += weight;
  595.     m_perBag[from] -= weight;
  596.     m_perBag[to] += weight;
  597.   }
  598.   /**
  599.    * Shifts all instances in given range from one bag to another one.
  600.    *
  601.    * @exception Exception if something goes wrong
  602.    */
  603.   public final void shiftRange(int from,int to,Instances source,
  604.        int startIndex,int lastPlusOne) 
  605.        throws Exception {
  606.     
  607.     int classIndex;
  608.     double weight;
  609.     Instance instance;
  610.     int i;
  611.     for (i = startIndex; i < lastPlusOne; i++) {
  612.       instance = (Instance) source.instance(i);
  613.       classIndex = (int)instance.classValue();
  614.       weight = instance.weight();
  615.       m_perClassPerBag[from][classIndex] -= weight;
  616.       m_perClassPerBag[to][classIndex] += weight;
  617.       m_perBag[from] -= weight;
  618.       m_perBag[to] += weight;
  619.     }
  620.   }
  621. }