SVMAttributeEval.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 15k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  * This program is free software; you can redistribute it and/or modify
  3.  * it under the terms of the GNU General Public License as published by
  4.  * the Free Software Foundation; either version 2 of the License, or
  5.  * (at your option) any later version.
  6.  * 
  7.  * This program is distributed in the hope that it will be useful,
  8.  * but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  * GNU General Public License for more details.
  11.  * 
  12.  * You should have received a copy of the GNU General Public License
  13.  * along with this program; if not, write to the Free Software
  14.  * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  * SVMAttributeEval.java
  18.  * Copyright (C) 2002 Eibe Frank
  19.  * Mod by Kieran Holland
  20.  * 
  21.  */
  22. package weka.attributeSelection;
  23. import java.io.*;
  24. import java.util.*;
  25. import weka.core.*;
  26. import weka.classifiers.functions.SMO;
  27. import weka.filters.Filter;
  28. import weka.filters.unsupervised.attribute.Remove;
  29. import weka.attributeSelection.*;
  30. /**
  31.  * Class for Evaluating attributes individually by using the SVM
  32.  * classifier. <p>
  33.  * 
  34.  * Valid options are: <p>
  35.  * 
  36.  * -E <constant rate of elimination> <br>
  37.  * Specify constant rate at which attributes are eliminated per invocation
  38.  * of the support vector machine. Default = 1.<p>
  39.  * 
  40.  * -P <percent rate of elimination> <br>
  41.  * Specify the percentage rate at which attributes are eliminated per invocation
  42.  * of the support vector machine. This setting trumps the constant rate setting.
  43.  * Default = 0 (percentage rate ignored).<p>
  44.  * 
  45.  * -T <threshold for percent elimination> <br>
  46.  * Specify the threshold below which the percentage elimination method
  47.  * reverts to the constant elimination method.<p>
  48.  * 
  49.  * -C <complexity parameter> <br>
  50.  * Specify the value of C - the complexity parameter to be passed on
  51.  * to the support vector machine. <p>
  52.  * 
  53.  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
  54.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  55.  * @version $Revision: 1.8 $
  56.  */
  57. public class SVMAttributeEval extends AttributeEvaluator 
  58.   implements OptionHandler {
  59.   /**
  60.    * The attribute scores
  61.    */
  62.   private double[] m_attScores;
  63.   /**
  64.    * Constant rate of attribute elimination per iteration
  65.    */
  66.   private int      m_numToEliminate = 1;
  67.   /**
  68.    * Percentage rate of attribute elimination, trumps constant
  69.    * rate (above threshold), ignored if = 0
  70.    */
  71.   private int      m_percentToEliminate = 0;
  72.   /**
  73.    * Threshold below which percent elimination switches to
  74.    * constant elimination
  75.    */
  76.   private int      m_percentThreshold = 0;
  77.   /**
  78.    * Complexity parameter to pass on to SMO
  79.    */
  80.   private double   m_smoCParameter = 1.0;
  81.   /**
  82.    * Returns a string describing this attribute evaluator
  83.    * @return a description of the evaluator suitable for
  84.    * displaying in the explorer/experimenter gui
  85.    */
  86.   public String globalInfo() {
  87.     return "SVMAttributeEval :nnEvaluates the worth of an attribute by " 
  88.    + "using an SVM classifier.n";
  89.   } 
  90.   /**
  91.    * Constructor
  92.    */
  93.   public SVMAttributeEval() {
  94.     resetOptions();
  95.   }
  96.   /**
  97.    * Returns an enumeration describing all the available options
  98.    * 
  99.    * @return an enumeration of options
  100.    */
  101.   public Enumeration listOptions() {
  102.     Vector newVector = new Vector(4);
  103.     newVector
  104.       .addElement(new Option("tSpecify the constant rate of attributen" 
  105.      + "telimination per invocation ofn" 
  106.      + "tthe support vector machine.n" 
  107.      + "tDefault = 1.", "E", 1, 
  108.      "-N <constant rate of elimination>"));
  109.     newVector
  110.       .addElement(new Option("tSpecify the percentage rate of attributes ton" 
  111.      + "telimination per invocation ofn" 
  112.      + "tthe support vector machine.n" 
  113.      + "tTrumps constant rate (above threshold).n" 
  114.      + "tDefault = 0.", "P", 1, 
  115.      "-P <percent rate of elimination>"));
  116.     newVector
  117.       .addElement(new Option("tSpecify the threshold below which n" 
  118.      + "tpercentage attribute eliminationn" 
  119.      + "treverts to the constant method.n", "T", 1, 
  120.      "-T <threshold for percent elimination>"));
  121.     newVector.addElement(new Option("tSpecify the value of C (complexityn" 
  122.     + "tparameter) to pass on to then" 
  123.     + "tsupport vector machine.n" 
  124.     + "tDefault = 1.0", "C", 1, 
  125.     "-C <complexity>"));
  126.     return newVector.elements();
  127.   } 
  128.   /**
  129.    * Parses a given list of options
  130.    * 
  131.    * Valid options are: <p>
  132.    * 
  133.    * -E <constant rate of elimination> <br>
  134.    * Specify constant rate at which attributes are eliminated per invocation
  135.    * of the support vector machine. Default = 1.<p>
  136.    * 
  137.    * -P <percent rate of elimination> <br>
  138.    * Specify the percentage rate at which attributes are eliminated per invocation
  139.    * of the support vector machine. This setting trumps the constant rate setting.
  140.    * Default = 0 (percentage rate ignored).<p>
  141.    * 
  142.    * -T <threshold for percent elimination> <br>
  143.    * Specify the threshold below which the percentage elimination method
  144.    * reverts to the constant elimination method.<p>
  145.    * 
  146.    * -C <complexity parameter> <br>
  147.    * Specify the value of C - the complexity parameter to be passed on
  148.    * to the support vector machine. <p>
  149.    * 
  150.    * @param options the list of options as an array of strings
  151.    * @exception Exception if an error occurs
  152.    */
  153.   public void setOptions(String[] options) throws Exception {
  154.     String optionString;
  155.     optionString = Utils.getOption('E', options);
  156.     if (optionString.length() != 0) {
  157.       setAttsToEliminatePerIteration(Integer.parseInt(optionString));
  158.     } 
  159.     optionString = Utils.getOption('P', options);
  160.     if (optionString.length() != 0) {
  161.       setPercentToEliminatePerIteration(Integer.parseInt(optionString));
  162.     } 
  163.     optionString = Utils.getOption('T', options);
  164.     if (optionString.length() != 0) {
  165.       setPercentThreshold(Integer.parseInt(optionString));
  166.     } 
  167.     optionString = Utils.getOption('C', options);
  168.     if (optionString.length() != 0) {
  169.       setComplexityParameter((new Double(optionString)).doubleValue());
  170.     } 
  171.     Utils.checkForRemainingOptions(options);
  172.   } 
  173.   /**
  174.    * Gets the current settings of SVMAttributeEval
  175.    * 
  176.    * @return an array of strings suitable for passing to setOptions()
  177.    */
  178.   public String[] getOptions() {
  179.     String[] options = new String[8];
  180.     int      current = 0;
  181.     options[current++] = "-E";
  182.     options[current++] = "" + getAttsToEliminatePerIteration();
  183.     options[current++] = "-P";
  184.     options[current++] = "" + getPercentToEliminatePerIteration();
  185.     options[current++] = "-T";
  186.     options[current++] = "" + getPercentThreshold();
  187.     options[current++] = "-C";
  188.     options[current++] = "" + getComplexityParameter();
  189.     while (current < options.length) {
  190.       options[current++] = "";
  191.     } 
  192.     return options;
  193.   } 
  194.   // ________________________________________________________________________
  195.   /**
  196.    * Returns a tip text for this property suitable for display in the
  197.    * GUI
  198.    * 
  199.    * @return tip text string describing this property
  200.    */
  201.   public String attsToEliminatePerIterationTipText() {
  202.     return "Constant rate of attribute elimination.";
  203.   } 
  204.   /**
  205.    * Returns a tip text for this property suitable for display in the
  206.    * GUI
  207.    * 
  208.    * @return tip text string describing this property
  209.    */
  210.   public String percentToEliminatePerIterationTipText() {
  211.     return "Percent rate of attribute elimination.";
  212.   } 
  213.   /**
  214.    * Returns a tip text for this property suitable for display in the
  215.    * GUI
  216.    * 
  217.    * @return tip text string describing this property
  218.    */
  219.   public String percentThresholdTipText() {
  220.     return "Threshold below which percent elimination reverts to constant elimination.";
  221.   } 
  222.   /**
  223.    * Returns a tip text for this property suitable for display in the
  224.    * GUI
  225.    * 
  226.    * @return tip text string describing this property
  227.    */
  228.   public String complexityParameterTipText() {
  229.     return "C complexity parameter to pass to the SVM";
  230.   } 
  231.   // ________________________________________________________________________
  232.   /**
  233.    * Set the constant rate of attribute elimination per iteration
  234.    * 
  235.    * @param cRate the constant rate of attribute elimination per iteration
  236.    */
  237.   public void setAttsToEliminatePerIteration(int cRate) {
  238.     m_numToEliminate = cRate;
  239.   } 
  240.   /**
  241.    * Get the constant rate of attribute elimination per iteration
  242.    * 
  243.    * @return the constant rate of attribute elimination per iteration
  244.    */
  245.   public int getAttsToEliminatePerIteration() {
  246.     return m_numToEliminate;
  247.   } 
  248.   /**
  249.    * Set the percentage of attributes to eliminate per iteration
  250.    * 
  251.    * @param pRate percent of attributes to eliminate per iteration
  252.    */
  253.   public void setPercentToEliminatePerIteration(int pRate) {
  254.     m_percentToEliminate = pRate;
  255.   } 
  256.   /**
  257.    * Get the percentage rate of attribute elimination per iteration
  258.    * 
  259.    * @return the percentage rate of attribute elimination per iteration
  260.    */
  261.   public int getPercentToEliminatePerIteration() {
  262.     return m_percentToEliminate;
  263.   } 
  264.   /**
  265.    * Set the threshold below which percentage elimination reverts to
  266.    * constant elimination.
  267.    * 
  268.    * @param thresh percent of attributes to eliminate per iteration
  269.    */
  270.   public void setPercentThreshold(int thresh) {
  271.     m_percentThreshold = thresh;
  272.   } 
  273.   /**
  274.    * Get the threshold below which percentage elimination reverts to
  275.    * constant elimination.
  276.    * 
  277.    * @return the threshold below which percentage elimination stops
  278.    */
  279.   public int getPercentThreshold() {
  280.     return m_percentThreshold;
  281.   } 
  282.   /**
  283.    * Set the value of C for SMO
  284.    * 
  285.    * @param svmC the value of C
  286.    */
  287.   public void setComplexityParameter(double svmC) {
  288.     m_smoCParameter = svmC;
  289.   } 
  290.   /**
  291.    * Get the value of C used with SMO
  292.    * 
  293.    * @return the value of C
  294.    */
  295.   public double getComplexityParameter() {
  296.     return m_smoCParameter;
  297.   } 
  298.   // ________________________________________________________________________
  299.   /**
  300.    * Initializes the evaluator.
  301.    * 
  302.    * @param data set of instances serving as training data
  303.    * @exception Exception if the evaluator has not been
  304.    * generated successfully
  305.    */
  306.   public void buildEvaluator(Instances data) throws Exception {
  307.     if (data.checkForStringAttributes()) {
  308.       throw new UnsupportedAttributeTypeException("Can't handle string attributes!");
  309.     } 
  310.     if (!data.classAttribute().isNominal()) {
  311.       throw new Exception("Class must be nominal!");
  312.     } 
  313.     if (data.classAttribute().numValues() != 2) {
  314.       throw new Exception("Can only deal with binary class problems!");
  315.     } 
  316.     // Holds a mapping into the original array of attribute indices
  317.     int[] origIndices = new int[data.numAttributes()];
  318.     for (int i = 0; i < origIndices.length; i++) {
  319.       if (data.attribute(i).isNominal() 
  320.       && (data.attribute(i).numValues() != 2)) {
  321. throw new Exception("All nominal attributes must be binary!");
  322.       } 
  323.       origIndices[i] = i;
  324.     } 
  325.     // We need to repeat the following loop until we've computed
  326.     // a weight for every attribute (excluding the class)
  327.     m_attScores = new double[data.numAttributes()];
  328.     Instances trainCopy = new Instances(data);
  329.     m_numToEliminate = (m_numToEliminate > 1) ? m_numToEliminate : 1;
  330.     m_percentToEliminate = (m_percentToEliminate < 100) 
  331.    ? m_percentToEliminate : 100;
  332.     m_percentToEliminate = (m_percentToEliminate > 0) ? m_percentToEliminate 
  333.    : 0;
  334.     m_percentThreshold = (m_percentThreshold < m_attScores.length) 
  335.  ? m_percentThreshold : m_attScores.length - 1;
  336.     m_percentThreshold = (m_percentThreshold > 0) ? m_percentThreshold : 0;
  337.     int    i = 0;
  338.     double pctToElim = ((double) m_percentToEliminate) / 100.0;
  339.     while (trainCopy.numAttributes() > 1) {
  340.       int numToElim;
  341.       if (pctToElim > 0) {
  342. numToElim = (int) (trainCopy.numAttributes() * pctToElim);
  343. numToElim = (numToElim > 1) ? numToElim : 1;
  344. if (m_attScores.length - i - numToElim <= m_percentThreshold) {
  345.   pctToElim = 0;
  346.   numToElim = m_attScores.length - i - m_percentThreshold;
  347.       } else {
  348. numToElim = (m_attScores.length - i - 1 >= m_numToEliminate) 
  349.     ? m_numToEliminate : m_attScores.length - i - 1;
  350.       } 
  351.       // System.out.println("Progress: " + trainCopy.numAttributes());
  352.       // Build the linear SVM with default parameters
  353.       SMO smo = new SMO();
  354.       smo.setC(m_smoCParameter);
  355.       smo.buildClassifier(trainCopy);
  356.       // Find the attribute with maximum weight^2
  357.       FastVector weightsAndIndices = smo.weights();
  358.       double[]   weightsSparse = (double[]) weightsAndIndices.elementAt(0);
  359.       int[]      indicesSparse = (int[]) weightsAndIndices.elementAt(1);
  360.       double[]   weights = new double[trainCopy.numAttributes()];
  361.       for (int j = 0; j < weightsSparse.length; j++) {
  362. weights[indicesSparse[j]] = weightsSparse[j] * weightsSparse[j];
  363.       } 
  364.       weights[trainCopy.classIndex()] = Double.MAX_VALUE;
  365.       int       minWeightIndex;
  366.       int[]     featArray = new int[numToElim];
  367.       boolean[] eliminated = new boolean[origIndices.length];
  368.       for (int j = 0; j < numToElim; j++) {
  369. minWeightIndex = Utils.minIndex(weights);
  370. m_attScores[origIndices[minWeightIndex]] = i + j + 1;
  371. featArray[j] = minWeightIndex;
  372. eliminated[minWeightIndex] = true;
  373. weights[minWeightIndex] = Double.MAX_VALUE;
  374.       } 
  375.       // Delete the best attribute.
  376.       Remove delTransform = new Remove();
  377.       delTransform.setInvertSelection(false);
  378.       delTransform.setAttributeIndicesArray(featArray);
  379.       delTransform.setInputFormat(trainCopy);
  380.       trainCopy = Filter.useFilter(trainCopy, delTransform);
  381.       // Update the array of indices
  382.       int[] temp = new int[origIndices.length - numToElim];
  383.       int   k = 0;
  384.       for (int j = 0; j < origIndices.length; j++) {
  385. if (!eliminated[j]) {
  386.   temp[k++] = origIndices[j];
  387.       } 
  388.       origIndices = temp;
  389.       i += numToElim;
  390.     } 
  391.   } 
  392.   /**
  393.    * Resets options to defaults.
  394.    */
  395.   protected void resetOptions() {
  396.     m_attScores = null;
  397.   } 
  398.   /**
  399.    * Evaluates an attribute by returning the square of its coefficient in a
  400.    * linear support vector machine.
  401.    * 
  402.    * @param attribute the index of the attribute to be evaluated
  403.    * @exception Exception if the attribute could not be evaluated
  404.    */
  405.   public double evaluateAttribute(int attribute) throws Exception {
  406.     return m_attScores[attribute];
  407.   } 
  408.   /**
  409.    * Return a description of the evaluator
  410.    * @return description as a string
  411.    */
  412.   public String toString() {
  413.     StringBuffer text = new StringBuffer();
  414.     if (m_attScores == null) {
  415.       text.append("tSVM feature evaluator has not been built yet");
  416.     } else {
  417.       text.append("tSVM feature evaluator");
  418.     } 
  419.     text.append("n");
  420.     return text.toString();
  421.   } 
  422.   /**
  423.    * Main method for testing this class.
  424.    * 
  425.    * @param args the options
  426.    */
  427.   public static void main(String[] args) {
  428.     try {
  429.       File  arff = new File("d:\weka331\data\golub.arff");
  430.       BufferedReader     br = new BufferedReader(new FileReader(arff));
  431.       Instances  test = new Instances(br);
  432.       AttributeSelection as = new AttributeSelection();
  433.       SVMAttributeEval   svm = new SVMAttributeEval();
  434.       svm.setAttsToEliminatePerIteration(1);
  435.       svm.setPercentThreshold(100);
  436.       svm.setPercentToEliminatePerIteration(10);
  437.       test.setClassIndex(0);
  438.       as.setEvaluator(svm);
  439.       as.setSearch(new Ranker());
  440.       as.SelectAttributes(test);
  441.       System.out.println(as.toResultsString());
  442.     } catch (Exception e) {
  443.       e.printStackTrace();
  444.       System.out.println(e.getMessage());
  445.     } 
  446.   } 
  447. }