DecisionTable.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 34k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    DecisionTable.java
  18.  *    Copyright (C) 1999 Mark Hall
  19.  *
  20.  */
  21. package weka.classifiers;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. import weka.filters.*;
  26. import weka.classifiers.*;
  27. import weka.classifiers.j48.*;
  28. /**
  29.  * Class for building and using a simple decision table majority classifier.
  30.  * For more information see: <p>
  31.  *
  32.  * Kohavi R. (1995).<i> The Power of Decision Tables.</i> In Proc
  33.  * European Conference on Machine Learning.<p>
  34.  *
  35.  * Valid options are: <p>
  36.  *
  37.  * -S num <br>
  38.  * Number of fully expanded non improving subsets to consider
  39.  * before terminating a best first search.
  40.  * (Default = 5) <p>
  41.  *
  42.  * -X num <br>
  43.  * Use cross validation to evaluate features. Use number of folds = 1 for
  44.  * leave one out CV. (Default = leave one out CV) <p>
  45.  * 
  46.  * -I <br>
  47.  * Use nearest neighbour instead of global table majority. <p>
  48.  *
  49.  * -R <br>
  50.  * Prints the decision table. <p>
  51.  *
  52.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  53.  * @version $Revision: 1.21 $ 
  54.  */
  55. public class DecisionTable 
  56.   extends DistributionClassifier 
  57.   implements OptionHandler, WeightedInstancesHandler, 
  58.      AdditionalMeasureProducer {
  59.   
  60.   /** The hashtable used to hold training instances */
  61.   private Hashtable m_entries;
  62.   /** Holds the final feature set */
  63.   private int [] m_decisionFeatures;
  64.   /** Discretization filter */
  65.   private DiscretizeFilter m_disTransform;
  66.   /** Filter used to remove columns discarded by feature selection */
  67.   private AttributeFilter m_delTransform;
  68.   /** IB1 used to classify non matching instances rather than majority class */
  69.   private IBk m_ibk;
  70.   
  71.   /** Holds the training instances */
  72.   private Instances m_theInstances;
  73.   
  74.   /** The number of attributes in the dataset */
  75.   private int m_numAttributes;
  76.   /** The number of instances in the dataset */
  77.   private int m_numInstances;
  78.   /** Class is nominal */
  79.   private boolean m_classIsNominal;
  80.   /** Output debug info */
  81.   private boolean m_debug;
  82.   /** Use the IBk classifier rather than majority class */
  83.   private boolean m_useIBk;
  84.   /** Display Rules */
  85.   private boolean m_displayRules;
  86.   /** 
  87.    * Maximum number of fully expanded non improving subsets for a best 
  88.    * first search. 
  89.    */
  90.   private int m_maxStale;
  91.   /** Number of folds for cross validating feature sets */
  92.   private int m_CVFolds;
  93.   /** Random numbers for use in cross validation */
  94.   private Random m_rr;
  95.   /** Holds the majority class */
  96.   private double m_majority;
  97.   /**
  98.    * Class for a node in a linked list. Used in best first search.
  99.    */
  100.   public class Link {
  101.     /** The group */
  102.     BitSet m_group;
  103.     /** The merit */
  104.     double m_merit;
  105.     /**
  106.      * The constructor.
  107.      *
  108.      * @param gr the group
  109.      * @param mer the merit
  110.      */
  111.     public Link (BitSet gr, double mer) {
  112.       m_group = (BitSet)gr.clone();
  113.       m_merit = mer;
  114.     }
  115.   
  116.     /**
  117.      * Gets the group.
  118.      */
  119.     public BitSet getGroup() {
  120.       return m_group;
  121.     }
  122.   
  123.     /**
  124.      * Gets the merit.
  125.      */
  126.     public double getMerit() {
  127.       return m_merit;
  128.     }
  129.     /**
  130.      * Returns string representation.
  131.      */
  132.     public String toString() {
  133.       return ("Node: "+m_group.toString()+"  "+m_merit);
  134.     }
  135.   }
  136.   
  137.   /**
  138.    * Class for handling a linked list. Used in best first search.
  139.    * Extends the Vector class.
  140.    */
  141.   public class LinkedList extends FastVector {
  142.     /**
  143.      * Removes an element (Link) at a specific index from the list.
  144.      *
  145.      * @param index the index of the element to be removed.
  146.      */
  147.     public void removeLinkAt(int index) throws Exception {
  148.       if ((index >= 0) && (index < size())) {
  149. removeElementAt(index);
  150.       } else {
  151. throw new Exception("index out of range (removeLinkAt)");
  152.       }
  153.     }
  154.     /**
  155.      * Returns the element (Link) at a specific index from the list.
  156.      *
  157.      * @param index the index of the element to be returned.
  158.      */
  159.     public Link getLinkAt(int index) throws Exception {
  160.       if (size()==0) {
  161. throw new Exception("List is empty (getLinkAt)");
  162.       } else if ((index >= 0) && (index < size())) {
  163. return ((Link)(elementAt(index)));
  164.       } else {
  165. throw new Exception("index out of range (getLinkAt)");
  166.       }
  167.     }
  168.     /**
  169.      * Aadds an element (Link) to the list.
  170.      *
  171.      * @param gr the feature set specification
  172.      * @param mer the "merit" of this feature set
  173.      */
  174.     public void addToList(BitSet gr, double mer) {
  175.       Link newL = new Link(gr, mer);
  176.       if (size()==0) {
  177. addElement(newL);
  178.       }
  179.       else if (mer > ((Link)(firstElement())).getMerit()) {
  180. insertElementAt(newL,0);
  181.       } else {
  182. int i = 0;
  183. int size = size();
  184. boolean done = false;
  185. while ((!done) && (i < size)) {
  186.   if (mer > ((Link)(elementAt(i))).getMerit()) {
  187.     insertElementAt(newL,i);
  188.     done = true;
  189.   } else if (i == size-1) {
  190.     addElement(newL);
  191.     done = true;
  192.   } else {
  193.     i++;
  194.   }
  195. }
  196.       }
  197.     }
  198.   }
  199.   /**
  200.    * Class providing keys to the hash table
  201.    */
  202.   public class hashKey implements Serializable {
  203.     
  204.     /** Array of attribute values for an instance */
  205.     private double [] attributes;
  206.     
  207.     /** True for an index if the corresponding attribute value is missing. */
  208.     private boolean [] missing;
  209.     /** The values */
  210.     private String [] values;
  211.     /** The key */
  212.     private int key;
  213.     /**
  214.      * Constructor for a hashKey
  215.      *
  216.      * @param t an instance from which to generate a key
  217.      * @param numAtts the number of attributes
  218.      */
  219.     public hashKey(Instance t, int numAtts) throws Exception {
  220.       int i;
  221.       int cindex = t.classIndex();
  222.       key = -999;
  223.       attributes = new double [numAtts];
  224.       missing = new boolean [numAtts];
  225.       for (i=0;i<numAtts;i++) {
  226. if (i == cindex) {
  227.   missing[i] = true;
  228. } else {
  229.   if ((missing[i] = t.isMissing(i)) == false) {
  230.     attributes[i] = t.value(i);
  231.   }
  232. }
  233.       }
  234.     }
  235.     /**
  236.      * Convert a hash entry to a string
  237.      *
  238.      * @param t the set of instances
  239.      * @param maxColWidth width to make the fields
  240.      */
  241.     public String toString(Instances t, int maxColWidth) {
  242.       int i;
  243.       int cindex = t.classIndex();
  244.       StringBuffer text = new StringBuffer();
  245.       
  246.       for (i=0;i<attributes.length;i++) {
  247. if (i != cindex) {
  248.   if (missing[i]) {
  249.     text.append("?");
  250.     for (int j=0;j<maxColWidth;j++) {
  251.       text.append(" ");
  252.     }
  253.   } else {
  254.     String ss = t.attribute(i).value((int)attributes[i]);
  255.     StringBuffer sb = new StringBuffer(ss);
  256.     
  257.     for (int j=0;j < (maxColWidth-ss.length()+1); j++) {
  258. sb.append(" ");
  259.     }
  260.     text.append(sb);
  261.   }
  262. }
  263.       }
  264.       return text.toString();
  265.     }
  266.     /**
  267.      * Constructor for a hashKey
  268.      *
  269.      * @param t an array of feature values
  270.      */
  271.     public hashKey(double [] t) {
  272.       int i;
  273.       int l = t.length;
  274.       key = -999;
  275.       attributes = new double [l];
  276.       missing = new boolean [l];
  277.       for (i=0;i<l;i++) {
  278. if (t[i] == Double.MAX_VALUE) {
  279.   missing[i] = true;
  280. } else {
  281.   missing[i] = false;
  282.   attributes[i] = t[i];
  283. }
  284.       }
  285.     }
  286.     
  287.     /**
  288.      * Calculates a hash code
  289.      *
  290.      * @return the hash code as an integer
  291.      */
  292.     public int hashCode() {
  293.       int hv = 0;
  294.       
  295.       if (key != -999)
  296. return key;
  297.       for (int i=0;i<attributes.length;i++) {
  298. if (missing[i]) {
  299.   hv += (i*13);
  300. } else {
  301.   hv += (i * 5 * (attributes[i]+1));
  302. }
  303.       }
  304.       if (key == -999) {
  305. key = hv;
  306.       }
  307.       return hv;
  308.     }
  309.     /**
  310.      * Tests if two instances are equal
  311.      *
  312.      * @param b a key to compare with
  313.      */
  314.     public boolean equals(Object b) {
  315.       
  316.       if ((b == null) || !(b.getClass().equals(this.getClass()))) {
  317.         return false;
  318.       }
  319.       boolean ok = true;
  320.       boolean l;
  321.       if (b instanceof hashKey) {
  322. hashKey n = (hashKey)b;
  323. for (int i=0;i<attributes.length;i++) {
  324.   l = n.missing[i];
  325.   if (missing[i] || l) {
  326.     if ((missing[i] && !l) || (!missing[i] && l)) {
  327.       ok = false;
  328.       break;
  329.     }
  330.   } else {
  331.     if (attributes[i] != n.attributes[i]) {
  332.       ok = false;
  333.       break;
  334.     }
  335.   }
  336. }
  337.       } else {
  338. return false;
  339.       }
  340.       return ok;
  341.     }
  342.     
  343.     /**
  344.      * Prints the hash code
  345.      */
  346.     public void print_hash_code() {
  347.       
  348.       System.out.println("Hash val: "+hashCode());
  349.     }
  350.   }
  351.   /**
  352.    * Inserts an instance into the hash table
  353.    *
  354.    * @param inst instance to be inserted
  355.    * @exception Exception if the instance can't be inserted
  356.    */
  357.   private void insertIntoTable(Instance inst, double [] instA)
  358.        throws Exception {
  359.     double [] tempClassDist2;
  360.     double [] newDist;
  361.     hashKey thekey;
  362.     if (instA != null) {
  363.       thekey = new hashKey(instA);
  364.     } else {
  365.       thekey = new hashKey(inst, inst.numAttributes());
  366.     }
  367.       
  368.     // see if this one is already in the table
  369.     tempClassDist2 = (double []) m_entries.get(thekey);
  370.     if (tempClassDist2 == null) {
  371.       if (m_classIsNominal) {
  372. newDist = new double [m_theInstances.classAttribute().numValues()];
  373. newDist[(int)inst.classValue()] = inst.weight();
  374. // add to the table
  375. m_entries.put(thekey, newDist);
  376.       } else {
  377. newDist = new double [2];
  378. newDist[0] = inst.classValue() * inst.weight();
  379. newDist[1] = inst.weight();
  380. // add to the table
  381. m_entries.put(thekey, newDist);
  382.       }
  383.     } else { 
  384.       // update the distribution for this instance
  385.       if (m_classIsNominal) {
  386. tempClassDist2[(int)inst.classValue()]+=inst.weight();
  387. // update the table
  388. m_entries.put(thekey, tempClassDist2);
  389.       } else  {
  390. tempClassDist2[0] += (inst.classValue() * inst.weight());
  391. tempClassDist2[1] += inst.weight();
  392. // update the table
  393. m_entries.put(thekey, tempClassDist2);
  394.       }
  395.     }
  396.   }
  397.   /**
  398.    * Classifies an instance for internal leave one out cross validation
  399.    * of feature sets
  400.    *
  401.    * @param instance instance to be "left out" and classified
  402.    * @param instA feature values of the selected features for the instance
  403.    * @return the classification of the instance
  404.    */
  405.   double classifyInstanceLeaveOneOut(Instance instance, double [] instA)
  406.        throws Exception {
  407.     hashKey thekey;
  408.     double [] tempDist;
  409.     double [] normDist;
  410.     thekey = new hashKey(instA);
  411.     if (m_classIsNominal) {
  412.       // if this one is not in the table
  413.       if ((tempDist = (double [])m_entries.get(thekey)) == null) {
  414. throw new Error("This should never happen!");
  415.       } else {
  416. normDist = new double [tempDist.length];
  417. System.arraycopy(tempDist,0,normDist,0,tempDist.length);
  418. normDist[(int)instance.classValue()] -= instance.weight();
  419. // update the table
  420. // first check to see if the class counts are all zero now
  421. boolean ok = false;
  422. for (int i=0;i<normDist.length;i++) {
  423.   if (!Utils.eq(normDist[i],0.0)) {
  424.     ok = true;
  425.     break;
  426.   }
  427. }
  428. if (ok) {
  429.   Utils.normalize(normDist);
  430.   return Utils.maxIndex(normDist);
  431. } else {
  432.   return m_majority;
  433. }
  434.       }
  435.       //      return Utils.maxIndex(tempDist);
  436.     } else {
  437.       // see if this one is already in the table
  438.       if ((tempDist = (double[])m_entries.get(thekey)) != null) {
  439. normDist = new double [tempDist.length];
  440. System.arraycopy(tempDist,0,normDist,0,tempDist.length);
  441. normDist[0] -= (instance.classValue() * instance.weight());
  442. normDist[1] -= instance.weight();
  443. if (Utils.eq(normDist[1],0.0)) {
  444.     return m_majority;
  445. } else {
  446.   return (normDist[0] / normDist[1]);
  447. }
  448.       } else {
  449. throw new Error("This should never happen!");
  450.       }
  451.     }
  452.     
  453.     // shouldn't get here 
  454.     // return 0.0;
  455.   }
  456.   /**
  457.    * Calculates the accuracy on a test fold for internal cross validation
  458.    * of feature sets
  459.    *
  460.    * @param fold set of instances to be "left out" and classified
  461.    * @param fs currently selected feature set
  462.    * @return the accuracy for the fold
  463.    */
  464.   double classifyFoldCV(Instances fold, int [] fs) throws Exception {
  465.     int i;
  466.     int ruleCount = 0;
  467.     int numFold = fold.numInstances();
  468.     int numCl = m_theInstances.classAttribute().numValues();
  469.     double [][] class_distribs = new double [numFold][numCl];
  470.     double [] instA = new double [fs.length];
  471.     double [] normDist;
  472.     hashKey thekey;
  473.     double acc = 0.0;
  474.     int classI = m_theInstances.classIndex();
  475.     Instance inst;
  476.     if (m_classIsNominal) {
  477.       normDist = new double [numCl];
  478.     } else {
  479.       normDist = new double [2];
  480.     }
  481.     // first *remove* instances
  482.     for (i=0;i<numFold;i++) {
  483.       inst = fold.instance(i);
  484.       for (int j=0;j<fs.length;j++) {
  485. if (fs[j] == classI) {
  486.   instA[j] = Double.MAX_VALUE; // missing for the class
  487. } else if (inst.isMissing(fs[j])) {
  488.   instA[j] = Double.MAX_VALUE;
  489. } else{
  490.   instA[j] = inst.value(fs[j]);
  491. }
  492.       }
  493.       thekey = new hashKey(instA);
  494.       if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) {
  495. throw new Error("This should never happen!");
  496.       } else {
  497. if (m_classIsNominal) {
  498.   class_distribs[i][(int)inst.classValue()] -= inst.weight();
  499. } else {
  500.   class_distribs[i][0] -= (inst.classValue() * inst.weight());
  501.   class_distribs[i][1] -= inst.weight();
  502. }
  503. ruleCount++;
  504.       }
  505.     }
  506.     // now classify instances
  507.     for (i=0;i<numFold;i++) {
  508.       inst = fold.instance(i);
  509.       System.arraycopy(class_distribs[i],0,normDist,0,normDist.length);
  510.       if (m_classIsNominal) {
  511. boolean ok = false;
  512. for (int j=0;j<normDist.length;j++) {
  513.   if (!Utils.eq(normDist[j],0.0)) {
  514.     ok = true;
  515.     break;
  516.   }
  517. }
  518. if (ok) {
  519.   Utils.normalize(normDist);
  520.   if (Utils.maxIndex(normDist) == inst.classValue())
  521.     acc += inst.weight();
  522. } else {
  523.   if (inst.classValue() == m_majority) {
  524.     acc += inst.weight();
  525.   }
  526. }
  527.       } else {
  528. if (Utils.eq(normDist[1],0.0)) {
  529.     acc += ((inst.weight() * (m_majority - inst.classValue())) * 
  530.     (inst.weight() * (m_majority - inst.classValue())));
  531. } else {
  532.   double t = (normDist[0] / normDist[1]);
  533.   acc += ((inst.weight() * (t - inst.classValue())) * 
  534.   (inst.weight() * (t - inst.classValue())));
  535. }
  536.       }
  537.     }
  538.     // now re-insert instances
  539.     for (i=0;i<numFold;i++) {
  540.       inst = fold.instance(i);
  541.       if (m_classIsNominal) {
  542. class_distribs[i][(int)inst.classValue()] += inst.weight();
  543.       } else {
  544. class_distribs[i][0] += (inst.classValue() * inst.weight());
  545. class_distribs[i][1] += inst.weight();
  546.       }
  547.     }
  548.     return acc;
  549.   }
  550.   /**
  551.    * Evaluates a feature subset by cross validation
  552.    *
  553.    * @param feature_set the subset to be evaluated
  554.    * @param num_atts the number of attributes in the subset
  555.    * @return the estimated accuracy
  556.    * @exception Exception if subset can't be evaluated
  557.    */
  558.   private double estimateAccuracy(BitSet feature_set, int num_atts)
  559.     throws Exception {
  560.     int i;
  561.     Instances newInstances;
  562.     int [] fs = new int [num_atts];
  563.     double acc = 0.0;
  564.     double [][] evalArray;
  565.     double [] instA = new double [num_atts];
  566.     int classI = m_theInstances.classIndex();
  567.     
  568.     int index = 0;
  569.     for (i=0;i<m_numAttributes;i++) {
  570.       if (feature_set.get(i)) {
  571. fs[index++] = i;
  572.       }
  573.     }
  574.     // create new hash table
  575.     m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));
  576.     // insert instances into the hash table
  577.     for (i=0;i<m_numInstances;i++) {
  578.       Instance inst = m_theInstances.instance(i);
  579.       for (int j=0;j<fs.length;j++) {
  580. if (fs[j] == classI) {
  581.   instA[j] = Double.MAX_VALUE; // missing for the class
  582. } else if (inst.isMissing(fs[j])) {
  583.   instA[j] = Double.MAX_VALUE;
  584. } else {
  585.   instA[j] = inst.value(fs[j]);
  586. }
  587.       }
  588.       insertIntoTable(inst, instA);
  589.     }
  590.     
  591.     
  592.     if (m_CVFolds == 1) {
  593.       // calculate leave one out error
  594.       for (i=0;i<m_numInstances;i++) {
  595. Instance inst = m_theInstances.instance(i);
  596. for (int j=0;j<fs.length;j++) {
  597.   if (fs[j] == classI) {
  598.     instA[j] = Double.MAX_VALUE; // missing for the class
  599.   } else if (inst.isMissing(fs[j])) {
  600.     instA[j] = Double.MAX_VALUE;
  601.   } else {
  602.     instA[j] = inst.value(fs[j]);
  603.   }
  604. }
  605. double t = classifyInstanceLeaveOneOut(inst, instA);
  606. if (m_classIsNominal) {
  607.   if (t == inst.classValue()) {
  608.     acc+=inst.weight();
  609.   }
  610. } else {
  611.   acc += ((inst.weight() * (t - inst.classValue())) * 
  612.   (inst.weight() * (t - inst.classValue())));
  613. }
  614. // weight_sum += inst.weight();
  615.       }
  616.     } else {
  617.       m_theInstances.randomize(m_rr);
  618.       m_theInstances.stratify(m_CVFolds);
  619.       // calculate 10 fold cross validation error
  620.       for (i=0;i<m_CVFolds;i++) {
  621. Instances insts = m_theInstances.testCV(m_CVFolds,i);
  622. acc += classifyFoldCV(insts, fs);
  623.       }
  624.     }
  625.   
  626.     if (m_classIsNominal) {
  627.       return (acc / m_theInstances.sumOfWeights());
  628.     } else {
  629.       return -(Math.sqrt(acc / m_theInstances.sumOfWeights()));   
  630.     }
  631.   }
  632.   /**
  633.    * Returns a String representation of a feature subset
  634.    *
  635.    * @param sub BitSet representation of a subset
  636.    * @return String containing subset
  637.    */
  638.   private String printSub(BitSet sub) {
  639.     int i;
  640.     String s="";
  641.     for (int jj=0;jj<m_numAttributes;jj++) {
  642.       if (sub.get(jj)) {
  643. s += " "+(jj+1);
  644.       }
  645.     }
  646.     return s;
  647.   }
  648.     
  649.   /**
  650.    * Does a best first search 
  651.    */
  652.   private void best_first() throws Exception {
  653.     int i,j,classI,count=0,fc,tree_count=0;
  654.     int evals=0;
  655.     BitSet best_group, temp_group;
  656.     int [] stale;
  657.     double [] best_merit;
  658.     double merit;
  659.     boolean z;
  660.     boolean added;
  661.     Link tl;
  662.   
  663.     Hashtable lookup = new Hashtable((int)(200.0*m_numAttributes*1.5));
  664.     LinkedList bfList = new LinkedList();
  665.     best_merit = new double[1]; best_merit[0] = 0.0;
  666.     stale = new int[1]; stale[0] = 0;
  667.     best_group = new BitSet(m_numAttributes);
  668.     // Add class to initial subset
  669.     classI = m_theInstances.classIndex();
  670.     best_group.set(classI);
  671.     best_merit[0] = estimateAccuracy(best_group, 1);
  672.     if (m_debug)
  673.       System.out.println("Accuracy of initial subset: "+best_merit[0]);
  674.     // add the initial group to the list
  675.     bfList.addToList(best_group,best_merit[0]);
  676.     // add initial subset to the hashtable
  677.     lookup.put(best_group,"");
  678.     while (stale[0] < m_maxStale) {
  679.       added = false;
  680.       // finished search?
  681.       if (bfList.size()==0) {
  682. stale[0] = m_maxStale;
  683. break;
  684.       }
  685.       // copy the feature set at the head of the list
  686.       tl = bfList.getLinkAt(0);
  687.       temp_group = (BitSet)(tl.getGroup().clone());
  688.       // remove the head of the list
  689.       bfList.removeLinkAt(0);
  690.       for (i=0;i<m_numAttributes;i++) {
  691. // if (search_direction == 1)
  692. z = ((i != classI) && (!temp_group.get(i)));
  693. if (z) {
  694.   // set the bit (feature to add/delete) */
  695.   temp_group.set(i);
  696.   
  697.   /* if this subset has been seen before, then it is already in 
  698.      the list (or has been fully expanded) */
  699.   BitSet tt = (BitSet)temp_group.clone();
  700.   if (lookup.containsKey(tt) == false) {
  701.     fc = 0;
  702.     for (int jj=0;jj<m_numAttributes;jj++) {
  703.       if (tt.get(jj)) {
  704. fc++;
  705.       }
  706.     }
  707.     merit = estimateAccuracy(tt, fc);
  708.     if (m_debug) {
  709.       System.out.println("evaluating: "+printSub(tt)+" "+merit); 
  710.     }
  711.     
  712.     // is this better than the best?
  713.     // if (search_direction == 1)
  714.     z = ((merit - best_merit[0]) > 0.00001);
  715.  
  716.     // else
  717.     // z = ((best_merit[0] - merit) > 0.00001);
  718.     if (z) {
  719.       if (m_debug) {
  720. System.out.println("new best feature set: "+printSub(tt)+
  721.    " "+merit);
  722.       }
  723.       added = true;
  724.       stale[0] = 0;
  725.       best_merit[0] = merit;
  726.       best_group = (BitSet)(temp_group.clone());
  727.     }
  728.     // insert this one in the list and the hash table
  729.     bfList.addToList(tt, merit);
  730.     lookup.put(tt,"");
  731.     count++;
  732.   }
  733.   // unset this addition(deletion)
  734.   temp_group.clear(i);
  735. }
  736.       }
  737.       /* if we haven't added a new feature subset then full expansion 
  738.  of this node hasn't resulted in anything better */
  739.       if (!added) {
  740. stale[0]++;
  741.       }
  742.     }
  743.    
  744.     // set selected features
  745.     for (i=0,j=0;i<m_numAttributes;i++) {
  746.       if (best_group.get(i)) {
  747. j++;
  748.       }
  749.     }
  750.     
  751.     m_decisionFeatures = new int[j];
  752.     for (i=0,j=0;i<m_numAttributes;i++) {
  753.       if (best_group.get(i)) {
  754. m_decisionFeatures[j++] = i;    
  755.       }
  756.     }
  757.   }
  758.  
  759.   /**
  760.    * Resets the options.
  761.    */
  762.   protected void resetOptions()  {
  763.     m_entries = null;
  764.     m_decisionFeatures = null;
  765.     m_debug = false;
  766.     m_useIBk = false;
  767.     m_CVFolds = 1;
  768.     m_maxStale = 5;
  769.     m_displayRules = false;
  770.   }
  771.    /**
  772.    * Constructor for a DecisionTable
  773.    */
  774.   public DecisionTable() {
  775.     resetOptions();
  776.   }
  777.   /**
  778.    * Returns an enumeration describing the available options
  779.    *
  780.    * @return an enumeration of all the available options
  781.    */
  782.   public Enumeration listOptions() {
  783.     Vector newVector = new Vector(5);
  784.     newVector.addElement(new Option(
  785.               "tNumber of fully expanded non improving subsets to considern" +
  786.       "tbefore terminating a best first search.n" +
  787.       "tUse in conjunction with -B. (Default = 5)",
  788.               "S", 1, "-S <number of non improving nodes>"));
  789.     
  790.     newVector.addElement(new Option(
  791.               "tUse cross validation to evaluate features.n" +
  792.       "tUse number of folds = 1 for leave one out CV.n" +
  793.       "t(Default = leave one out CV)",
  794.               "X", 1, "-X <number of folds>"));
  795.      newVector.addElement(new Option(
  796.               "tUse nearest neighbour instead of global table majority.n",
  797.               "I", 0, "-I"));
  798.      newVector.addElement(new Option(
  799.               "tDisplay decision table rules.n",
  800.               "R", 0, "-R")); 
  801.     return newVector.elements();
  802.   }
  803.   /**
  804.    * Sets the number of folds for cross validation (1 = leave one out)
  805.    *
  806.    * @param folds the number of folds
  807.    */
  808.   public void setCrossVal(int folds) {
  809.     m_CVFolds = folds;
  810.   }
  811.   /**
  812.    * Gets the number of folds for cross validation
  813.    *
  814.    * @return the number of cross validation folds
  815.    */
  816.   public int getCrossVal() {
  817.     return m_CVFolds;
  818.   }
  819.   /**
  820.    * Sets the number of non improving decision tables to consider
  821.    * before abandoning the search.
  822.    *
  823.    * @param stale the number of nodes
  824.    */
  825.   public void setMaxStale(int stale) {
  826.     m_maxStale = stale;
  827.   }
  828.   /**
  829.    * Gets the number of non improving decision tables
  830.    *
  831.    * @return the number of non improving decision tables
  832.    */
  833.   public int getMaxStale() {
  834.     return m_maxStale;
  835.   }
  836.   /**
  837.    * Sets whether IBk should be used instead of the majority class
  838.    *
  839.    * @param ibk true if IBk is to be used
  840.    */
  841.   public void setUseIBk(boolean ibk) {
  842.     m_useIBk = ibk;
  843.   }
  844.   
  845.   /**
  846.    * Gets whether IBk is being used instead of the majority class
  847.    *
  848.    * @return true if IBk is being used
  849.    */
  850.   public boolean getUseIBk() {
  851.     return m_useIBk;
  852.   }
  853.   /**
  854.    * Sets whether rules are to be printed
  855.    *
  856.    * @param rules true if rules are to be printed
  857.    */
  858.   public void setDisplayRules(boolean rules) {
  859.     m_displayRules = rules;
  860.   }
  861.   
  862.   /**
  863.    * Gets whether rules are being printed
  864.    *
  865.    * @return true if rules are being printed
  866.    */
  867.   public boolean getDisplayRules() {
  868.     return m_displayRules;
  869.   }
  870.   /**
  871.    * Parses the options for this object.
  872.    *
  873.    * Valid options are: <p>
  874.    *
  875.    * -S num <br>
  876.    * Number of fully expanded non improving subsets to consider
  877.    * before terminating a best first search.
  878.    * (Default = 5) <p>
  879.    *
  880.    * -X num <br>
  881.    * Use cross validation to evaluate features. Use number of folds = 1 for
  882.    * leave one out CV. (Default = leave one out CV) <p>
  883.    * 
  884.    * -I <br>
  885.    * Use nearest neighbour instead of global table majority. <p>
  886.    *
  887.    * -R <br>
  888.    * Prints the decision table. <p>
  889.    *
  890.    * @param options the list of options as an array of strings
  891.    * @exception Exception if an option is not supported
  892.    */
  893.   public void setOptions(String[] options) throws Exception {
  894.     String optionString;
  895.     resetOptions();
  896.     optionString = Utils.getOption('X',options);
  897.     if (optionString.length() != 0) {
  898.       m_CVFolds = Integer.parseInt(optionString);
  899.     }
  900.     optionString = Utils.getOption('S',options);
  901.     if (optionString.length() != 0) {
  902.       m_maxStale = Integer.parseInt(optionString);
  903.     }
  904.     m_useIBk = Utils.getFlag('I',options);
  905.     m_displayRules = Utils.getFlag('R',options);
  906.   }
  907.   /**
  908.    * Gets the current settings of the classifier.
  909.    *
  910.    * @return an array of strings suitable for passing to setOptions
  911.    */
  912.   public String [] getOptions() {
  913.     String [] options = new String [7];
  914.     int current = 0;
  915.     options[current++] = "-X"; options[current++] = "" + m_CVFolds;
  916.     options[current++] = "-S"; options[current++] = "" + m_maxStale;
  917.     if (m_useIBk) {
  918.       options[current++] = "-I";
  919.     }
  920.     if (m_displayRules) {
  921.       options[current++] = "-R";
  922.     }
  923.     while (current < options.length) {
  924.       options[current++] = "";
  925.     }
  926.     return options;
  927.   }
  928.   
  929.   /**
  930.    * Generates the classifier.
  931.    *
  932.    * @param data set of instances serving as training data 
  933.    * @exception Exception if the classifier has not been generated successfully
  934.    */
  935.   public void buildClassifier(Instances data) throws Exception {
  936.     int i;
  937.     m_rr = new Random(1);
  938.     m_theInstances = new Instances(data);
  939.     m_theInstances.deleteWithMissingClass();
  940.     if (m_theInstances.numInstances() == 0) {
  941.       throw new Exception("No training instances without missing class!");
  942.     }
  943.     if (m_theInstances.checkForStringAttributes()) {
  944.       throw new Exception("Can't handle string attributes!");
  945.     }
  946.     
  947.     m_disTransform = new DiscretizeFilter();
  948.     if (m_theInstances.classAttribute().isNumeric()) {
  949.       m_classIsNominal = false;
  950.       
  951.       // use binned discretisation if the class is numeric
  952.       m_disTransform.setUseMDL(false);
  953.       m_disTransform.setBins(10);
  954.       m_disTransform.setInvertSelection(true);
  955.       
  956.       // Discretize all attributes EXCEPT the class 
  957.       String rangeList = "";
  958.       rangeList+=(m_theInstances.classIndex()+1);
  959.       //System.out.println("The class col: "+m_theInstances.classIndex());
  960.       
  961.       m_disTransform.setAttributeIndices(rangeList);
  962.     } else {
  963.       m_disTransform.setUseBetterEncoding(true);
  964.       m_classIsNominal = true;
  965.     }
  966.     m_disTransform.setInputFormat(m_theInstances);
  967.     m_theInstances = Filter.useFilter(m_theInstances, m_disTransform);
  968.     
  969.     m_numAttributes = m_theInstances.numAttributes();
  970.     m_numInstances = m_theInstances.numInstances();
  971.     m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute());
  972.     
  973.     best_first();
  974.     
  975.     // reduce instances to selected features
  976.     m_delTransform = new AttributeFilter();
  977.     m_delTransform.setInvertSelection(true);
  978.     
  979.     // set features to keep
  980.     m_delTransform.setAttributeIndicesArray(m_decisionFeatures); 
  981.     m_delTransform.setInputFormat(m_theInstances);
  982.     m_theInstances = Filter.useFilter(m_theInstances, m_delTransform);
  983.     
  984.     // reset the number of attributes
  985.     m_numAttributes = m_theInstances.numAttributes();
  986.     
  987.     // create hash table
  988.     m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));
  989.     
  990.     // insert instances into the hash table
  991.     for (i=0;i<m_numInstances;i++) {
  992.       Instance inst = m_theInstances.instance(i);
  993.       insertIntoTable(inst, null);
  994.     }
  995.     
  996.     // Replace the global table majority with nearest neighbour?
  997.     if (m_useIBk) {
  998.       m_ibk = new IBk();
  999.       m_ibk.buildClassifier(m_theInstances);
  1000.     }
  1001.     
  1002.     // Save memory
  1003.     m_theInstances = new Instances(m_theInstances, 0);
  1004.   }
  1005.   /**
  1006.    * Calculates the class membership probabilities for the given 
  1007.    * test instance.
  1008.    *
  1009.    * @param instance the instance to be classified
  1010.    * @return predicted class probability distribution
  1011.    * @exception Exception if distribution can't be computed
  1012.    */
  1013.   public double [] distributionForInstance(Instance instance)
  1014.        throws Exception {
  1015.     hashKey thekey;
  1016.     double [] tempDist;
  1017.     double [] normDist;
  1018.     m_disTransform.input(instance);
  1019.     m_disTransform.batchFinished();
  1020.     instance = m_disTransform.output();
  1021.     m_delTransform.input(instance);
  1022.     m_delTransform.batchFinished();
  1023.     instance = m_delTransform.output();
  1024.     thekey = new hashKey(instance, instance.numAttributes());
  1025.     
  1026.     // if this one is not in the table
  1027.     if ((tempDist = (double [])m_entries.get(thekey)) == null) {
  1028.       if (m_useIBk) {
  1029. tempDist = m_ibk.distributionForInstance(instance);
  1030.       } else {
  1031. if (!m_classIsNominal) {
  1032.   tempDist = new double[1];
  1033.   tempDist[0] = m_majority;
  1034. } else {
  1035.   tempDist = new double [m_theInstances.classAttribute().numValues()];
  1036.   tempDist[(int)m_majority] = 1.0;
  1037. }
  1038.       }
  1039.     } else {
  1040.       if (!m_classIsNominal) {
  1041. normDist = new double[1];
  1042. normDist[0] = (tempDist[0] / tempDist[1]);
  1043. tempDist = normDist;
  1044.       } else {
  1045. // normalise distribution
  1046. normDist = new double [tempDist.length];
  1047. System.arraycopy(tempDist,0,normDist,0,tempDist.length);
  1048. Utils.normalize(normDist);
  1049. tempDist = normDist;
  1050.       }
  1051.     }
  1052.     return tempDist;
  1053.   }
  1054.   /**
  1055.    * Returns a string description of the features selected
  1056.    *
  1057.    * @return a string of features
  1058.    */
  1059.   public String printFeatures() {
  1060.     int i;
  1061.     String s = "";
  1062.    
  1063.     for (i=0;i<m_decisionFeatures.length;i++) {
  1064.       if (i==0) {
  1065. s = ""+(m_decisionFeatures[i]+1);
  1066.       } else {
  1067. s += ","+(m_decisionFeatures[i]+1);
  1068.       }
  1069.     }
  1070.     return s;
  1071.   }
  1072.   /**
  1073.    * Returns the number of rules
  1074.    * @return the number of rules
  1075.    */
  1076.   public double measureNumRules() {
  1077.     return m_entries.size();
  1078.   }
  1079.   /**
  1080.    * Returns an enumeration of the additional measure names
  1081.    * @return an enumeration of the measure names
  1082.    */
  1083.   public Enumeration enumerateMeasures() {
  1084.     Vector newVector = new Vector(1);
  1085.     newVector.addElement("measureNumRules");
  1086.     return newVector.elements();
  1087.   }
  1088.   /**
  1089.    * Returns the value of the named measure
  1090.    * @param measureName the name of the measure to query for its value
  1091.    * @return the value of the named measure
  1092.    * @exception IllegalArgumentException if the named measure is not supported
  1093.    */
  1094.   public double getMeasure(String additionalMeasureName) {
  1095.     if (additionalMeasureName.compareTo("measureNumRules") == 0) {
  1096.       return measureNumRules();
  1097.     } else {
  1098.       throw new IllegalArgumentException(additionalMeasureName 
  1099.   + " not supported (DecisionTable)");
  1100.     }
  1101.   }
  1102.   /**
  1103.    * Returns a description of the classifier.
  1104.    *
  1105.    * @return a description of the classifier as a string.
  1106.    */
  1107.   public String toString() {
  1108.     if (m_entries == null) {
  1109.       return "Decision Table: No model built yet.";
  1110.     } else {
  1111.       StringBuffer text = new StringBuffer();
  1112.       
  1113.       text.append("Decision Table:"+
  1114.   "nnNumber of training instances: "+m_numInstances+
  1115.   "nNumber of Rules : "+m_entries.size()+"n");
  1116.       
  1117.       if (m_useIBk) {
  1118. text.append("Non matches covered by IB1.n");
  1119.       } else {
  1120. text.append("Non matches covered by Majority class.n");
  1121.       }
  1122.       
  1123.       text.append("Best first search for feature set,nterminated after "+
  1124.   m_maxStale+" non improving subsets.n");
  1125.       
  1126.       text.append("Evaluation (for feature selection): CV ");
  1127.       if (m_CVFolds > 1) {
  1128. text.append("("+m_CVFolds+" fold) ");
  1129.       } else {
  1130.   text.append("(leave one out) ");
  1131.       }
  1132.       text.append("nFeature set: "+printFeatures());
  1133.       
  1134.       if (m_displayRules) {
  1135. // find out the max column width
  1136. int maxColWidth = 0;
  1137. for (int i=0;i<m_theInstances.numAttributes();i++) {
  1138.   if (m_theInstances.attribute(i).name().length() > maxColWidth) {
  1139.     maxColWidth = m_theInstances.attribute(i).name().length();
  1140.   }
  1141.   if (m_classIsNominal || (i != m_theInstances.classIndex())) {
  1142.     Enumeration e = m_theInstances.attribute(i).enumerateValues();
  1143.     while (e.hasMoreElements()) {
  1144.       String ss = (String)e.nextElement();
  1145.       if (ss.length() > maxColWidth) {
  1146. maxColWidth = ss.length();
  1147.       }
  1148.     }
  1149.   }
  1150. }
  1151. text.append("nnRules:n");
  1152. StringBuffer tm = new StringBuffer();
  1153. for (int i=0;i<m_theInstances.numAttributes();i++) {
  1154.   if (m_theInstances.classIndex() != i) {
  1155.     int d = maxColWidth - m_theInstances.attribute(i).name().length();
  1156.     tm.append(m_theInstances.attribute(i).name());
  1157.     for (int j=0;j<d+1;j++) {
  1158.       tm.append(" ");
  1159.     }
  1160.   }
  1161. }
  1162. tm.append(m_theInstances.attribute(m_theInstances.classIndex()).name()+"  ");
  1163. for (int i=0;i<tm.length()+10;i++) {
  1164.   text.append("=");
  1165. }
  1166. text.append("n");
  1167. text.append(tm);
  1168. text.append("n");
  1169. for (int i=0;i<tm.length()+10;i++) {
  1170.   text.append("=");
  1171. }
  1172. text.append("n");
  1173. Enumeration e = m_entries.keys();
  1174. while (e.hasMoreElements()) {
  1175.   hashKey tt = (hashKey)e.nextElement();
  1176.   text.append(tt.toString(m_theInstances,maxColWidth));
  1177.   double [] ClassDist = (double []) m_entries.get(tt);
  1178.   if (m_classIsNominal) {
  1179.     int m = Utils.maxIndex(ClassDist);
  1180.     try {
  1181.       text.append(m_theInstances.classAttribute().value(m)+"n");
  1182.     } catch (Exception ee) {
  1183.       System.out.println(ee.getMessage());
  1184.     }
  1185.   } else {
  1186.     text.append((ClassDist[0] / ClassDist[1])+"n");
  1187.   }
  1188. }
  1189. for (int i=0;i<tm.length()+10;i++) {
  1190.   text.append("=");
  1191. }
  1192. text.append("n");
  1193. text.append("n");
  1194.       }
  1195.       return text.toString();
  1196.     }
  1197.   }
  1198.   /**
  1199.    * Main method for testing this class.
  1200.    *
  1201.    * @param argv the command-line options
  1202.    */
  1203.   public static void main(String [] argv) {
  1204.     
  1205.     Classifier scheme;
  1206.     
  1207.     try {
  1208.       scheme = new DecisionTable();
  1209.       System.out.println(Evaluation.evaluateModel(scheme,argv));
  1210.     }
  1211.     catch (Exception e) {
  1212.       e.printStackTrace();
  1213.       System.out.println(e.getMessage());
  1214.     }
  1215.   }
  1216. }