NaiveBayesSimple.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 10k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    NaiveBayesSimple.java
  18.  *    Copyright (C) 1999 Eibe Frank
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. /**
  26.  * Class for building and using a simple Naive Bayes classifier.
  27.  * Numeric attributes are modelled by a normal distribution. For more
  28.  * information, see<p>
  29.  *
  30.  * Richard Duda and Peter Hart (1973).<i>Pattern
  31.  * Classification and Scene Analysis</i>. Wiley, New York.
  32.  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
  33.  * @version $Revision: 1.8 $ 
  34. */
  35. public class NaiveBayesSimple extends DistributionClassifier {
  36.   /** All the counts for nominal attributes. */
  37.   private double [][][] m_Counts;
  38.   
  39.   /** The means for numeric attributes. */
  40.   private double [][] m_Means;
  41.   /** The standard deviations for numeric attributes. */
  42.   private double [][] m_Devs;
  43.   /** The prior probabilities of the classes. */
  44.   private double [] m_Priors;
  45.   /** The instances used for training. */
  46.   private Instances m_Instances;
  47.   /** Constant for normal distribution. */
  48.   private static double NORM_CONST = Math.sqrt(2 * Math.PI);
  49.   /**
  50.    * Generates the classifier.
  51.    *
  52.    * @param instances set of instances serving as training data 
  53.    * @exception Exception if the classifier has not been generated successfully
  54.    */
  55.   public void buildClassifier(Instances instances) throws Exception {
  56.     int attIndex = 0;
  57.     double sum;
  58.     
  59.     if (instances.checkForStringAttributes()) {
  60.       throw new Exception("Can't handle string attributes!");
  61.     }
  62.     if (instances.classAttribute().isNumeric()) {
  63.       throw new Exception("Naive Bayes: Class is numeric!");
  64.     }
  65.     
  66.     m_Instances = new Instances(instances, 0);
  67.     
  68.     // Reserve space
  69.     m_Counts = new double[instances.numClasses()]
  70.       [instances.numAttributes() - 1][0];
  71.     m_Means = new double[instances.numClasses()]
  72.       [instances.numAttributes() - 1];
  73.     m_Devs = new double[instances.numClasses()]
  74.       [instances.numAttributes() - 1];
  75.     m_Priors = new double[instances.numClasses()];
  76.     Enumeration enum = instances.enumerateAttributes();
  77.     while (enum.hasMoreElements()) {
  78.       Attribute attribute = (Attribute) enum.nextElement();
  79.       if (attribute.isNominal()) {
  80. for (int j = 0; j < instances.numClasses(); j++) {
  81.   m_Counts[j][attIndex] = new double[attribute.numValues()];
  82. }
  83.       } else {
  84. for (int j = 0; j < instances.numClasses(); j++) {
  85.   m_Counts[j][attIndex] = new double[1];
  86. }
  87.       }
  88.       attIndex++;
  89.     }
  90.     
  91.     // Compute counts and sums
  92.     Enumeration enumInsts = instances.enumerateInstances();
  93.     while (enumInsts.hasMoreElements()) {
  94.       Instance instance = (Instance) enumInsts.nextElement();
  95.       if (!instance.classIsMissing()) {
  96. Enumeration enumAtts = instances.enumerateAttributes();
  97. attIndex = 0;
  98. while (enumAtts.hasMoreElements()) {
  99.   Attribute attribute = (Attribute) enumAtts.nextElement();
  100.   if (!instance.isMissing(attribute)) {
  101.     if (attribute.isNominal()) {
  102.       m_Counts[(int)instance.classValue()][attIndex]
  103. [(int)instance.value(attribute)]++;
  104.     } else {
  105.       m_Means[(int)instance.classValue()][attIndex] +=
  106. instance.value(attribute);
  107.       m_Counts[(int)instance.classValue()][attIndex][0]++;
  108.     }
  109.   }
  110.   attIndex++;
  111. }
  112. m_Priors[(int)instance.classValue()]++;
  113.       }
  114.     }
  115.     
  116.     // Compute means
  117.     Enumeration enumAtts = instances.enumerateAttributes();
  118.     attIndex = 0;
  119.     while (enumAtts.hasMoreElements()) {
  120.       Attribute attribute = (Attribute) enumAtts.nextElement();
  121.       if (attribute.isNumeric()) {
  122. for (int j = 0; j < instances.numClasses(); j++) {
  123.   if (m_Counts[j][attIndex][0] < 2) {
  124.     throw new Exception("attribute " + attribute.name() +
  125. ": less than two values for class " +
  126. instances.classAttribute().value(j));
  127.   }
  128.   m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
  129. }
  130.       }
  131.       attIndex++;
  132.     }    
  133.     
  134.     // Compute standard deviations
  135.     enumInsts = instances.enumerateInstances();
  136.     while (enumInsts.hasMoreElements()) {
  137.       Instance instance = 
  138. (Instance) enumInsts.nextElement();
  139.       if (!instance.classIsMissing()) {
  140. enumAtts = instances.enumerateAttributes();
  141. attIndex = 0;
  142. while (enumAtts.hasMoreElements()) {
  143.   Attribute attribute = (Attribute) enumAtts.nextElement();
  144.   if (!instance.isMissing(attribute)) {
  145.     if (attribute.isNumeric()) {
  146.       m_Devs[(int)instance.classValue()][attIndex] +=
  147. (m_Means[(int)instance.classValue()][attIndex]-
  148.  instance.value(attribute))*
  149. (m_Means[(int)instance.classValue()][attIndex]-
  150.  instance.value(attribute));
  151.     }
  152.   }
  153.   attIndex++;
  154. }
  155.       }
  156.     }
  157.     enumAtts = instances.enumerateAttributes();
  158.     attIndex = 0;
  159.     while (enumAtts.hasMoreElements()) {
  160.       Attribute attribute = (Attribute) enumAtts.nextElement();
  161.       if (attribute.isNumeric()) {
  162. for (int j = 0; j < instances.numClasses(); j++) {
  163.   if (m_Devs[j][attIndex] <= 0) {
  164.     throw new Exception("attribute " + attribute.name() +
  165. ": standard deviation is 0 for class " +
  166. instances.classAttribute().value(j));
  167.   }
  168.   else {
  169.     m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
  170.     m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
  171.   }
  172. }
  173.       }
  174.       attIndex++;
  175.     } 
  176.     
  177.     // Normalize counts
  178.     enumAtts = instances.enumerateAttributes();
  179.     attIndex = 0;
  180.     while (enumAtts.hasMoreElements()) {
  181.       Attribute attribute = (Attribute) enumAtts.nextElement();
  182.       if (attribute.isNominal()) {
  183. for (int j = 0; j < instances.numClasses(); j++) {
  184.   sum = Utils.sum(m_Counts[j][attIndex]);
  185.   for (int i = 0; i < attribute.numValues(); i++) {
  186.     m_Counts[j][attIndex][i] =
  187.       (m_Counts[j][attIndex][i] + 1) 
  188.       / (sum + (double)attribute.numValues());
  189.   }
  190. }
  191.       }
  192.       attIndex++;
  193.     }
  194.     
  195.     // Normalize priors
  196.     sum = Utils.sum(m_Priors);
  197.     for (int j = 0; j < instances.numClasses(); j++)
  198.       m_Priors[j] = (m_Priors[j] + 1) 
  199. / (sum + (double)instances.numClasses());
  200.   }
  201.   /**
  202.    * Calculates the class membership probabilities for the given test instance.
  203.    *
  204.    * @param instance the instance to be classified
  205.    * @return predicted class probability distribution
  206.    * @exception Exception if distribution can't be computed
  207.    */
  208.   public double[] distributionForInstance(Instance instance) throws Exception {
  209.     
  210.     double [] probs = new double[instance.numClasses()];
  211.     int attIndex;
  212.     
  213.     for (int j = 0; j < instance.numClasses(); j++) {
  214.       probs[j] = 1;
  215.       Enumeration enumAtts = instance.enumerateAttributes();
  216.       attIndex = 0;
  217.       while (enumAtts.hasMoreElements()) {
  218. Attribute attribute = (Attribute) enumAtts.nextElement();
  219. if (!instance.isMissing(attribute)) {
  220.   if (attribute.isNominal()) {
  221.     probs[j] *= m_Counts[j][attIndex][(int)instance.value(attribute)];
  222.   } else {
  223.     probs[j] *= normalDens(instance.value(attribute),
  224.    m_Means[j][attIndex],
  225.    m_Devs[j][attIndex]);}
  226. }
  227. attIndex++;
  228.       }
  229.       probs[j] *= m_Priors[j];
  230.     }
  231.     
  232.     // Normalize probabilities
  233.     Utils.normalize(probs);
  234.     return probs;
  235.   }
  236.   /**
  237.    * Returns a description of the classifier.
  238.    *
  239.    * @return a description of the classifier as a string.
  240.    */
  241.   public String toString() {
  242.     if (m_Instances == null) {
  243.       return "Naive Bayes (simple): No model built yet.";
  244.     }
  245.     try {
  246.       StringBuffer text = new StringBuffer("Naive Bayes (simple)");
  247.       int attIndex;
  248.       
  249.       for (int i = 0; i < m_Instances.numClasses(); i++) {
  250. text.append("nnClass " + m_Instances.classAttribute().value(i) 
  251.     + ": P(C) = " 
  252.     + Utils.doubleToString(m_Priors[i], 10, 8)
  253.     + "nn");
  254. Enumeration enumAtts = m_Instances.enumerateAttributes();
  255. attIndex = 0;
  256. while (enumAtts.hasMoreElements()) {
  257.   Attribute attribute = (Attribute) enumAtts.nextElement();
  258.   text.append("Attribute " + attribute.name() + "n");
  259.   if (attribute.isNominal()) {
  260.     for (int j = 0; j < attribute.numValues(); j++) {
  261.       text.append(attribute.value(j) + "t");
  262.     }
  263.     text.append("n");
  264.     for (int j = 0; j < attribute.numValues(); j++)
  265.       text.append(Utils.
  266.   doubleToString(m_Counts[i][attIndex][j], 10, 8)
  267.   + "t");
  268.   } else {
  269.     text.append("Mean: " + Utils.
  270. doubleToString(m_Means[i][attIndex], 10, 8) + "t");
  271.     text.append("Standard Deviation: " 
  272. + Utils.doubleToString(m_Devs[i][attIndex], 10, 8));
  273.   }
  274.   text.append("nn");
  275.   attIndex++;
  276. }
  277.       }
  278.       
  279.       return text.toString();
  280.     } catch (Exception e) {
  281.       return "Can't print Naive Bayes classifier!";
  282.     }
  283.   }
  284.   /**
  285.    * Density function of normal distribution.
  286.    */
  287.   private double normalDens(double x, double mean, double stdDev) {
  288.     
  289.     double diff = x - mean;
  290.     
  291.     return (1 / (NORM_CONST * stdDev)) 
  292.       * Math.exp(-(diff * diff / (2 * stdDev * stdDev)));
  293.   }
  294.   /**
  295.    * Main method for testing this class.
  296.    *
  297.    * @param argv the options
  298.    */
  299.   public static void main(String [] argv) {
  300.     Classifier scheme;
  301.     try {
  302.       scheme = new NaiveBayesSimple();
  303.       System.out.println(Evaluation.evaluateModel(scheme, argv));
  304.     } catch (Exception e) {
  305.       System.err.println(e.getMessage());
  306.     }
  307.   }
  308. }