DecisionTable.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 34k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    DecisionTable.java
  18.  *    Copyright (C) 1999 Mark Hall
  19.  *
  20.  */
  21. package weka.classifiers.rules;
  22. import weka.classifiers.Classifier;
  23. import weka.classifiers.DistributionClassifier;
  24. import weka.classifiers.Evaluation;
  25. import weka.classifiers.lazy.IBk;
  26. import weka.classifiers.lazy.IB1;
  27. import java.io.*;
  28. import java.util.*;
  29. import weka.core.*;
  30. import weka.filters.Filter;
  31. import weka.filters.unsupervised.attribute.Remove;
  32. /**
  33.  * Class for building and using a simple decision table majority classifier.
  34.  * For more information see: <p>
  35.  *
  36.  * Kohavi R. (1995).<i> The Power of Decision Tables.</i> In Proc
  37.  * European Conference on Machine Learning.<p>
  38.  *
  39.  * Valid options are: <p>
  40.  *
  41.  * -S num <br>
  42.  * Number of fully expanded non improving subsets to consider
  43.  * before terminating a best first search.
  44.  * (Default = 5) <p>
  45.  *
  46.  * -X num <br>
  47.  * Use cross validation to evaluate features. Use number of folds = 1 for
  48.  * leave one out CV. (Default = leave one out CV) <p>
  49.  * 
  50.  * -I <br>
  51.  * Use nearest neighbour instead of global table majority. <p>
  52.  *
  53.  * -R <br>
  54.  * Prints the decision table. <p>
  55.  *
  56.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  57.  * @version $Revision: 1.25 $ 
  58.  */
  59. public class DecisionTable 
  60.   extends DistributionClassifier 
  61.   implements OptionHandler, WeightedInstancesHandler, 
  62.      AdditionalMeasureProducer {
  63.   
  64.   /** The hashtable used to hold training instances */
  65.   private Hashtable m_entries;
  66.   /** Holds the final feature set */
  67.   private int [] m_decisionFeatures;
  68.   /** Discretization filter */
  69.   private Filter m_disTransform;
  70.   /** Filter used to remove columns discarded by feature selection */
  71.   private Remove m_delTransform;
  72.   /** IB1 used to classify non matching instances rather than majority class */
  73.   private IBk m_ibk;
  74.   
  75.   /** Holds the training instances */
  76.   private Instances m_theInstances;
  77.   
  78.   /** The number of attributes in the dataset */
  79.   private int m_numAttributes;
  80.   /** The number of instances in the dataset */
  81.   private int m_numInstances;
  82.   /** Class is nominal */
  83.   private boolean m_classIsNominal;
  84.   /** Output debug info */
  85.   private boolean m_debug;
  86.   /** Use the IBk classifier rather than majority class */
  87.   private boolean m_useIBk;
  88.   /** Display Rules */
  89.   private boolean m_displayRules;
  90.   /** 
  91.    * Maximum number of fully expanded non improving subsets for a best 
  92.    * first search. 
  93.    */
  94.   private int m_maxStale;
  95.   /** Number of folds for cross validating feature sets */
  96.   private int m_CVFolds;
  97.   /** Random numbers for use in cross validation */
  98.   private Random m_rr;
  99.   /** Holds the majority class */
  100.   private double m_majority;
  101.   /**
  102.    * Class for a node in a linked list. Used in best first search.
  103.    */
  104.   public class Link {
  105.     /** The group */
  106.     BitSet m_group;
  107.     /** The merit */
  108.     double m_merit;
  109.     /**
  110.      * The constructor.
  111.      *
  112.      * @param gr the group
  113.      * @param mer the merit
  114.      */
  115.     public Link (BitSet gr, double mer) {
  116.       m_group = (BitSet)gr.clone();
  117.       m_merit = mer;
  118.     }
  119.   
  120.     /**
  121.      * Gets the group.
  122.      */
  123.     public BitSet getGroup() {
  124.       return m_group;
  125.     }
  126.   
  127.     /**
  128.      * Gets the merit.
  129.      */
  130.     public double getMerit() {
  131.       return m_merit;
  132.     }
  133.     /**
  134.      * Returns string representation.
  135.      */
  136.     public String toString() {
  137.       return ("Node: "+m_group.toString()+"  "+m_merit);
  138.     }
  139.   }
  140.   
  141.   /**
  142.    * Class for handling a linked list. Used in best first search.
  143.    * Extends the Vector class.
  144.    */
  145.   public class LinkedList extends FastVector {
  146.     /**
  147.      * Removes an element (Link) at a specific index from the list.
  148.      *
  149.      * @param index the index of the element to be removed.
  150.      */
  151.     public void removeLinkAt(int index) throws Exception {
  152.       if ((index >= 0) && (index < size())) {
  153. removeElementAt(index);
  154.       } else {
  155. throw new Exception("index out of range (removeLinkAt)");
  156.       }
  157.     }
  158.     /**
  159.      * Returns the element (Link) at a specific index from the list.
  160.      *
  161.      * @param index the index of the element to be returned.
  162.      */
  163.     public Link getLinkAt(int index) throws Exception {
  164.       if (size()==0) {
  165. throw new Exception("List is empty (getLinkAt)");
  166.       } else if ((index >= 0) && (index < size())) {
  167. return ((Link)(elementAt(index)));
  168.       } else {
  169. throw new Exception("index out of range (getLinkAt)");
  170.       }
  171.     }
  172.     /**
  173.      * Aadds an element (Link) to the list.
  174.      *
  175.      * @param gr the feature set specification
  176.      * @param mer the "merit" of this feature set
  177.      */
  178.     public void addToList(BitSet gr, double mer) {
  179.       Link newL = new Link(gr, mer);
  180.       if (size()==0) {
  181. addElement(newL);
  182.       }
  183.       else if (mer > ((Link)(firstElement())).getMerit()) {
  184. insertElementAt(newL,0);
  185.       } else {
  186. int i = 0;
  187. int size = size();
  188. boolean done = false;
  189. while ((!done) && (i < size)) {
  190.   if (mer > ((Link)(elementAt(i))).getMerit()) {
  191.     insertElementAt(newL,i);
  192.     done = true;
  193.   } else if (i == size-1) {
  194.     addElement(newL);
  195.     done = true;
  196.   } else {
  197.     i++;
  198.   }
  199. }
  200.       }
  201.     }
  202.   }
  203.   /**
  204.    * Class providing keys to the hash table
  205.    */
  206.   public class hashKey implements Serializable {
  207.     
  208.     /** Array of attribute values for an instance */
  209.     private double [] attributes;
  210.     
  211.     /** True for an index if the corresponding attribute value is missing. */
  212.     private boolean [] missing;
  213.     /** The values */
  214.     private String [] values;
  215.     /** The key */
  216.     private int key;
  217.     /**
  218.      * Constructor for a hashKey
  219.      *
  220.      * @param t an instance from which to generate a key
  221.      * @param numAtts the number of attributes
  222.      */
  223.     public hashKey(Instance t, int numAtts) throws Exception {
  224.       int i;
  225.       int cindex = t.classIndex();
  226.       key = -999;
  227.       attributes = new double [numAtts];
  228.       missing = new boolean [numAtts];
  229.       for (i=0;i<numAtts;i++) {
  230. if (i == cindex) {
  231.   missing[i] = true;
  232. } else {
  233.   if ((missing[i] = t.isMissing(i)) == false) {
  234.     attributes[i] = t.value(i);
  235.   }
  236. }
  237.       }
  238.     }
  239.     /**
  240.      * Convert a hash entry to a string
  241.      *
  242.      * @param t the set of instances
  243.      * @param maxColWidth width to make the fields
  244.      */
  245.     public String toString(Instances t, int maxColWidth) {
  246.       int i;
  247.       int cindex = t.classIndex();
  248.       StringBuffer text = new StringBuffer();
  249.       
  250.       for (i=0;i<attributes.length;i++) {
  251. if (i != cindex) {
  252.   if (missing[i]) {
  253.     text.append("?");
  254.     for (int j=0;j<maxColWidth;j++) {
  255.       text.append(" ");
  256.     }
  257.   } else {
  258.     String ss = t.attribute(i).value((int)attributes[i]);
  259.     StringBuffer sb = new StringBuffer(ss);
  260.     
  261.     for (int j=0;j < (maxColWidth-ss.length()+1); j++) {
  262. sb.append(" ");
  263.     }
  264.     text.append(sb);
  265.   }
  266. }
  267.       }
  268.       return text.toString();
  269.     }
  270.     /**
  271.      * Constructor for a hashKey
  272.      *
  273.      * @param t an array of feature values
  274.      */
  275.     public hashKey(double [] t) {
  276.       int i;
  277.       int l = t.length;
  278.       key = -999;
  279.       attributes = new double [l];
  280.       missing = new boolean [l];
  281.       for (i=0;i<l;i++) {
  282. if (t[i] == Double.MAX_VALUE) {
  283.   missing[i] = true;
  284. } else {
  285.   missing[i] = false;
  286.   attributes[i] = t[i];
  287. }
  288.       }
  289.     }
  290.     
  291.     /**
  292.      * Calculates a hash code
  293.      *
  294.      * @return the hash code as an integer
  295.      */
  296.     public int hashCode() {
  297.       int hv = 0;
  298.       
  299.       if (key != -999)
  300. return key;
  301.       for (int i=0;i<attributes.length;i++) {
  302. if (missing[i]) {
  303.   hv += (i*13);
  304. } else {
  305.   hv += (i * 5 * (attributes[i]+1));
  306. }
  307.       }
  308.       if (key == -999) {
  309. key = hv;
  310.       }
  311.       return hv;
  312.     }
  313.     /**
  314.      * Tests if two instances are equal
  315.      *
  316.      * @param b a key to compare with
  317.      */
  318.     public boolean equals(Object b) {
  319.       
  320.       if ((b == null) || !(b.getClass().equals(this.getClass()))) {
  321.         return false;
  322.       }
  323.       boolean ok = true;
  324.       boolean l;
  325.       if (b instanceof hashKey) {
  326. hashKey n = (hashKey)b;
  327. for (int i=0;i<attributes.length;i++) {
  328.   l = n.missing[i];
  329.   if (missing[i] || l) {
  330.     if ((missing[i] && !l) || (!missing[i] && l)) {
  331.       ok = false;
  332.       break;
  333.     }
  334.   } else {
  335.     if (attributes[i] != n.attributes[i]) {
  336.       ok = false;
  337.       break;
  338.     }
  339.   }
  340. }
  341.       } else {
  342. return false;
  343.       }
  344.       return ok;
  345.     }
  346.     
  347.     /**
  348.      * Prints the hash code
  349.      */
  350.     public void print_hash_code() {
  351.       
  352.       System.out.println("Hash val: "+hashCode());
  353.     }
  354.   }
  355.   /**
  356.    * Inserts an instance into the hash table
  357.    *
  358.    * @param inst instance to be inserted
  359.    * @exception Exception if the instance can't be inserted
  360.    */
  361.   private void insertIntoTable(Instance inst, double [] instA)
  362.        throws Exception {
  363.     double [] tempClassDist2;
  364.     double [] newDist;
  365.     hashKey thekey;
  366.     if (instA != null) {
  367.       thekey = new hashKey(instA);
  368.     } else {
  369.       thekey = new hashKey(inst, inst.numAttributes());
  370.     }
  371.       
  372.     // see if this one is already in the table
  373.     tempClassDist2 = (double []) m_entries.get(thekey);
  374.     if (tempClassDist2 == null) {
  375.       if (m_classIsNominal) {
  376. newDist = new double [m_theInstances.classAttribute().numValues()];
  377. newDist[(int)inst.classValue()] = inst.weight();
  378. // add to the table
  379. m_entries.put(thekey, newDist);
  380.       } else {
  381. newDist = new double [2];
  382. newDist[0] = inst.classValue() * inst.weight();
  383. newDist[1] = inst.weight();
  384. // add to the table
  385. m_entries.put(thekey, newDist);
  386.       }
  387.     } else { 
  388.       // update the distribution for this instance
  389.       if (m_classIsNominal) {
  390. tempClassDist2[(int)inst.classValue()]+=inst.weight();
  391. // update the table
  392. m_entries.put(thekey, tempClassDist2);
  393.       } else  {
  394. tempClassDist2[0] += (inst.classValue() * inst.weight());
  395. tempClassDist2[1] += inst.weight();
  396. // update the table
  397. m_entries.put(thekey, tempClassDist2);
  398.       }
  399.     }
  400.   }
  401.   /**
  402.    * Classifies an instance for internal leave one out cross validation
  403.    * of feature sets
  404.    *
  405.    * @param instance instance to be "left out" and classified
  406.    * @param instA feature values of the selected features for the instance
  407.    * @return the classification of the instance
  408.    */
  409.   double classifyInstanceLeaveOneOut(Instance instance, double [] instA)
  410.        throws Exception {
  411.     hashKey thekey;
  412.     double [] tempDist;
  413.     double [] normDist;
  414.     thekey = new hashKey(instA);
  415.     if (m_classIsNominal) {
  416.       // if this one is not in the table
  417.       if ((tempDist = (double [])m_entries.get(thekey)) == null) {
  418. throw new Error("This should never happen!");
  419.       } else {
  420. normDist = new double [tempDist.length];
  421. System.arraycopy(tempDist,0,normDist,0,tempDist.length);
  422. normDist[(int)instance.classValue()] -= instance.weight();
  423. // update the table
  424. // first check to see if the class counts are all zero now
  425. boolean ok = false;
  426. for (int i=0;i<normDist.length;i++) {
  427.   if (!Utils.eq(normDist[i],0.0)) {
  428.     ok = true;
  429.     break;
  430.   }
  431. }
  432. if (ok) {
  433.   Utils.normalize(normDist);
  434.   return Utils.maxIndex(normDist);
  435. } else {
  436.   return m_majority;
  437. }
  438.       }
  439.       //      return Utils.maxIndex(tempDist);
  440.     } else {
  441.       // see if this one is already in the table
  442.       if ((tempDist = (double[])m_entries.get(thekey)) != null) {
  443. normDist = new double [tempDist.length];
  444. System.arraycopy(tempDist,0,normDist,0,tempDist.length);
  445. normDist[0] -= (instance.classValue() * instance.weight());
  446. normDist[1] -= instance.weight();
  447. if (Utils.eq(normDist[1],0.0)) {
  448.     return m_majority;
  449. } else {
  450.   return (normDist[0] / normDist[1]);
  451. }
  452.       } else {
  453. throw new Error("This should never happen!");
  454.       }
  455.     }
  456.     
  457.     // shouldn't get here 
  458.     // return 0.0;
  459.   }
  460.   /**
  461.    * Calculates the accuracy on a test fold for internal cross validation
  462.    * of feature sets
  463.    *
  464.    * @param fold set of instances to be "left out" and classified
  465.    * @param fs currently selected feature set
  466.    * @return the accuracy for the fold
  467.    */
  468.   double classifyFoldCV(Instances fold, int [] fs) throws Exception {
  469.     int i;
  470.     int ruleCount = 0;
  471.     int numFold = fold.numInstances();
  472.     int numCl = m_theInstances.classAttribute().numValues();
  473.     double [][] class_distribs = new double [numFold][numCl];
  474.     double [] instA = new double [fs.length];
  475.     double [] normDist;
  476.     hashKey thekey;
  477.     double acc = 0.0;
  478.     int classI = m_theInstances.classIndex();
  479.     Instance inst;
  480.     if (m_classIsNominal) {
  481.       normDist = new double [numCl];
  482.     } else {
  483.       normDist = new double [2];
  484.     }
  485.     // first *remove* instances
  486.     for (i=0;i<numFold;i++) {
  487.       inst = fold.instance(i);
  488.       for (int j=0;j<fs.length;j++) {
  489. if (fs[j] == classI) {
  490.   instA[j] = Double.MAX_VALUE; // missing for the class
  491. } else if (inst.isMissing(fs[j])) {
  492.   instA[j] = Double.MAX_VALUE;
  493. } else{
  494.   instA[j] = inst.value(fs[j]);
  495. }
  496.       }
  497.       thekey = new hashKey(instA);
  498.       if ((class_distribs[i] = (double [])m_entries.get(thekey)) == null) {
  499. throw new Error("This should never happen!");
  500.       } else {
  501. if (m_classIsNominal) {
  502.   class_distribs[i][(int)inst.classValue()] -= inst.weight();
  503. } else {
  504.   class_distribs[i][0] -= (inst.classValue() * inst.weight());
  505.   class_distribs[i][1] -= inst.weight();
  506. }
  507. ruleCount++;
  508.       }
  509.     }
  510.     // now classify instances
  511.     for (i=0;i<numFold;i++) {
  512.       inst = fold.instance(i);
  513.       System.arraycopy(class_distribs[i],0,normDist,0,normDist.length);
  514.       if (m_classIsNominal) {
  515. boolean ok = false;
  516. for (int j=0;j<normDist.length;j++) {
  517.   if (!Utils.eq(normDist[j],0.0)) {
  518.     ok = true;
  519.     break;
  520.   }
  521. }
  522. if (ok) {
  523.   Utils.normalize(normDist);
  524.   if (Utils.maxIndex(normDist) == inst.classValue())
  525.     acc += inst.weight();
  526. } else {
  527.   if (inst.classValue() == m_majority) {
  528.     acc += inst.weight();
  529.   }
  530. }
  531.       } else {
  532. if (Utils.eq(normDist[1],0.0)) {
  533.     acc += ((inst.weight() * (m_majority - inst.classValue())) * 
  534.     (inst.weight() * (m_majority - inst.classValue())));
  535. } else {
  536.   double t = (normDist[0] / normDist[1]);
  537.   acc += ((inst.weight() * (t - inst.classValue())) * 
  538.   (inst.weight() * (t - inst.classValue())));
  539. }
  540.       }
  541.     }
  542.     // now re-insert instances
  543.     for (i=0;i<numFold;i++) {
  544.       inst = fold.instance(i);
  545.       if (m_classIsNominal) {
  546. class_distribs[i][(int)inst.classValue()] += inst.weight();
  547.       } else {
  548. class_distribs[i][0] += (inst.classValue() * inst.weight());
  549. class_distribs[i][1] += inst.weight();
  550.       }
  551.     }
  552.     return acc;
  553.   }
  554.   /**
  555.    * Evaluates a feature subset by cross validation
  556.    *
  557.    * @param feature_set the subset to be evaluated
  558.    * @param num_atts the number of attributes in the subset
  559.    * @return the estimated accuracy
  560.    * @exception Exception if subset can't be evaluated
  561.    */
  562.   private double estimateAccuracy(BitSet feature_set, int num_atts)
  563.     throws Exception {
  564.     int i;
  565.     Instances newInstances;
  566.     int [] fs = new int [num_atts];
  567.     double acc = 0.0;
  568.     double [][] evalArray;
  569.     double [] instA = new double [num_atts];
  570.     int classI = m_theInstances.classIndex();
  571.     
  572.     int index = 0;
  573.     for (i=0;i<m_numAttributes;i++) {
  574.       if (feature_set.get(i)) {
  575. fs[index++] = i;
  576.       }
  577.     }
  578.     // create new hash table
  579.     m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));
  580.     // insert instances into the hash table
  581.     for (i=0;i<m_numInstances;i++) {
  582.       Instance inst = m_theInstances.instance(i);
  583.       for (int j=0;j<fs.length;j++) {
  584. if (fs[j] == classI) {
  585.   instA[j] = Double.MAX_VALUE; // missing for the class
  586. } else if (inst.isMissing(fs[j])) {
  587.   instA[j] = Double.MAX_VALUE;
  588. } else {
  589.   instA[j] = inst.value(fs[j]);
  590. }
  591.       }
  592.       insertIntoTable(inst, instA);
  593.     }
  594.     
  595.     
  596.     if (m_CVFolds == 1) {
  597.       // calculate leave one out error
  598.       for (i=0;i<m_numInstances;i++) {
  599. Instance inst = m_theInstances.instance(i);
  600. for (int j=0;j<fs.length;j++) {
  601.   if (fs[j] == classI) {
  602.     instA[j] = Double.MAX_VALUE; // missing for the class
  603.   } else if (inst.isMissing(fs[j])) {
  604.     instA[j] = Double.MAX_VALUE;
  605.   } else {
  606.     instA[j] = inst.value(fs[j]);
  607.   }
  608. }
  609. double t = classifyInstanceLeaveOneOut(inst, instA);
  610. if (m_classIsNominal) {
  611.   if (t == inst.classValue()) {
  612.     acc+=inst.weight();
  613.   }
  614. } else {
  615.   acc += ((inst.weight() * (t - inst.classValue())) * 
  616.   (inst.weight() * (t - inst.classValue())));
  617. }
  618. // weight_sum += inst.weight();
  619.       }
  620.     } else {
  621.       m_theInstances.randomize(m_rr);
  622.       m_theInstances.stratify(m_CVFolds);
  623.       // calculate 10 fold cross validation error
  624.       for (i=0;i<m_CVFolds;i++) {
  625. Instances insts = m_theInstances.testCV(m_CVFolds,i);
  626. acc += classifyFoldCV(insts, fs);
  627.       }
  628.     }
  629.   
  630.     if (m_classIsNominal) {
  631.       return (acc / m_theInstances.sumOfWeights());
  632.     } else {
  633.       return -(Math.sqrt(acc / m_theInstances.sumOfWeights()));   
  634.     }
  635.   }
  636.   /**
  637.    * Returns a String representation of a feature subset
  638.    *
  639.    * @param sub BitSet representation of a subset
  640.    * @return String containing subset
  641.    */
  642.   private String printSub(BitSet sub) {
  643.     int i;
  644.     String s="";
  645.     for (int jj=0;jj<m_numAttributes;jj++) {
  646.       if (sub.get(jj)) {
  647. s += " "+(jj+1);
  648.       }
  649.     }
  650.     return s;
  651.   }
  652.     
  653.   /**
  654.    * Does a best first search 
  655.    */
  656.   private void best_first() throws Exception {
  657.     int i,j,classI,count=0,fc,tree_count=0;
  658.     int evals=0;
  659.     BitSet best_group, temp_group;
  660.     int [] stale;
  661.     double [] best_merit;
  662.     double merit;
  663.     boolean z;
  664.     boolean added;
  665.     Link tl;
  666.   
  667.     Hashtable lookup = new Hashtable((int)(200.0*m_numAttributes*1.5));
  668.     LinkedList bfList = new LinkedList();
  669.     best_merit = new double[1]; best_merit[0] = 0.0;
  670.     stale = new int[1]; stale[0] = 0;
  671.     best_group = new BitSet(m_numAttributes);
  672.     // Add class to initial subset
  673.     classI = m_theInstances.classIndex();
  674.     best_group.set(classI);
  675.     best_merit[0] = estimateAccuracy(best_group, 1);
  676.     if (m_debug)
  677.       System.out.println("Accuracy of initial subset: "+best_merit[0]);
  678.     // add the initial group to the list
  679.     bfList.addToList(best_group,best_merit[0]);
  680.     // add initial subset to the hashtable
  681.     lookup.put(best_group,"");
  682.     while (stale[0] < m_maxStale) {
  683.       added = false;
  684.       // finished search?
  685.       if (bfList.size()==0) {
  686. stale[0] = m_maxStale;
  687. break;
  688.       }
  689.       // copy the feature set at the head of the list
  690.       tl = bfList.getLinkAt(0);
  691.       temp_group = (BitSet)(tl.getGroup().clone());
  692.       // remove the head of the list
  693.       bfList.removeLinkAt(0);
  694.       for (i=0;i<m_numAttributes;i++) {
  695. // if (search_direction == 1)
  696. z = ((i != classI) && (!temp_group.get(i)));
  697. if (z) {
  698.   // set the bit (feature to add/delete) */
  699.   temp_group.set(i);
  700.   
  701.   /* if this subset has been seen before, then it is already in 
  702.      the list (or has been fully expanded) */
  703.   BitSet tt = (BitSet)temp_group.clone();
  704.   if (lookup.containsKey(tt) == false) {
  705.     fc = 0;
  706.     for (int jj=0;jj<m_numAttributes;jj++) {
  707.       if (tt.get(jj)) {
  708. fc++;
  709.       }
  710.     }
  711.     merit = estimateAccuracy(tt, fc);
  712.     if (m_debug) {
  713.       System.out.println("evaluating: "+printSub(tt)+" "+merit); 
  714.     }
  715.     
  716.     // is this better than the best?
  717.     // if (search_direction == 1)
  718.     z = ((merit - best_merit[0]) > 0.00001);
  719.  
  720.     // else
  721.     // z = ((best_merit[0] - merit) > 0.00001);
  722.     if (z) {
  723.       if (m_debug) {
  724. System.out.println("new best feature set: "+printSub(tt)+
  725.    " "+merit);
  726.       }
  727.       added = true;
  728.       stale[0] = 0;
  729.       best_merit[0] = merit;
  730.       best_group = (BitSet)(temp_group.clone());
  731.     }
  732.     // insert this one in the list and the hash table
  733.     bfList.addToList(tt, merit);
  734.     lookup.put(tt,"");
  735.     count++;
  736.   }
  737.   // unset this addition(deletion)
  738.   temp_group.clear(i);
  739. }
  740.       }
  741.       /* if we haven't added a new feature subset then full expansion 
  742.  of this node hasn't resulted in anything better */
  743.       if (!added) {
  744. stale[0]++;
  745.       }
  746.     }
  747.    
  748.     // set selected features
  749.     for (i=0,j=0;i<m_numAttributes;i++) {
  750.       if (best_group.get(i)) {
  751. j++;
  752.       }
  753.     }
  754.     
  755.     m_decisionFeatures = new int[j];
  756.     for (i=0,j=0;i<m_numAttributes;i++) {
  757.       if (best_group.get(i)) {
  758. m_decisionFeatures[j++] = i;    
  759.       }
  760.     }
  761.   }
  762.  
  763.   /**
  764.    * Resets the options.
  765.    */
  766.   protected void resetOptions()  {
  767.     m_entries = null;
  768.     m_decisionFeatures = null;
  769.     m_debug = false;
  770.     m_useIBk = false;
  771.     m_CVFolds = 1;
  772.     m_maxStale = 5;
  773.     m_displayRules = false;
  774.   }
  775.    /**
  776.    * Constructor for a DecisionTable
  777.    */
  778.   public DecisionTable() {
  779.     resetOptions();
  780.   }
  781.   /**
  782.    * Returns an enumeration describing the available options.
  783.    *
  784.    * @return an enumeration of all the available options.
  785.    */
  786.   public Enumeration listOptions() {
  787.     Vector newVector = new Vector(5);
  788.     newVector.addElement(new Option(
  789.               "tNumber of fully expanded non improving subsets to considern" +
  790.       "tbefore terminating a best first search.n" +
  791.       "tUse in conjunction with -B. (Default = 5)",
  792.               "S", 1, "-S <number of non improving nodes>"));
  793.     
  794.     newVector.addElement(new Option(
  795.               "tUse cross validation to evaluate features.n" +
  796.       "tUse number of folds = 1 for leave one out CV.n" +
  797.       "t(Default = leave one out CV)",
  798.               "X", 1, "-X <number of folds>"));
  799.      newVector.addElement(new Option(
  800.               "tUse nearest neighbour instead of global table majority.n",
  801.               "I", 0, "-I"));
  802.      newVector.addElement(new Option(
  803.               "tDisplay decision table rules.n",
  804.               "R", 0, "-R")); 
  805.     return newVector.elements();
  806.   }
  807.   /**
  808.    * Sets the number of folds for cross validation (1 = leave one out)
  809.    *
  810.    * @param folds the number of folds
  811.    */
  812.   public void setCrossVal(int folds) {
  813.     m_CVFolds = folds;
  814.   }
  815.   /**
  816.    * Gets the number of folds for cross validation
  817.    *
  818.    * @return the number of cross validation folds
  819.    */
  820.   public int getCrossVal() {
  821.     return m_CVFolds;
  822.   }
  823.   /**
  824.    * Sets the number of non improving decision tables to consider
  825.    * before abandoning the search.
  826.    *
  827.    * @param stale the number of nodes
  828.    */
  829.   public void setMaxStale(int stale) {
  830.     m_maxStale = stale;
  831.   }
  832.   /**
  833.    * Gets the number of non improving decision tables
  834.    *
  835.    * @return the number of non improving decision tables
  836.    */
  837.   public int getMaxStale() {
  838.     return m_maxStale;
  839.   }
  840.   /**
  841.    * Sets whether IBk should be used instead of the majority class
  842.    *
  843.    * @param ibk true if IBk is to be used
  844.    */
  845.   public void setUseIBk(boolean ibk) {
  846.     m_useIBk = ibk;
  847.   }
  848.   
  849.   /**
  850.    * Gets whether IBk is being used instead of the majority class
  851.    *
  852.    * @return true if IBk is being used
  853.    */
  854.   public boolean getUseIBk() {
  855.     return m_useIBk;
  856.   }
  857.   /**
  858.    * Sets whether rules are to be printed
  859.    *
  860.    * @param rules true if rules are to be printed
  861.    */
  862.   public void setDisplayRules(boolean rules) {
  863.     m_displayRules = rules;
  864.   }
  865.   
  866.   /**
  867.    * Gets whether rules are being printed
  868.    *
  869.    * @return true if rules are being printed
  870.    */
  871.   public boolean getDisplayRules() {
  872.     return m_displayRules;
  873.   }
  874.   /**
  875.    * Parses the options for this object.
  876.    *
  877.    * Valid options are: <p>
  878.    *
  879.    * -S num <br>
  880.    * Number of fully expanded non improving subsets to consider
  881.    * before terminating a best first search.
  882.    * (Default = 5) <p>
  883.    *
  884.    * -X num <br>
  885.    * Use cross validation to evaluate features. Use number of folds = 1 for
  886.    * leave one out CV. (Default = leave one out CV) <p>
  887.    * 
  888.    * -I <br>
  889.    * Use nearest neighbour instead of global table majority. <p>
  890.    *
  891.    * -R <br>
  892.    * Prints the decision table. <p>
  893.    *
  894.    * @param options the list of options as an array of strings
  895.    * @exception Exception if an option is not supported
  896.    */
  897.   public void setOptions(String[] options) throws Exception {
  898.     String optionString;
  899.     resetOptions();
  900.     optionString = Utils.getOption('X',options);
  901.     if (optionString.length() != 0) {
  902.       m_CVFolds = Integer.parseInt(optionString);
  903.     }
  904.     optionString = Utils.getOption('S',options);
  905.     if (optionString.length() != 0) {
  906.       m_maxStale = Integer.parseInt(optionString);
  907.     }
  908.     m_useIBk = Utils.getFlag('I',options);
  909.     m_displayRules = Utils.getFlag('R',options);
  910.   }
  911.   /**
  912.    * Gets the current settings of the classifier.
  913.    *
  914.    * @return an array of strings suitable for passing to setOptions
  915.    */
  916.   public String [] getOptions() {
  917.     String [] options = new String [7];
  918.     int current = 0;
  919.     options[current++] = "-X"; options[current++] = "" + m_CVFolds;
  920.     options[current++] = "-S"; options[current++] = "" + m_maxStale;
  921.     if (m_useIBk) {
  922.       options[current++] = "-I";
  923.     }
  924.     if (m_displayRules) {
  925.       options[current++] = "-R";
  926.     }
  927.     while (current < options.length) {
  928.       options[current++] = "";
  929.     }
  930.     return options;
  931.   }
  932.   
  933.   /**
  934.    * Generates the classifier.
  935.    *
  936.    * @param data set of instances serving as training data 
  937.    * @exception Exception if the classifier has not been generated successfully
  938.    */
  939.   public void buildClassifier(Instances data) throws Exception {
  940.     int i;
  941.     m_rr = new Random(1);
  942.     m_theInstances = new Instances(data);
  943.     m_theInstances.deleteWithMissingClass();
  944.     if (m_theInstances.numInstances() == 0) {
  945.       throw new Exception("No training instances without missing class!");
  946.     }
  947.     if (m_theInstances.checkForStringAttributes()) {
  948.       throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
  949.     }
  950.     if (m_theInstances.classAttribute().isNumeric()) {
  951.       m_disTransform = new weka.filters.unsupervised.attribute.Discretize();
  952.       m_classIsNominal = false;
  953.       
  954.       // use binned discretisation if the class is numeric
  955.       ((weka.filters.unsupervised.attribute.Discretize)m_disTransform).
  956. setBins(10);
  957.       ((weka.filters.unsupervised.attribute.Discretize)m_disTransform).
  958. setInvertSelection(true);
  959.       
  960.       // Discretize all attributes EXCEPT the class 
  961.       String rangeList = "";
  962.       rangeList+=(m_theInstances.classIndex()+1);
  963.       //System.out.println("The class col: "+m_theInstances.classIndex());
  964.       
  965.       ((weka.filters.unsupervised.attribute.Discretize)m_disTransform).
  966. setAttributeIndices(rangeList);
  967.     } else {
  968.       m_disTransform = new weka.filters.supervised.attribute.Discretize();
  969.       ((weka.filters.supervised.attribute.Discretize)m_disTransform).setUseBetterEncoding(true);
  970.       m_classIsNominal = true;
  971.     }
  972.     m_disTransform.setInputFormat(m_theInstances);
  973.     m_theInstances = Filter.useFilter(m_theInstances, m_disTransform);
  974.     
  975.     m_numAttributes = m_theInstances.numAttributes();
  976.     m_numInstances = m_theInstances.numInstances();
  977.     m_majority = m_theInstances.meanOrMode(m_theInstances.classAttribute());
  978.     
  979.     best_first();
  980.     
  981.     // reduce instances to selected features
  982.     m_delTransform = new Remove();
  983.     m_delTransform.setInvertSelection(true);
  984.     
  985.     // set features to keep
  986.     m_delTransform.setAttributeIndicesArray(m_decisionFeatures); 
  987.     m_delTransform.setInputFormat(m_theInstances);
  988.     m_theInstances = Filter.useFilter(m_theInstances, m_delTransform);
  989.     
  990.     // reset the number of attributes
  991.     m_numAttributes = m_theInstances.numAttributes();
  992.     
  993.     // create hash table
  994.     m_entries = new Hashtable((int)(m_theInstances.numInstances() * 1.5));
  995.     
  996.     // insert instances into the hash table
  997.     for (i=0;i<m_numInstances;i++) {
  998.       Instance inst = m_theInstances.instance(i);
  999.       insertIntoTable(inst, null);
  1000.     }
  1001.     
  1002.     // Replace the global table majority with nearest neighbour?
  1003.     if (m_useIBk) {
  1004.       m_ibk = new IBk();
  1005.       m_ibk.buildClassifier(m_theInstances);
  1006.     }
  1007.     
  1008.     // Save memory
  1009.     m_theInstances = new Instances(m_theInstances, 0);
  1010.   }
  1011.   /**
  1012.    * Calculates the class membership probabilities for the given 
  1013.    * test instance.
  1014.    *
  1015.    * @param instance the instance to be classified
  1016.    * @return predicted class probability distribution
  1017.    * @exception Exception if distribution can't be computed
  1018.    */
  1019.   public double [] distributionForInstance(Instance instance)
  1020.        throws Exception {
  1021.     hashKey thekey;
  1022.     double [] tempDist;
  1023.     double [] normDist;
  1024.     m_disTransform.input(instance);
  1025.     m_disTransform.batchFinished();
  1026.     instance = m_disTransform.output();
  1027.     m_delTransform.input(instance);
  1028.     m_delTransform.batchFinished();
  1029.     instance = m_delTransform.output();
  1030.     thekey = new hashKey(instance, instance.numAttributes());
  1031.     
  1032.     // if this one is not in the table
  1033.     if ((tempDist = (double [])m_entries.get(thekey)) == null) {
  1034.       if (m_useIBk) {
  1035. tempDist = m_ibk.distributionForInstance(instance);
  1036.       } else {
  1037. if (!m_classIsNominal) {
  1038.   tempDist = new double[1];
  1039.   tempDist[0] = m_majority;
  1040. } else {
  1041.   tempDist = new double [m_theInstances.classAttribute().numValues()];
  1042.   tempDist[(int)m_majority] = 1.0;
  1043. }
  1044.       }
  1045.     } else {
  1046.       if (!m_classIsNominal) {
  1047. normDist = new double[1];
  1048. normDist[0] = (tempDist[0] / tempDist[1]);
  1049. tempDist = normDist;
  1050.       } else {
  1051. // normalise distribution
  1052. normDist = new double [tempDist.length];
  1053. System.arraycopy(tempDist,0,normDist,0,tempDist.length);
  1054. Utils.normalize(normDist);
  1055. tempDist = normDist;
  1056.       }
  1057.     }
  1058.     return tempDist;
  1059.   }
  1060.   /**
  1061.    * Returns a string description of the features selected
  1062.    *
  1063.    * @return a string of features
  1064.    */
  1065.   public String printFeatures() {
  1066.     int i;
  1067.     String s = "";
  1068.    
  1069.     for (i=0;i<m_decisionFeatures.length;i++) {
  1070.       if (i==0) {
  1071. s = ""+(m_decisionFeatures[i]+1);
  1072.       } else {
  1073. s += ","+(m_decisionFeatures[i]+1);
  1074.       }
  1075.     }
  1076.     return s;
  1077.   }
  1078.   /**
  1079.    * Returns the number of rules
  1080.    * @return the number of rules
  1081.    */
  1082.   public double measureNumRules() {
  1083.     return m_entries.size();
  1084.   }
  1085.   /**
  1086.    * Returns an enumeration of the additional measure names
  1087.    * @return an enumeration of the measure names
  1088.    */
  1089.   public Enumeration enumerateMeasures() {
  1090.     Vector newVector = new Vector(1);
  1091.     newVector.addElement("measureNumRules");
  1092.     return newVector.elements();
  1093.   }
  1094.   /**
  1095.    * Returns the value of the named measure
  1096.    * @param measureName the name of the measure to query for its value
  1097.    * @return the value of the named measure
  1098.    * @exception IllegalArgumentException if the named measure is not supported
  1099.    */
  1100.   public double getMeasure(String additionalMeasureName) {
  1101.     if (additionalMeasureName.compareTo("measureNumRules") == 0) {
  1102.       return measureNumRules();
  1103.     } else {
  1104.       throw new IllegalArgumentException(additionalMeasureName 
  1105.   + " not supported (DecisionTable)");
  1106.     }
  1107.   }
  1108.   /**
  1109.    * Returns a description of the classifier.
  1110.    *
  1111.    * @return a description of the classifier as a string.
  1112.    */
  1113.   public String toString() {
  1114.     if (m_entries == null) {
  1115.       return "Decision Table: No model built yet.";
  1116.     } else {
  1117.       StringBuffer text = new StringBuffer();
  1118.       
  1119.       text.append("Decision Table:"+
  1120.   "nnNumber of training instances: "+m_numInstances+
  1121.   "nNumber of Rules : "+m_entries.size()+"n");
  1122.       
  1123.       if (m_useIBk) {
  1124. text.append("Non matches covered by IB1.n");
  1125.       } else {
  1126. text.append("Non matches covered by Majority class.n");
  1127.       }
  1128.       
  1129.       text.append("Best first search for feature set,nterminated after "+
  1130.   m_maxStale+" non improving subsets.n");
  1131.       
  1132.       text.append("Evaluation (for feature selection): CV ");
  1133.       if (m_CVFolds > 1) {
  1134. text.append("("+m_CVFolds+" fold) ");
  1135.       } else {
  1136.   text.append("(leave one out) ");
  1137.       }
  1138.       text.append("nFeature set: "+printFeatures());
  1139.       
  1140.       if (m_displayRules) {
  1141. // find out the max column width
  1142. int maxColWidth = 0;
  1143. for (int i=0;i<m_theInstances.numAttributes();i++) {
  1144.   if (m_theInstances.attribute(i).name().length() > maxColWidth) {
  1145.     maxColWidth = m_theInstances.attribute(i).name().length();
  1146.   }
  1147.   if (m_classIsNominal || (i != m_theInstances.classIndex())) {
  1148.     Enumeration e = m_theInstances.attribute(i).enumerateValues();
  1149.     while (e.hasMoreElements()) {
  1150.       String ss = (String)e.nextElement();
  1151.       if (ss.length() > maxColWidth) {
  1152. maxColWidth = ss.length();
  1153.       }
  1154.     }
  1155.   }
  1156. }
  1157. text.append("nnRules:n");
  1158. StringBuffer tm = new StringBuffer();
  1159. for (int i=0;i<m_theInstances.numAttributes();i++) {
  1160.   if (m_theInstances.classIndex() != i) {
  1161.     int d = maxColWidth - m_theInstances.attribute(i).name().length();
  1162.     tm.append(m_theInstances.attribute(i).name());
  1163.     for (int j=0;j<d+1;j++) {
  1164.       tm.append(" ");
  1165.     }
  1166.   }
  1167. }
  1168. tm.append(m_theInstances.attribute(m_theInstances.classIndex()).name()+"  ");
  1169. for (int i=0;i<tm.length()+10;i++) {
  1170.   text.append("=");
  1171. }
  1172. text.append("n");
  1173. text.append(tm);
  1174. text.append("n");
  1175. for (int i=0;i<tm.length()+10;i++) {
  1176.   text.append("=");
  1177. }
  1178. text.append("n");
  1179. Enumeration e = m_entries.keys();
  1180. while (e.hasMoreElements()) {
  1181.   hashKey tt = (hashKey)e.nextElement();
  1182.   text.append(tt.toString(m_theInstances,maxColWidth));
  1183.   double [] ClassDist = (double []) m_entries.get(tt);
  1184.   if (m_classIsNominal) {
  1185.     int m = Utils.maxIndex(ClassDist);
  1186.     try {
  1187.       text.append(m_theInstances.classAttribute().value(m)+"n");
  1188.     } catch (Exception ee) {
  1189.       System.out.println(ee.getMessage());
  1190.     }
  1191.   } else {
  1192.     text.append((ClassDist[0] / ClassDist[1])+"n");
  1193.   }
  1194. }
  1195. for (int i=0;i<tm.length()+10;i++) {
  1196.   text.append("=");
  1197. }
  1198. text.append("n");
  1199. text.append("n");
  1200.       }
  1201.       return text.toString();
  1202.     }
  1203.   }
  1204.   /**
  1205.    * Main method for testing this class.
  1206.    *
  1207.    * @param argv the command-line options
  1208.    */
  1209.   public static void main(String [] argv) {
  1210.     
  1211.     Classifier scheme;
  1212.     
  1213.     try {
  1214.       scheme = new DecisionTable();
  1215.       System.out.println(Evaluation.evaluateModel(scheme,argv));
  1216.     }
  1217.     catch (Exception e) {
  1218.       e.printStackTrace();
  1219.       System.out.println(e.getMessage());
  1220.     }
  1221.   }
  1222. }