BayesNetB.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 7k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  * This program is free software; you can redistribute it and/or modify
  3.  * it under the terms of the GNU General Public License as published by
  4.  * the Free Software Foundation; either version 2 of the License, or
  5.  * (at your option) any later version.
  6.  * 
  7.  * This program is distributed in the hope that it will be useful,
  8.  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  * GNU General Public License for more details.
  11.  * 
  12.  * You should have received a copy of the GNU General Public License
  13.  * along with this program; if not, write to the Free Software
  14.  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  * BayesNetB.java
  18.  * Copyright (C) 2001 Remco Bouckaert
  19.  * 
  20.  */
  21. package weka.classifiers.bayes;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. import weka.estimators.*;
  26. import weka.classifiers.*;
  27. /**
  28.  * Class for a Bayes Network classifier based on a hill climbing algorithm for
  29.  * learning structure as described in Buntine, W. (1991). Theory refinement on
  30.  * Bayesian networks. In Proceedings of Seventh Conference on Uncertainty in
  31.  * Artificial Intelligence, Los Angeles, CA, pages 52--60. Morgan Kaufmann.
  32.  * Works with nominal variables and no missing values only.
  33.  * 
  34.  * @author Remco Bouckaert (rrb@xm.co.nz)
  35.  * @version $Revision: 1.3 $
  36.  */
  37. public class BayesNetB extends BayesNet {
  38.   /**
  39.    * buildStructure determines the network structure/graph of the network
  40.    * with Buntines greedy hill climbing algorithm, restricted by its initial
  41.    * structure (which can be an empty graph, or a Naive Bayes graph.
  42.    */
  43.   public void buildStructure() throws Exception {
  44.     // determine base scores
  45.     double[] fBaseScores = new double[m_Instances.numAttributes()];
  46.     int      nNrOfAtts = m_Instances.numAttributes();
  47.     for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
  48.       fBaseScores[iAttribute] = CalcNodeScore(iAttribute);
  49.     } 
  50.     // B algorithm: greedy search (not restricted by ordering like K2)
  51.     boolean     bProgress = true;
  52.     // cache scores & whether adding an arc makes sense
  53.     boolean[][] bAddArcMakesSense = new boolean[nNrOfAtts][nNrOfAtts];
  54.     double[][]  fScore = new double[nNrOfAtts][nNrOfAtts];
  55.     for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; 
  56.  iAttributeHead++) {
  57.       if (m_ParentSets[iAttributeHead].GetNrOfParents() < m_nMaxNrOfParents) {
  58. // only bother maintaining scores if adding parent does not violate the upper bound on nr of parents
  59. for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; 
  60.      iAttributeTail++) {
  61.   bAddArcMakesSense[iAttributeHead][iAttributeTail] = 
  62.     AddArcMakesSense(iAttributeHead, iAttributeTail);
  63.   if (bAddArcMakesSense[iAttributeHead][iAttributeTail]) {
  64.     fScore[iAttributeHead][iAttributeTail] = 
  65.       CalcScoreWithExtraParent(iAttributeHead, iAttributeTail);
  66.   } 
  67.       } 
  68.     } 
  69.     // go do the hill climbing
  70.     while (bProgress) {
  71.       bProgress = false;
  72.       int    nBestAttributeTail = -1;
  73.       int    nBestAttributeHead = -1;
  74.       double fBestDeltaScore = 0.0;
  75.       // find best arc to add
  76.       for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; 
  77.    iAttributeHead++) {
  78. if (m_ParentSets[iAttributeHead].GetNrOfParents() 
  79. < m_nMaxNrOfParents) {
  80.   for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; 
  81.        iAttributeTail++) {
  82.     if (bAddArcMakesSense[iAttributeHead][iAttributeTail]) {
  83.       if (fScore[iAttributeHead][iAttributeTail] 
  84.       - fBaseScores[iAttributeHead] > fBestDeltaScore) {
  85. if (AddArcMakesSense(iAttributeHead, iAttributeTail)) {
  86.   fBestDeltaScore = fScore[iAttributeHead][iAttributeTail] 
  87.     - fBaseScores[iAttributeHead];
  88.   nBestAttributeTail = iAttributeTail;
  89.   nBestAttributeHead = iAttributeHead;
  90. } else {
  91.   bAddArcMakesSense[iAttributeHead][iAttributeTail] = false;
  92.       } 
  93.     } 
  94.   } 
  95.       } 
  96.       if (nBestAttributeHead >= 0) {
  97. // update network structure
  98. m_ParentSets[nBestAttributeHead].AddParent(nBestAttributeTail, 
  99.    m_Instances);
  100. if (m_ParentSets[nBestAttributeHead].GetNrOfParents() 
  101. < m_nMaxNrOfParents) {
  102.   // only bother updating scores if adding parent does not violate the upper bound on nr of parents
  103.   fBaseScores[nBestAttributeHead] += fBestDeltaScore;
  104.   for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; 
  105.        iAttributeTail++) {
  106.     bAddArcMakesSense[nBestAttributeHead][iAttributeTail] = 
  107.       AddArcMakesSense(nBestAttributeHead, iAttributeTail);
  108.     if (bAddArcMakesSense[nBestAttributeHead][iAttributeTail]) {
  109.       fScore[nBestAttributeHead][iAttributeTail] = 
  110. CalcScoreWithExtraParent(nBestAttributeHead, iAttributeTail);
  111.     } 
  112.   } 
  113. bProgress = true;
  114.       } 
  115.     } 
  116.   }    // buildStructure
  117.  
  118.   /**
  119.    * AddArcMakesSense checks whether adding the arc from iAttributeTail to iAttributeHead
  120.    * does not already exists and does not introduce a cycle
  121.    * 
  122.    * @param iAttributeHead index of the attribute that becomes head of the arrow
  123.    * @param iAttributeTail index of the attribute that becomes tail of the arrow
  124.    * @return true if adding arc is allowed, otherwise false
  125.    */
  126.   private boolean AddArcMakesSense(int iAttributeHead, int iAttributeTail) {
  127.     if (iAttributeHead == iAttributeTail) {
  128.       return false;
  129.     } 
  130.     // sanity check: arc should not be in parent set already
  131.     for (int iParent = 0; 
  132.  iParent < m_ParentSets[iAttributeHead].GetNrOfParents(); iParent++) {
  133.       if (m_ParentSets[iAttributeHead].GetParent(iParent) == iAttributeTail) {
  134. return false;
  135.       } 
  136.     } 
  137.     // sanity check: arc should not introduce a cycle
  138.     int       nNodes = m_Instances.numAttributes();
  139.     boolean[] bDone = new boolean[nNodes];
  140.     for (int iNode = 0; iNode < nNodes; iNode++) {
  141.       bDone[iNode] = false;
  142.     } 
  143.     // check for cycles
  144.     m_ParentSets[iAttributeHead].AddParent(iAttributeTail, m_Instances);
  145.     for (int iNode = 0; iNode < nNodes; iNode++) {
  146.       // find a node for which all parents are 'done'
  147.       boolean bFound = false;
  148.       for (int iNode2 = 0; !bFound && iNode2 < nNodes; iNode2++) {
  149. if (!bDone[iNode2]) {
  150.   boolean bHasNoParents = true;
  151.   for (int iParent = 0; 
  152.        iParent < m_ParentSets[iNode2].GetNrOfParents(); iParent++) {
  153.     if (!bDone[m_ParentSets[iNode2].GetParent(iParent)]) {
  154.       bHasNoParents = false;
  155.     } 
  156.   } 
  157.   if (bHasNoParents) {
  158.     bDone[iNode2] = true;
  159.     bFound = true;
  160.   } 
  161.       } 
  162.       if (!bFound) {
  163. m_ParentSets[iAttributeHead].DeleteLastParent(m_Instances);
  164. return false;
  165.       } 
  166.     } 
  167.     m_ParentSets[iAttributeHead].DeleteLastParent(m_Instances);
  168.     return true;
  169.   }    // AddArcMakesCycle
  170.  
  171.   /**
  172.    * This will return a string describing the classifier.
  173.    * @return The string.
  174.    */
  175.   public String globalInfo() {
  176.     return "This Bayes Network learning algorithm uses a hill climbing algorithm" +
  177.     " without restriction on the order of variables";
  178.   }
  179.   /**
  180.    * Main method for testing this class.
  181.    * 
  182.    * @param argv the options
  183.    */
  184.   public static void main(String[] argv) {
  185.     try {
  186.       System.out.println(Evaluation.evaluateModel(new BayesNetB(), argv));
  187.     } catch (Exception e) {
  188.       e.printStackTrace();
  189.       System.err.println(e.getMessage());
  190.     } 
  191.   }    // main
  192.  
  193. }      // class BayesNetB