OneR.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 13k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    OneR.java
  18.  *    Copyright (C) 1999 Ian H. Witten
  19.  *
  20.  */
  21. package weka.classifiers.rules;
  22. import weka.classifiers.Classifier;
  23. import weka.classifiers.Evaluation;
  24. import java.io.*;
  25. import java.util.*;
  26. import weka.core.*;
  27. /**
  28.  * Class for building and using a 1R classifier. For more information, see<p>
  29.  *
  30.  * R.C. Holte (1993). <i>Very simple classification rules
  31.  * perform well on most commonly used datasets</i>. Machine Learning,
  32.  * Vol. 11, pp. 63-91.<p>
  33.  *
  34.  * Valid options are:<p>
  35.  *
  36.  * -B num <br>
  37.  * Specify the minimum number of objects in a bucket (default: 6). <p>
  38.  * 
  39.  * @author Ian H. Witten (ihw@cs.waikato.ac.nz)
  40.  * @version $Revision: 1.15 $ 
  41. */
  42. public class OneR extends Classifier implements OptionHandler {
  43.   /**
  44.    * Class for storing store a 1R rule.
  45.    */
  46.   private class OneRRule implements Serializable {
  47.     /** The class attribute. */
  48.     private Attribute m_class;
  49.     /** The number of instances used for building the rule. */
  50.     private int m_numInst;
  51.     /** Attribute to test */
  52.     private Attribute m_attr; 
  53.     /** Training set examples this rule gets right */
  54.     private int m_correct; 
  55.     /** Predicted class for each value of attr */
  56.     private int[] m_classifications; 
  57.     /** Predicted class for missing values */
  58.     private int m_missingValueClass = -1; 
  59.     /** Breakpoints (numeric attributes only) */
  60.     private double[] m_breakpoints; 
  61.   
  62.     /**
  63.      * Constructor for nominal attribute.
  64.      */
  65.     public OneRRule(Instances data, Attribute attribute) throws Exception {
  66.       m_class = data.classAttribute();
  67.       m_numInst = data.numInstances();
  68.       m_attr = attribute;
  69.       m_correct = 0;
  70.       m_classifications = new int[m_attr.numValues()];
  71.     }
  72.     /**
  73.      * Constructor for numeric attribute.
  74.      */
  75.     public OneRRule(Instances data, Attribute attribute, int nBreaks) throws Exception {
  76.       m_class = data.classAttribute();
  77.       m_numInst = data.numInstances();
  78.       m_attr = attribute;
  79.       m_correct = 0;
  80.       m_classifications = new int[nBreaks];
  81.       m_breakpoints = new double[nBreaks - 1]; // last breakpoint is infinity
  82.     }
  83.     
  84.     /**
  85.      * Returns a description of the rule.
  86.      */
  87.     public String toString() {
  88.       try {
  89. StringBuffer text = new StringBuffer();
  90. text.append(m_attr.name() + ":n");
  91. for (int v = 0; v < m_classifications.length; v++) {
  92.   text.append("t");
  93.   if (m_attr.isNominal()) {
  94.     text.append(m_attr.value(v));
  95.   } else if (v < m_breakpoints.length) {
  96.     text.append("< " + m_breakpoints[v]);
  97.   } else if (v > 0) {
  98.     text.append(">= " + m_breakpoints[v - 1]);
  99.   } else {
  100.     text.append("not ?");
  101.   }
  102.   text.append("t-> " + m_class.value(m_classifications[v]) + "n");
  103. }
  104. if (m_missingValueClass != -1) {
  105.   text.append("t?t-> " + m_class.value(m_missingValueClass) + "n");
  106. }
  107. text.append("(" + m_correct + "/" + m_numInst + " instances correct)n");
  108. return text.toString();
  109.       } catch (Exception e) {
  110. return "Can't print OneR classifier!";
  111.       }
  112.     }
  113.   }
  114.   
  115.   /** A 1-R rule */
  116.   private OneRRule m_rule;
  117.   /** The minimum bucket size */
  118.   private int m_minBucketSize = 6;
  119.   /**
  120.    * Classifies a given instance.
  121.    *
  122.    * @param inst the instance to be classified
  123.    */
  124.   public double classifyInstance(Instance inst) {
  125.     int v = 0;
  126.     if (inst.isMissing(m_rule.m_attr)) {
  127.       if (m_rule.m_missingValueClass != -1) {
  128. return m_rule.m_missingValueClass;
  129.       } else {
  130. return 0;  // missing values occur in test but not training set    
  131.       }
  132.     }
  133.     if (m_rule.m_attr.isNominal()) {
  134.       v = (int) inst.value(m_rule.m_attr);
  135.     } else {
  136.       while (v < m_rule.m_breakpoints.length &&
  137.      inst.value(m_rule.m_attr) >= m_rule.m_breakpoints[v]) {
  138. v++;
  139.       }
  140.     }
  141.     return m_rule.m_classifications[v];
  142.   }
  143.   /**
  144.    * Generates the classifier.
  145.    *
  146.    * @param instances the instances to be used for building the classifier
  147.    * @exception Exception if the classifier can't be built successfully
  148.    */
  149.   public void buildClassifier(Instances instances) 
  150.     throws Exception {
  151.     
  152.     boolean noRule = true;
  153.     if (instances.checkForStringAttributes()) {
  154.       throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
  155.     }
  156.     if (instances.classAttribute().isNumeric()) {
  157.       throw new UnsupportedClassTypeException("Can't handle numeric class!");
  158.     }
  159.     Instances data = new Instances(instances);
  160.     // new dataset without missing class values
  161.     data.deleteWithMissingClass();
  162.     if (data.numInstances() == 0) {
  163.       throw new Exception("No instances with a class value!");
  164.     }
  165.     // for each attribute ...
  166.     Enumeration enum = instances.enumerateAttributes();
  167.     while (enum.hasMoreElements()) {
  168.       try {
  169. OneRRule r = newRule((Attribute) enum.nextElement(), data);
  170. // if this attribute is the best so far, replace the rule
  171. if (noRule || r.m_correct > m_rule.m_correct) {
  172.   m_rule = r;
  173. }
  174. noRule = false;
  175.       } catch (Exception ex) {
  176.       }
  177.     }
  178.   }
  179.   /**
  180.    * Create a rule branching on this attribute.
  181.    *
  182.    * @param attr the attribute to branch on
  183.    * @param data the data to be used for creating the rule
  184.    * @exception Exception if the rule can't be built successfully
  185.    */
  186.   public OneRRule newRule(Attribute attr, Instances data) throws Exception {
  187.     OneRRule r;
  188.     // ... create array to hold the missing value counts
  189.     int[] missingValueCounts =
  190.       new int [data.classAttribute().numValues()];
  191.     
  192.     if (attr.isNominal()) {
  193.       r = newNominalRule(attr, data, missingValueCounts);
  194.     } else {
  195.       r = newNumericRule(attr, data, missingValueCounts);
  196.     }
  197.     r.m_missingValueClass = Utils.maxIndex(missingValueCounts);
  198.     if (missingValueCounts[r.m_missingValueClass] == 0) {
  199.       r.m_missingValueClass = -1; // signal for no missing value class
  200.     } else {
  201.       r.m_correct += missingValueCounts[r.m_missingValueClass];
  202.     }
  203.     return r;
  204.   }
  205.   /**
  206.    * Create a rule branching on this nominal attribute.
  207.    *
  208.    * @param attr the attribute to branch on
  209.    * @param data the data to be used for creating the rule
  210.    * @param missingValueCounts to be filled in
  211.    * @exception Exception if the rule can't be built successfully
  212.    */
  213.   public OneRRule newNominalRule(Attribute attr, Instances data,
  214.                                  int[] missingValueCounts) throws Exception {
  215.     // ... create arrays to hold the counts
  216.     int[][] counts = new int [attr.numValues()]
  217.                              [data.classAttribute().numValues()];
  218.       
  219.     // ... calculate the counts
  220.     Enumeration enum = data.enumerateInstances();
  221.     while (enum.hasMoreElements()) {
  222.       Instance i = (Instance) enum.nextElement();
  223.       if (i.isMissing(attr)) {
  224. missingValueCounts[(int) i.classValue()]++; 
  225.       } else {
  226. counts[(int) i.value(attr)][(int) i.classValue()]++;
  227.       }
  228.     }
  229.     OneRRule r = new OneRRule(data, attr); // create a new rule
  230.     for (int value = 0; value < attr.numValues(); value++) {
  231.       int best = Utils.maxIndex(counts[value]);
  232.       r.m_classifications[value] = best;
  233.       r.m_correct += counts[value][best];
  234.     }
  235.     return r;
  236.   }
  237.   /**
  238.    * Create a rule branching on this numeric attribute
  239.    *
  240.    * @param attr the attribute to branch on
  241.    * @param data the data to be used for creating the rule
  242.    * @param missingValueCounts to be filled in
  243.    * @exception Exception if the rule can't be built successfully
  244.    */
  245.   public OneRRule newNumericRule(Attribute attr, Instances data,
  246.                              int[] missingValueCounts) throws Exception {
  247.     // ... can't be more than numInstances buckets
  248.     int [] classifications = new int[data.numInstances()];
  249.     double [] breakpoints = new double[data.numInstances()];
  250.     // create array to hold the counts
  251.     int [] counts = new int[data.classAttribute().numValues()];
  252.     int correct = 0;
  253.     int lastInstance = data.numInstances();
  254.     // missing values get sorted to the end of the instances
  255.     data.sort(attr);
  256.     while (lastInstance > 0 && 
  257.            data.instance(lastInstance-1).isMissing(attr)) {
  258.       lastInstance--;
  259.       missingValueCounts[(int) data.instance(lastInstance).
  260.                          classValue()]++; 
  261.     }
  262.     int i = 0; 
  263.     int cl = 0; // index of next bucket to create
  264.     int it;
  265.     while (i < lastInstance) { // start a new bucket
  266.       for (int j = 0; j < counts.length; j++) counts[j] = 0;
  267.       do { // fill it until it has enough of the majority class
  268.         it = (int) data.instance(i++).classValue();
  269.         counts[it]++;
  270.       } while (counts[it] < m_minBucketSize && i < lastInstance);
  271.       // while class remains the same, keep on filling
  272.       while (i < lastInstance && 
  273.              (int) data.instance(i).classValue() == it) { 
  274.         counts[it]++; 
  275.         i++;
  276.       }
  277.       while (i < lastInstance && // keep on while attr value is the same
  278.              (data.instance(i - 1).value(attr) 
  279.       == data.instance(i).value(attr))) {
  280.         counts[(int) data.instance(i++).classValue()]++;
  281.       }
  282.       for (int j = 0; j < counts.length; j++) {
  283.         if (counts[j] > counts[it]) { 
  284.   it = j;
  285. }
  286.       }
  287.       if (cl > 0) { // can we coalesce with previous class?
  288.         if (counts[classifications[cl - 1]] == counts[it]) {
  289.           it = classifications[cl - 1];
  290. }
  291.         if (it == classifications[cl - 1]) {
  292.   cl--; // yes!
  293. }
  294.       }
  295.       correct += counts[it];
  296.       classifications[cl] = it;
  297.       if (i < lastInstance) {
  298.         breakpoints[cl] = (data.instance(i - 1).value(attr)
  299.    + data.instance(i).value(attr)) / 2;
  300.       }
  301.       cl++;
  302.     }
  303.     if (cl == 0) {
  304.       throw new Exception("Only missing values in the training data!");
  305.     }
  306.     OneRRule r = new OneRRule(data, attr, cl); // new rule with cl branches
  307.     r.m_correct = correct;
  308.     for (int v = 0; v < cl; v++) {
  309.       r.m_classifications[v] = classifications[v];
  310.       if (v < cl-1) {
  311. r.m_breakpoints[v] = breakpoints[v];
  312.       }
  313.     }
  314.     return r;
  315.   }
  316.   /**
  317.    * Returns an enumeration describing the available options..
  318.    *
  319.    * @return an enumeration of all the available options.
  320.    */
  321.   public Enumeration listOptions() {
  322.     String string = "tThe minimum number of objects in a bucket (default: 6).";
  323.     Vector newVector = new Vector(1);
  324.     newVector.addElement(new Option(string, "B", 1, 
  325.     "-B <minimum bucket size>"));
  326.     return newVector.elements();
  327.   }
  328.   /**
  329.    * Parses a given list of options. Valid options are:<p>
  330.    *
  331.    * -B num <br>
  332.    * Specify the minimum number of objects in a bucket (default: 6). <p>
  333.    *
  334.    * @param options the list of options as an array of strings
  335.    * @exception Exception if an option is not supported
  336.    */
  337.   public void setOptions(String[] options) throws Exception {
  338.     
  339.     String bucketSizeString = Utils.getOption('B', options);
  340.     if (bucketSizeString.length() != 0) {
  341.       m_minBucketSize = Integer.parseInt(bucketSizeString);
  342.     } else {
  343.       m_minBucketSize = 6;
  344.     }
  345.   }
  346.   /**
  347.    * Gets the current settings of the OneR classifier.
  348.    *
  349.    * @return an array of strings suitable for passing to setOptions
  350.    */
  351.   public String [] getOptions() {
  352.     String [] options = new String [2];
  353.     int current = 0;
  354.     options[current++] = "-B"; options[current++] = "" + m_minBucketSize;
  355.     while (current < options.length) {
  356.       options[current++] = "";
  357.     }
  358.     return options;
  359.   }
  360.   /**
  361.    * Returns a description of the classifier
  362.    */
  363.   public String toString() {
  364.     if (m_rule == null) {
  365.       return "OneR: No model built yet.";
  366.     }
  367.     return m_rule.toString();
  368.   }
  369.   
  370.   /**
  371.    * Get the value of minBucketSize.
  372.    * @return Value of minBucketSize.
  373.    */
  374.   public int getMinBucketSize() {
  375.     
  376.     return m_minBucketSize;
  377.   }
  378.   
  379.   /**
  380.    * Set the value of minBucketSize.
  381.    * @param v  Value to assign to minBucketSize.
  382.    */
  383.   public void setMinBucketSize(int v) {
  384.     
  385.     m_minBucketSize = v;
  386.   }
  387.   
  388.   /**
  389.    * Main method for testing this class
  390.    */
  391.   public static void main(String [] argv) {
  392.     try {
  393.       System.out.println(Evaluation.evaluateModel(new OneR(), argv));
  394.     } catch (Exception e) {
  395.       System.err.println(e.getMessage());
  396.     }
  397.   }
  398. }