BayesNetB2.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 12k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  * This program is free software; you can redistribute it and/or modify
  3.  * it under the terms of the GNU General Public License as published by
  4.  * the Free Software Foundation; either version 2 of the License, or
  5.  * (at your option) any later version.
  6.  * 
  7.  * This program is distributed in the hope that it will be useful,
  8.  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  * GNU General Public License for more details.
  11.  * 
  12.  * You should have received a copy of the GNU General Public License
  13.  * along with this program; if not, write to the Free Software
  14.  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  * BayesNetB2.java
  18.  * Copyright (C) 2001 Remco Bouckaert
  19.  * 
  20.  */
  21. package weka.classifiers.bayes;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. import weka.estimators.*;
  26. import weka.classifiers.*;
  27. /**
  28.  * Class for a Bayes Network classifier based on Buntines hill climbing algorithm for
  29.  * learning structure, but augmented to allow arc reversal as an operation.
  30.  * Works with nominal variables only.
  31.  * 
  32.  * @author Remco Bouckaert (rrb@xm.co.nz)
  33.  * @version $Revision: 1.2 $
  34.  */
  35. public class BayesNetB2 extends BayesNetB {
  36.   /**
  37.    * buildStructure determines the network structure/graph of the network
  38.    * with Buntines greedy hill climbing algorithm, restricted by its initial
  39.    * structure (which can be an empty graph, or a Naive Bayes graph.
  40.    */
  41.   public void buildStructure() throws Exception {
  42.     // determine base scores
  43.     double[] fBaseScores = new double[m_Instances.numAttributes()];
  44.     int      nNrOfAtts = m_Instances.numAttributes();
  45.     for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
  46.       fBaseScores[iAttribute] = CalcNodeScore(iAttribute);
  47.     } 
  48.     // Determine initial structure by finding a good parent-set for classification
  49.     // node using greedy search
  50.     int     iAttribute = m_Instances.classIndex();
  51.     double  fBestScore = fBaseScores[iAttribute];
  52.     // /////////////////////////////////////////////////////////////////////////////////////////
  53.     /*
  54.      * int nBestAttribute1 = -1;
  55.      * int nBestAttribute2 = -1;
  56.      * for (int iAttribute1 = 0; iAttribute1 < m_Instances.numAttributes(); iAttribute1++) {
  57.      * if (iAttribute != iAttribute1) {
  58.      * for (int iAttribute2 = 0; iAttribute2 < iAttribute1; iAttribute2++) {
  59.      * if (iAttribute != iAttribute2) {
  60.      * m_ParentSets[iAttribute].AddParent(iAttribute1, m_Instances);
  61.      * double fScore = CalcScoreWithExtraParent(iAttribute, iAttribute2);
  62.      * m_ParentSets[iAttribute].DeleteLastParent(m_Instances);
  63.      * if (fScore > fBestScore) {
  64.      * fBestScore = fScore;
  65.      * nBestAttribute1 = iAttribute1;
  66.      * nBestAttribute2 = iAttribute2;
  67.      * }
  68.      * }
  69.      * }
  70.      * }
  71.      * }
  72.      * if (nBestAttribute1 != -1) {
  73.      * m_ParentSets[iAttribute].AddParent(nBestAttribute1, m_Instances);
  74.      * m_ParentSets[iAttribute].AddParent(nBestAttribute2, m_Instances);
  75.      * fBaseScores[iAttribute] = fBestScore;
  76.      * System.out.println("Added " +  nBestAttribute1 + " & " + nBestAttribute2);
  77.      * }
  78.      */
  79.     int     m_nMaxNrOfClassifierParents = 4;
  80.     // /////////////////////////////////////////////////////////////////////////////////////////
  81.     // double fBestScore = CalcNodeScore(iAttribute);
  82.     boolean bProgress = true;
  83.     while (bProgress 
  84.    && m_ParentSets[iAttribute].GetNrOfParents() 
  85.       < m_nMaxNrOfClassifierParents) {
  86.       int nBestAttribute = -1;
  87.       for (int iAttribute2 = 0; iAttribute2 < m_Instances.numAttributes(); 
  88.    iAttribute2++) {
  89. if (iAttribute != iAttribute2) {
  90.   double fScore = CalcScoreWithExtraParent(iAttribute, iAttribute2);
  91.   if (fScore > fBestScore) {
  92.     fBestScore = fScore;
  93.     nBestAttribute = iAttribute2;
  94.   } 
  95.       } 
  96.       if (nBestAttribute != -1) {
  97. m_ParentSets[iAttribute].AddParent(nBestAttribute, m_Instances);
  98. fBaseScores[iAttribute] = fBestScore;
  99.       } else {
  100. bProgress = false;
  101.       } 
  102.     } 
  103.     // Recalc Base scores
  104.     // Correction for Naive Bayes structures: delete arcs from classification node to children
  105.     for (int iParent = 0; 
  106.  iParent < m_ParentSets[iAttribute].GetNrOfParents(); iParent++) {
  107.       int nParentNode = m_ParentSets[iAttribute].GetParent(iParent);
  108.       if (IsArc(nParentNode, iAttribute)) {
  109. m_ParentSets[nParentNode].DeleteLastParent(m_Instances);
  110.       } 
  111.       // recalc base scores
  112.       fBaseScores[nParentNode] = CalcNodeScore(nParentNode);
  113.     } 
  114.     // super.buildStructure();
  115.     // Do algorithm B from here onwards
  116.     // cache scores & whether adding an arc makes sense
  117.     boolean[][] bAddArcMakesSense = new boolean[nNrOfAtts][nNrOfAtts];
  118.     double[][]  fScore = new double[nNrOfAtts][nNrOfAtts];
  119.     for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; 
  120.  iAttributeHead++) {
  121.       if (m_ParentSets[iAttributeHead].GetNrOfParents() < m_nMaxNrOfParents) {
  122. // only bother maintaining scores if adding parent does not violate the upper bound on nr of parents
  123. for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; 
  124.      iAttributeTail++) {
  125.   bAddArcMakesSense[iAttributeHead][iAttributeTail] = 
  126.     AddArcMakesSense(iAttributeHead, iAttributeTail);
  127.   if (bAddArcMakesSense[iAttributeHead][iAttributeTail]) {
  128.     fScore[iAttributeHead][iAttributeTail] = 
  129.       CalcScoreWithExtraParent(iAttributeHead, iAttributeTail);
  130.   } 
  131.       } 
  132.     } 
  133.     bProgress = true;
  134.     // go do the hill climbing
  135.     while (bProgress) {
  136.       bProgress = false;
  137.       int    nBestAttributeTail = -1;
  138.       int    nBestAttributeHead = -1;
  139.       double fBestDeltaScore = 0.0;
  140.       // find best arc to add
  141.       for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; 
  142.    iAttributeHead++) {
  143. if (m_ParentSets[iAttributeHead].GetNrOfParents() 
  144. < m_nMaxNrOfParents) {
  145.   for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; 
  146.        iAttributeTail++) {
  147.     if (bAddArcMakesSense[iAttributeHead][iAttributeTail]) {
  148.       // System.out.println("gain " +  iAttributeTail + " -> " + iAttributeHead + ": "+ (fScore[iAttributeHead][iAttributeTail] - fBaseScores[iAttributeHead]));
  149.       if (fScore[iAttributeHead][iAttributeTail] 
  150.       - fBaseScores[iAttributeHead] > fBestDeltaScore) {
  151. if (AddArcMakesSense(iAttributeHead, iAttributeTail)) {
  152.   fBestDeltaScore = fScore[iAttributeHead][iAttributeTail] 
  153.     - fBaseScores[iAttributeHead];
  154.   nBestAttributeTail = iAttributeTail;
  155.   nBestAttributeHead = iAttributeHead;
  156. } else {
  157.   bAddArcMakesSense[iAttributeHead][iAttributeTail] = false;
  158.       } 
  159.     } 
  160.   } 
  161.       } 
  162.       if (nBestAttributeHead >= 0) {
  163. // update network structure
  164. // System.out.println("Added " + nBestAttributeTail + " -> " + nBestAttributeHead);
  165. m_ParentSets[nBestAttributeHead].AddParent(nBestAttributeTail, 
  166.    m_Instances);
  167. if (m_ParentSets[nBestAttributeHead].GetNrOfParents() 
  168. < m_nMaxNrOfParents) {
  169.   // only bother updating scores if adding parent does not violate the upper bound on nr of parents
  170.   fBaseScores[nBestAttributeHead] += fBestDeltaScore;
  171.   // System.out.println(fScore[nBestAttributeHead][nBestAttributeTail] + " " + fBaseScores[nBestAttributeHead] + " " + fBestDeltaScore);
  172.   for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; 
  173.        iAttributeTail++) {
  174.     bAddArcMakesSense[nBestAttributeHead][iAttributeTail] = 
  175.       AddArcMakesSense(nBestAttributeHead, iAttributeTail);
  176.     if (bAddArcMakesSense[nBestAttributeHead][iAttributeTail]) {
  177.       fScore[nBestAttributeHead][iAttributeTail] = 
  178. CalcScoreWithExtraParent(nBestAttributeHead, iAttributeTail);
  179.       // System.out.println(iAttributeTail + " -> " + nBestAttributeHead + ": " + fScore[nBestAttributeHead][iAttributeTail]);
  180.     } 
  181.   } 
  182. bProgress = true;
  183.       } 
  184.     } 
  185.   }    // buildStructure
  186.  
  187.   /**
  188.    * IsArc checks whether the arc from iAttributeTail to iAttributeHead already exists
  189.    * 
  190.    * @param index of the attribute that becomes head of the arrow
  191.    * @param index of the attribute that becomes tail of the arrow
  192.    */
  193.   private boolean IsArc(int iAttributeHead, int iAttributeTail) {
  194.     for (int iParent = 0; 
  195.  iParent < m_ParentSets[iAttributeHead].GetNrOfParents(); iParent++) {
  196.       if (m_ParentSets[iAttributeHead].GetParent(iParent) == iAttributeTail) {
  197. return true;
  198.       } 
  199.     } 
  200.     return false;
  201.   }    // IsArc
  202.  
  203.   /**
  204.    * AddArcMakesSense checks whether adding the arc from iAttributeTail to iAttributeHead
  205.    * does not already exists and does not introduce a cycle
  206.    * 
  207.    * @param index of the attribute that becomes head of the arrow
  208.    * @param index of the attribute that becomes tail of the arrow
  209.    */
  210.   private boolean AddArcMakesSense(int iAttributeHead, int iAttributeTail) {
  211.     if (iAttributeHead == iAttributeTail) {
  212.       return false;
  213.     } 
  214.     // sanity check: arc should not be in parent set already
  215.     if (IsArc(iAttributeHead, iAttributeTail)) {
  216.       return false;
  217.     } 
  218.     // sanity check: arc should not introduce a cycle
  219.     int       nNodes = m_Instances.numAttributes();
  220.     boolean[] bDone = new boolean[nNodes];
  221.     for (int iNode = 0; iNode < nNodes; iNode++) {
  222.       bDone[iNode] = false;
  223.     } 
  224.     // check for cycles
  225.     m_ParentSets[iAttributeHead].AddParent(iAttributeTail, m_Instances);
  226.     for (int iNode = 0; iNode < nNodes; iNode++) {
  227.       // find a node for which all parents are 'done'
  228.       boolean bFound = false;
  229.       for (int iNode2 = 0; !bFound && iNode2 < nNodes; iNode2++) {
  230. if (!bDone[iNode2]) {
  231.   boolean bHasNoParents = true;
  232.   for (int iParent = 0; 
  233.        iParent < m_ParentSets[iNode2].GetNrOfParents(); iParent++) {
  234.     if (!bDone[m_ParentSets[iNode2].GetParent(iParent)]) {
  235.       bHasNoParents = false;
  236.     } 
  237.   } 
  238.   if (bHasNoParents) {
  239.     bDone[iNode2] = true;
  240.     bFound = true;
  241.   } 
  242.       } 
  243.       if (!bFound) {
  244. m_ParentSets[iAttributeHead].DeleteLastParent(m_Instances);
  245. return false;
  246.       } 
  247.     } 
  248.     m_ParentSets[iAttributeHead].DeleteLastParent(m_Instances);
  249.     return true;
  250.   }    // AddArcMakesCycle
  251.  
  252.   /**
  253.    * ReverseArcMakesCycle checks whether the arc from iAttributeTail to
  254.    * iAttributeHead exists and reversing does not introduce a cycle
  255.    * 
  256.    * @param index of the attribute that is head of the arrow
  257.    * @param index of the attribute that is tail of the arrow
  258.    */
  259.   private boolean ReverseArcMakesCycle(int iAttributeHead, 
  260.        int iAttributeTail) {
  261.     if (iAttributeHead == iAttributeTail) {
  262.       return false;
  263.     } 
  264.     // sanity check: arc should be in parent set already
  265.     if (!IsArc(iAttributeHead, iAttributeTail)) {
  266.       return false;
  267.     } 
  268.     // sanity check: arc should not introduce a cycle
  269.     int       nNodes = m_Instances.numAttributes();
  270.     boolean[] bDone = new boolean[nNodes];
  271.     for (int iNode = 0; iNode < nNodes; iNode++) {
  272.       bDone[iNode] = false;
  273.     } 
  274.     // check for cycles
  275.     m_ParentSets[iAttributeTail].AddParent(iAttributeHead, m_Instances);
  276.     for (int iNode = 0; iNode < nNodes; iNode++) {
  277.       // find a node for which all parents are 'done'
  278.       boolean bFound = false;
  279.       for (int iNode2 = 0; !bFound && iNode2 < nNodes; iNode2++) {
  280. if (!bDone[iNode2]) {
  281.   boolean bHasNoParents = true;
  282.   for (int iParent = 0; 
  283.        iParent < m_ParentSets[iNode2].GetNrOfParents(); iParent++) {
  284.     if (!bDone[m_ParentSets[iNode2].GetParent(iParent)]) {
  285.       // this one has a parent which is not 'done' UNLESS it is the arc to be reversed
  286.       if (iNode2 != iAttributeHead 
  287.       || m_ParentSets[iNode2].GetParent(iParent) 
  288.  != iAttributeTail) {
  289. bHasNoParents = false;
  290.       } 
  291.     } 
  292.   } 
  293.   if (bHasNoParents) {
  294.     bDone[iNode2] = true;
  295.     bFound = true;
  296.   } 
  297.       } 
  298.       if (!bFound) {
  299. m_ParentSets[iAttributeTail].DeleteLastParent(m_Instances);
  300. return false;
  301.       } 
  302.     } 
  303.     m_ParentSets[iAttributeTail].DeleteLastParent(m_Instances);
  304.     return true;
  305.   }    // ReverseArcMakesCycle
  306.  
  307.   /**
  308.    * Main method for testing this class.
  309.    * 
  310.    * @param argv the options
  311.    */
  312.   public static void main(String[] argv) {
  313.     try {
  314.       System.out.println(Evaluation.evaluateModel(new BayesNetB2(), argv));
  315.     } catch (Exception e) {
  316.       e.printStackTrace();
  317.       System.err.println(e.getMessage());
  318.     } 
  319.   }    // main
  320.  
  321. }      // class BayesNetB2