NaiveBayesSimple.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 10k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    NaiveBayesSimple.java
  18.  *    Copyright (C) 1999 Eibe Frank
  19.  *
  20.  */
  21. package weka.classifiers.bayes;
  22. import weka.classifiers.Classifier;
  23. import weka.classifiers.DistributionClassifier;
  24. import weka.classifiers.Evaluation;
  25. import java.io.*;
  26. import java.util.*;
  27. import weka.core.*;
  28. /**
  29.  * Class for building and using a simple Naive Bayes classifier.
  30.  * Numeric attributes are modelled by a normal distribution. For more
  31.  * information, see<p>
  32.  *
  33.  * Richard Duda and Peter Hart (1973).<i>Pattern
  34.  * Classification and Scene Analysis</i>. Wiley, New York.
  35.  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
  36.  * @version $Revision: 1.10 $ 
  37. */
  38. public class NaiveBayesSimple extends DistributionClassifier {
  39.   /** All the counts for nominal attributes. */
  40.   private double [][][] m_Counts;
  41.   
  42.   /** The means for numeric attributes. */
  43.   private double [][] m_Means;
  44.   /** The standard deviations for numeric attributes. */
  45.   private double [][] m_Devs;
  46.   /** The prior probabilities of the classes. */
  47.   private double [] m_Priors;
  48.   /** The instances used for training. */
  49.   private Instances m_Instances;
  50.   /** Constant for normal distribution. */
  51.   private static double NORM_CONST = Math.sqrt(2 * Math.PI);
  52.   /**
  53.    * Generates the classifier.
  54.    *
  55.    * @param instances set of instances serving as training data 
  56.    * @exception Exception if the classifier has not been generated successfully
  57.    */
  58.   public void buildClassifier(Instances instances) throws Exception {
  59.     int attIndex = 0;
  60.     double sum;
  61.     
  62.     if (instances.checkForStringAttributes()) {
  63.       throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
  64.     }
  65.     if (instances.classAttribute().isNumeric()) {
  66.       throw new UnsupportedClassTypeException("Naive Bayes: Class is numeric!");
  67.     }
  68.     
  69.     m_Instances = new Instances(instances, 0);
  70.     
  71.     // Reserve space
  72.     m_Counts = new double[instances.numClasses()]
  73.       [instances.numAttributes() - 1][0];
  74.     m_Means = new double[instances.numClasses()]
  75.       [instances.numAttributes() - 1];
  76.     m_Devs = new double[instances.numClasses()]
  77.       [instances.numAttributes() - 1];
  78.     m_Priors = new double[instances.numClasses()];
  79.     Enumeration enum = instances.enumerateAttributes();
  80.     while (enum.hasMoreElements()) {
  81.       Attribute attribute = (Attribute) enum.nextElement();
  82.       if (attribute.isNominal()) {
  83. for (int j = 0; j < instances.numClasses(); j++) {
  84.   m_Counts[j][attIndex] = new double[attribute.numValues()];
  85. }
  86.       } else {
  87. for (int j = 0; j < instances.numClasses(); j++) {
  88.   m_Counts[j][attIndex] = new double[1];
  89. }
  90.       }
  91.       attIndex++;
  92.     }
  93.     
  94.     // Compute counts and sums
  95.     Enumeration enumInsts = instances.enumerateInstances();
  96.     while (enumInsts.hasMoreElements()) {
  97.       Instance instance = (Instance) enumInsts.nextElement();
  98.       if (!instance.classIsMissing()) {
  99. Enumeration enumAtts = instances.enumerateAttributes();
  100. attIndex = 0;
  101. while (enumAtts.hasMoreElements()) {
  102.   Attribute attribute = (Attribute) enumAtts.nextElement();
  103.   if (!instance.isMissing(attribute)) {
  104.     if (attribute.isNominal()) {
  105.       m_Counts[(int)instance.classValue()][attIndex]
  106. [(int)instance.value(attribute)]++;
  107.     } else {
  108.       m_Means[(int)instance.classValue()][attIndex] +=
  109. instance.value(attribute);
  110.       m_Counts[(int)instance.classValue()][attIndex][0]++;
  111.     }
  112.   }
  113.   attIndex++;
  114. }
  115. m_Priors[(int)instance.classValue()]++;
  116.       }
  117.     }
  118.     
  119.     // Compute means
  120.     Enumeration enumAtts = instances.enumerateAttributes();
  121.     attIndex = 0;
  122.     while (enumAtts.hasMoreElements()) {
  123.       Attribute attribute = (Attribute) enumAtts.nextElement();
  124.       if (attribute.isNumeric()) {
  125. for (int j = 0; j < instances.numClasses(); j++) {
  126.   if (m_Counts[j][attIndex][0] < 2) {
  127.     throw new Exception("attribute " + attribute.name() +
  128. ": less than two values for class " +
  129. instances.classAttribute().value(j));
  130.   }
  131.   m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
  132. }
  133.       }
  134.       attIndex++;
  135.     }    
  136.     
  137.     // Compute standard deviations
  138.     enumInsts = instances.enumerateInstances();
  139.     while (enumInsts.hasMoreElements()) {
  140.       Instance instance = 
  141. (Instance) enumInsts.nextElement();
  142.       if (!instance.classIsMissing()) {
  143. enumAtts = instances.enumerateAttributes();
  144. attIndex = 0;
  145. while (enumAtts.hasMoreElements()) {
  146.   Attribute attribute = (Attribute) enumAtts.nextElement();
  147.   if (!instance.isMissing(attribute)) {
  148.     if (attribute.isNumeric()) {
  149.       m_Devs[(int)instance.classValue()][attIndex] +=
  150. (m_Means[(int)instance.classValue()][attIndex]-
  151.  instance.value(attribute))*
  152. (m_Means[(int)instance.classValue()][attIndex]-
  153.  instance.value(attribute));
  154.     }
  155.   }
  156.   attIndex++;
  157. }
  158.       }
  159.     }
  160.     enumAtts = instances.enumerateAttributes();
  161.     attIndex = 0;
  162.     while (enumAtts.hasMoreElements()) {
  163.       Attribute attribute = (Attribute) enumAtts.nextElement();
  164.       if (attribute.isNumeric()) {
  165. for (int j = 0; j < instances.numClasses(); j++) {
  166.   if (m_Devs[j][attIndex] <= 0) {
  167.     throw new Exception("attribute " + attribute.name() +
  168. ": standard deviation is 0 for class " +
  169. instances.classAttribute().value(j));
  170.   }
  171.   else {
  172.     m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
  173.     m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
  174.   }
  175. }
  176.       }
  177.       attIndex++;
  178.     } 
  179.     
  180.     // Normalize counts
  181.     enumAtts = instances.enumerateAttributes();
  182.     attIndex = 0;
  183.     while (enumAtts.hasMoreElements()) {
  184.       Attribute attribute = (Attribute) enumAtts.nextElement();
  185.       if (attribute.isNominal()) {
  186. for (int j = 0; j < instances.numClasses(); j++) {
  187.   sum = Utils.sum(m_Counts[j][attIndex]);
  188.   for (int i = 0; i < attribute.numValues(); i++) {
  189.     m_Counts[j][attIndex][i] =
  190.       (m_Counts[j][attIndex][i] + 1) 
  191.       / (sum + (double)attribute.numValues());
  192.   }
  193. }
  194.       }
  195.       attIndex++;
  196.     }
  197.     
  198.     // Normalize priors
  199.     sum = Utils.sum(m_Priors);
  200.     for (int j = 0; j < instances.numClasses(); j++)
  201.       m_Priors[j] = (m_Priors[j] + 1) 
  202. / (sum + (double)instances.numClasses());
  203.   }
  204.   /**
  205.    * Calculates the class membership probabilities for the given test instance.
  206.    *
  207.    * @param instance the instance to be classified
  208.    * @return predicted class probability distribution
  209.    * @exception Exception if distribution can't be computed
  210.    */
  211.   public double[] distributionForInstance(Instance instance) throws Exception {
  212.     
  213.     double [] probs = new double[instance.numClasses()];
  214.     int attIndex;
  215.     
  216.     for (int j = 0; j < instance.numClasses(); j++) {
  217.       probs[j] = 1;
  218.       Enumeration enumAtts = instance.enumerateAttributes();
  219.       attIndex = 0;
  220.       while (enumAtts.hasMoreElements()) {
  221. Attribute attribute = (Attribute) enumAtts.nextElement();
  222. if (!instance.isMissing(attribute)) {
  223.   if (attribute.isNominal()) {
  224.     probs[j] *= m_Counts[j][attIndex][(int)instance.value(attribute)];
  225.   } else {
  226.     probs[j] *= normalDens(instance.value(attribute),
  227.    m_Means[j][attIndex],
  228.    m_Devs[j][attIndex]);}
  229. }
  230. attIndex++;
  231.       }
  232.       probs[j] *= m_Priors[j];
  233.     }
  234.     
  235.     // Normalize probabilities
  236.     Utils.normalize(probs);
  237.     return probs;
  238.   }
  239.   /**
  240.    * Returns a description of the classifier.
  241.    *
  242.    * @return a description of the classifier as a string.
  243.    */
  244.   public String toString() {
  245.     if (m_Instances == null) {
  246.       return "Naive Bayes (simple): No model built yet.";
  247.     }
  248.     try {
  249.       StringBuffer text = new StringBuffer("Naive Bayes (simple)");
  250.       int attIndex;
  251.       
  252.       for (int i = 0; i < m_Instances.numClasses(); i++) {
  253. text.append("nnClass " + m_Instances.classAttribute().value(i) 
  254.     + ": P(C) = " 
  255.     + Utils.doubleToString(m_Priors[i], 10, 8)
  256.     + "nn");
  257. Enumeration enumAtts = m_Instances.enumerateAttributes();
  258. attIndex = 0;
  259. while (enumAtts.hasMoreElements()) {
  260.   Attribute attribute = (Attribute) enumAtts.nextElement();
  261.   text.append("Attribute " + attribute.name() + "n");
  262.   if (attribute.isNominal()) {
  263.     for (int j = 0; j < attribute.numValues(); j++) {
  264.       text.append(attribute.value(j) + "t");
  265.     }
  266.     text.append("n");
  267.     for (int j = 0; j < attribute.numValues(); j++)
  268.       text.append(Utils.
  269.   doubleToString(m_Counts[i][attIndex][j], 10, 8)
  270.   + "t");
  271.   } else {
  272.     text.append("Mean: " + Utils.
  273. doubleToString(m_Means[i][attIndex], 10, 8) + "t");
  274.     text.append("Standard Deviation: " 
  275. + Utils.doubleToString(m_Devs[i][attIndex], 10, 8));
  276.   }
  277.   text.append("nn");
  278.   attIndex++;
  279. }
  280.       }
  281.       
  282.       return text.toString();
  283.     } catch (Exception e) {
  284.       return "Can't print Naive Bayes classifier!";
  285.     }
  286.   }
  287.   /**
  288.    * Density function of normal distribution.
  289.    */
  290.   private double normalDens(double x, double mean, double stdDev) {
  291.     
  292.     double diff = x - mean;
  293.     
  294.     return (1 / (NORM_CONST * stdDev)) 
  295.       * Math.exp(-(diff * diff / (2 * stdDev * stdDev)));
  296.   }
  297.   /**
  298.    * Main method for testing this class.
  299.    *
  300.    * @param argv the options
  301.    */
  302.   public static void main(String [] argv) {
  303.     Classifier scheme;
  304.     try {
  305.       scheme = new NaiveBayesSimple();
  306.       System.out.println(Evaluation.evaluateModel(scheme, argv));
  307.     } catch (Exception e) {
  308.       System.err.println(e.getMessage());
  309.     }
  310.   }
  311. }