SMO.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 32k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    SMO.java
  18.  *    Copyright (C) 1999 Eibe Frank
  19.  */
  20. package weka.classifiers;
  21. import java.util.*;
  22. import java.io.*;
  23. import weka.core.*;
  24. import weka.filters.*;
  25. /**
  26.  * Implements John C. Platt's sequential minimal optimization
  27.  * algorithm for training a support vector classifier using polynomial
  28.  * kernels. Transforms output of SVM into probabilities by applying a
  29.  * standard sigmoid function that is not fitted to the data.
  30.  *
  31.  * This implementation globally replaces all missing values and
  32.  * transforms nominal attributes into binary ones. For more
  33.  * information on the SMO algorithm, see<p>
  34.  *
  35.  * J. Platt (1998). <i>Fast Training of Support Vector
  36.  * Machines using Sequential Minimal Optimization</i>. Advances in Kernel
  37.  * Methods - Support Vector Learning, B. Sch鰈kopf, C. Burges, and
  38.  * A. Smola, eds., MIT Press. <p>
  39.  *
  40.  * S.S. Keerthi, S.K. Shevade, C. Bhattacharyya, K.R.K. Murthy (2001).
  41.  * <i> Improvements to Platt's SMO Algorithm for SVM Classifier
  42.  * Design.  Neural Computation, 13(3), pp 637-649, 2001. <p>
  43.  *
  44.  * Note: for improved speed normalization should be turned off when
  45.  * operating on SparseInstances.<p>
  46.  *
  47.  * Valid options are:<p>
  48.  *
  49.  * -C num <br>
  50.  * The complexity constant C. (default 1)<p>
  51.  *
  52.  * -E num <br>
  53.  * The exponent for the polynomial kernel. (default 1)<p>
  54.  *
  55.  * -N <br>
  56.  * Don't normalize the training instances. <p>
  57.  *
  58.  * -L <br>
  59.  * Rescale kernel. <p>
  60.  *
  61.  * -O <br>
  62.  * Use lower-order terms. <p>
  63.  *
  64.  * -A num <br>
  65.  * Sets the size of the kernel cache. Should be a prime number. 
  66.  * (default 1000003) <p>
  67.  *
  68.  * -T num <br>
  69.  * Sets the tolerance parameter. (default 1.0e-3)<p>
  70.  *
  71.  * -P num <br>
  72.  * Sets the epsilon for round-off error. (default 1.0e-12)<p>
  73.  *
  74.  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
  75.  * @author Shane Legg (shane@intelligenesis.net) (sparse vector code)
  76.  * @author Stuart Inglis (stuart@intelligenesis.net) (sparse vector code)
  77.  * @version $Revision: 1.23.2.2 $ 
  78.  */
  79. public class SMO extends DistributionClassifier implements OptionHandler {
  80.   /**
  81.    * Stores a set of a given size.
  82.    */
  83.   private class SMOset implements Serializable {
  84.     /** The current number of elements in the set */
  85.     private int m_number;
  86.     /** The first element in the set */
  87.     private int m_first;
  88.     /** Indicators */
  89.     private boolean[] m_indicators;
  90.     /** The next element for each element */
  91.     private int[] m_next;
  92.     /** The previous element for each element */
  93.     private int[] m_previous;
  94.     /**
  95.      * Creates a new set of the given size.
  96.      */
  97.     private SMOset(int size) {
  98.       
  99.       m_indicators = new boolean[size];
  100.       m_next = new int[size];
  101.       m_previous = new int[size];
  102.       m_number = 0;
  103.       m_first = -1;
  104.     }
  105.  
  106.     /**
  107.      * Checks whether an element is in the set.
  108.      */
  109.     private boolean contains(int index) {
  110.       return m_indicators[index];
  111.     }
  112.     /**
  113.      * Deletes an element from the set.
  114.      */
  115.     private void delete(int index) {
  116.       if (m_indicators[index]) {
  117. if (m_first == index) {
  118.   m_first = m_next[index];
  119. } else {
  120.   m_next[m_previous[index]] = m_next[index];
  121. }
  122. if (m_next[index] != -1) {
  123.   m_previous[m_next[index]] = m_previous[index];
  124. }
  125. m_indicators[index] = false;
  126. m_number--;
  127.       }
  128.     }
  129.     /**
  130.      * Inserts an element into the set.
  131.      */
  132.     private void insert(int index) {
  133.       if (!m_indicators[index]) {
  134. if (m_number == 0) {
  135.   m_first = index;
  136.   m_next[index] = -1;
  137.   m_previous[index] = -1;
  138. } else {
  139.   m_previous[m_first] = index;
  140.   m_next[index] = m_first;
  141.   m_previous[index] = -1;
  142.   m_first = index;
  143. }
  144. m_indicators[index] = true;
  145. m_number++;
  146.       }
  147.     }
  148.     /** 
  149.      * Gets the next element in the set. -1 gets the first one.
  150.      */
  151.     private int getNext(int index) {
  152.       if (index == -1) {
  153. return m_first;
  154.       } else {
  155. return m_next[index];
  156.       }
  157.     }
  158.     /**
  159.      * Prints all the current elements in the set.
  160.      */
  161.     private void printElements() {
  162.       for (int i = getNext(-1); i != -1; i = getNext(i)) {
  163. System.err.print(i + " ");
  164.       }
  165.       System.err.println();
  166.       for (int i = 0; i < m_indicators.length; i++) {
  167. if (m_indicators[i]) {
  168.   System.err.print(i + " ");
  169. }
  170.       }
  171.       System.err.println();
  172.       System.err.println(m_number);
  173.     }
  174.     /** 
  175.      * Returns the number of elements in the set.
  176.      */
  177.     private int numElements() {
  178.       
  179.       return m_number;
  180.     }
  181.   }
  182.   /** The exponent for the polnomial kernel. */
  183.   private double m_exponent = 1.0;
  184.   /** The complexity parameter. */
  185.   private double m_C = 1.0;
  186.   /** Epsilon for rounding. */
  187.   private double m_eps = 1.0e-12;
  188.   
  189.   /** Tolerance for accuracy of result. */
  190.   private double m_tol = 1.0e-3;
  191.   /** The Lagrange multipliers. */
  192.   private double[] m_alpha;
  193.   /** The thresholds. */
  194.   private double m_b, m_bLow, m_bUp;
  195.   /** The indices for m_bLow and m_bUp */
  196.   private int m_iLow, m_iUp;
  197.   /** The training data. */
  198.   private Instances m_data;
  199.   /** Weight vector for linear machine. */
  200.   private double[] m_weights;
  201.   /** Kernel function cache */
  202.   private double[] m_storage;
  203.   private long[] m_keys;
  204.   /** The transformed class values. */
  205.   private double[] m_class;
  206.   /** The current set of errors for all non-bound examples. */
  207.   private double[] m_errors;
  208.   /** The five different sets used by the algorithm. */
  209.   private SMOset m_I0; // {i: 0 < m_alpha[i] < C}
  210.   private SMOset m_I1; // {i: m_class[i] = 1, m_alpha[i] = 0}
  211.   private SMOset m_I2; // {i: m_class[i] = -1, m_alpha[i] =C}
  212.   private SMOset m_I3; // {i: m_class[i] = 1, m_alpha[i] = C}
  213.   private SMOset m_I4; // {i: m_class[i] = -1, m_alpha[i] = 0}
  214.   /** The set of support vectors */
  215.   private SMOset m_supportVectors; // {i: 0 < m_alpha[i]}
  216.   /** The filter used to make attributes numeric. */
  217.   private NominalToBinaryFilter m_NominalToBinary;
  218.   /** The filter used to normalize all values. */
  219.   private NormalizationFilter m_Normalization;
  220.   /** The filter used to get rid of missing values. */
  221.   private ReplaceMissingValuesFilter m_Missing;
  222.   /** Counts the number of kernel evaluations. */
  223.   private int m_kernelEvals;
  224.   /** The size of the cache (a prime number) */
  225.   private int m_cacheSize = 1000003;
  226.   /** True if we don't want to normalize */
  227.   private boolean m_Normalize = true;
  228.   /** Rescale? */
  229.   private boolean m_rescale = false;
  230.   
  231.   /** Use lower-order terms? */
  232.   private boolean m_lowerOrder = false;
  233.   /** Only numeric attributes in the dataset? */
  234.   private boolean m_onlyNumeric;
  235.   /** Precision constant for updating sets */
  236.   private static double m_Del = 1000 * Double.MIN_VALUE;
  237.   /**
  238.    * Method for building the classifier.
  239.    *
  240.    * @param insts the set of training instances
  241.    * @exception Exception if the classifier can't be built successfully
  242.    */
  243.   public void buildClassifier(Instances insts) throws Exception {
  244.     int m_kernelEvals = 0;
  245.     if (insts.checkForStringAttributes()) {
  246.       throw new Exception("Can't handle string attributes!");
  247.     }
  248.     if (insts.numClasses() > 2) {
  249.       throw new Exception("Can only handle two-class datasets!");
  250.     }
  251.     if (insts.classAttribute().isNumeric()) {
  252.       throw new Exception("SMO can't handle a numeric class!");
  253.     }
  254.     m_data = insts;
  255.     m_onlyNumeric = true;
  256.     for (int i = 0; i < m_data.numAttributes(); i++) {
  257.       if (i != m_data.classIndex()) {
  258. if (!m_data.attribute(i).isNumeric()) {
  259.   m_onlyNumeric = false;
  260.   break;
  261. }
  262.       }
  263.     }
  264.     m_Missing = new ReplaceMissingValuesFilter();
  265.     m_Missing.setInputFormat(m_data);
  266.     m_data = Filter.useFilter(m_data, m_Missing); 
  267.     if (m_Normalize) {
  268.       m_Normalization = new NormalizationFilter();
  269.       m_Normalization.setInputFormat(m_data);
  270.       m_data = Filter.useFilter(m_data, m_Normalization); 
  271.     } else {
  272.       m_Normalization = null;
  273.     }
  274.     if (!m_onlyNumeric) {
  275.       m_NominalToBinary = new NominalToBinaryFilter();
  276.       m_NominalToBinary.setInputFormat(m_data);
  277.       m_data = Filter.useFilter(m_data, m_NominalToBinary);
  278.     } else {
  279.       m_NominalToBinary = null;
  280.     }
  281.     // If machine is linear, reserve space for weights
  282.     if (m_exponent == 1.0) {
  283.       m_weights = new double[m_data.numAttributes()];
  284.     } else {
  285.       m_weights = null;
  286.     }
  287.     // Initialize alpha array to zero
  288.     m_alpha = new double[m_data.numInstances()];
  289.     // Initialize thresholds
  290.     m_bUp = -1; m_bLow = 1; m_b = 0;
  291.     // Initialize sets
  292.     m_supportVectors = new SMOset(m_data.numInstances());
  293.     m_I0 = new SMOset(m_data.numInstances());
  294.     m_I1 = new SMOset(m_data.numInstances());
  295.     m_I2 = new SMOset(m_data.numInstances());
  296.     m_I3 = new SMOset(m_data.numInstances());
  297.     m_I4 = new SMOset(m_data.numInstances());
  298.     // Set class values
  299.     m_class = new double[m_data.numInstances()];
  300.     m_iUp = -1; m_iLow = -1;
  301.     for (int i = 0; i < m_class.length; i++) {
  302.       if ((int) m_data.instance(i).classValue() == 0) {
  303. m_class[i] = -1; m_iLow = i;
  304.       } else {
  305. m_class[i] = 1; m_iUp = i;
  306.       }
  307.     }
  308.     if ((m_iUp == -1) || (m_iLow == -1)) {
  309.       if ((m_iUp == -1) && (m_iLow == -1)) {
  310. throw new Exception ("No instances without missing class values!");
  311.       } else {
  312. if (m_iUp == -1) {
  313.   m_b = 1;
  314. } else {
  315.   m_b = -1;
  316. }
  317. return;
  318.       }
  319.     }
  320.     // Initialize error cache
  321.     m_errors = new double[m_data.numInstances()];
  322.     m_errors[m_iLow] = 1; m_errors[m_iUp] = -1;
  323.     // The kernel calculations are cached
  324.     m_storage = new double[m_cacheSize];
  325.     m_keys = new long[m_cacheSize];
  326.     // Build up I1 and I4
  327.     for (int i = 0; i < m_class.length; i++ ) {
  328.       if (m_class[i] == 1) {
  329. m_I1.insert(i);
  330.       } else {
  331. m_I4.insert(i);
  332.       }
  333.     }
  334.     // Loop to find all the support vectors
  335.     int numChanged = 0;
  336.     boolean examineAll = true;
  337.     while ((numChanged > 0) || examineAll) {
  338.       numChanged = 0;
  339.       if (examineAll) {
  340. for (int i = 0; i < m_alpha.length; i++) {
  341.   if (examineExample(i)) {
  342.     numChanged++;
  343.   }
  344. }
  345.       } else {
  346. // This code implements Modification 1 from Keerthi et al.'s paper
  347. for (int i = 0; i < m_alpha.length; i++) {
  348.   if ((m_alpha[i] > 0) &&  (m_alpha[i] < m_C)) {
  349.     if (examineExample(i)) {
  350.       numChanged++;
  351.     }
  352.     
  353.     // Is optimality on unbound vectors obtained?
  354.     if (m_bUp > m_bLow - 2 * m_tol) {
  355.       numChanged = 0;
  356.       break;
  357.     }
  358.   }
  359. }
  360. //This is the code for Modification 2 from Keerthi et al.'s paper
  361. /*boolean innerLoopSuccess = true; 
  362. numChanged = 0;
  363. while ((m_bUp < m_bLow - 2 * m_tol) && (innerLoopSuccess == true)) {
  364.   innerLoopSuccess = takeStep(m_iUp, m_iLow, m_errors[m_iLow]);
  365.   }*/
  366.       }
  367.       if (examineAll) {
  368. examineAll = false;
  369.       } else if (numChanged == 0) {
  370. examineAll = true;
  371.       }
  372.     }
  373.     
  374.     // Set threshold
  375.     m_b = (m_bLow + m_bUp) / 2.0;
  376.     // Save memory
  377.     m_storage = null; m_keys = null; m_errors = null;
  378.     m_I0 = m_I1 = m_I2 = m_I3 = m_I4 = null;
  379.     // If machine is linear, delete training data
  380.     if (m_exponent == 1.0) {
  381.       m_data = new Instances(m_data, 0);
  382.     }
  383.   }
  384.   
  385.   /**
  386.    * Computes SVM output for given instance.
  387.    *
  388.    * @param index the instance for which output is to be computed
  389.    * @param inst the instance 
  390.    * @return the output of the SVM for the given instance
  391.    */
  392.   private double SVMOutput(int index, Instance inst) throws Exception {
  393.     double result = 0;
  394.     // Is the machine linear?
  395.     if (m_exponent == 1.0) {
  396.       int n1 = inst.numValues(); int classIndex = m_data.classIndex();
  397.       for (int p = 0; p < n1; p++) {
  398. if (inst.index(p) != classIndex) {
  399.   result += m_weights[inst.index(p)] * inst.valueSparse(p);
  400. }
  401.       }
  402.     } else {
  403.       for (int i = m_supportVectors.getNext(-1); i != -1; 
  404.    i = m_supportVectors.getNext(i)) {
  405. result += m_class[i] * m_alpha[i] * kernel(index, i, inst);
  406.       }
  407.     }
  408.     result -= m_b;
  409.     
  410.     return result;
  411.   }
  412.   /**
  413.    * Outputs the distribution for the given output.
  414.    *
  415.    * Pipes output of SVM through sigmoid function.
  416.    * @param inst the instance for which distribution is to be computed
  417.    * @return the distribution
  418.    * @exception Exception if something goes wrong
  419.    */
  420.   public double[] distributionForInstance(Instance inst) throws Exception {
  421.     // Filter instance
  422.     m_Missing.input(inst);
  423.     m_Missing.batchFinished();
  424.     inst = m_Missing.output();
  425.     
  426.     if (m_Normalize) {
  427.       m_Normalization.input(inst);
  428.       m_Normalization.batchFinished();
  429.       inst = m_Normalization.output();
  430.     }
  431.     if (!m_onlyNumeric) {
  432.       m_NominalToBinary.input(inst);
  433.       m_NominalToBinary.batchFinished();
  434.       inst = m_NominalToBinary.output();
  435.     }
  436.     // Get probabilities
  437.     double output = SVMOutput(-1, inst);
  438.     double[] result = new double[2];
  439.     result[1] = 1.0 / (1.0 + Math.exp(-output));
  440.     result[0] = 1.0 - result[1];
  441.     return result;
  442.   }
  443.   /**
  444.    * Returns an enumeration describing the available options
  445.    *
  446.    * @return an enumeration of all the available options
  447.    */
  448.   public Enumeration listOptions() {
  449.     Vector newVector = new Vector(8);
  450.     newVector.addElement(new Option("tThe complexity constant C. (default 1)",
  451.     "C", 1, "-C <double>"));
  452.     newVector.addElement(new Option("tThe exponent for the "
  453.     + "polynomial kernel. (default 1)",
  454.     "E", 1, "-E <double>"));
  455.     newVector.addElement(new Option("tDon't normalize the data.",
  456.     "N", 0, "-N"));
  457.     newVector.addElement(new Option("tRescale the kernel.",
  458.     "L", 0, "-L"));
  459.     newVector.addElement(new Option("tUse lower-order terms.",
  460.     "O", 0, "-O"));
  461.     newVector.addElement(new Option("tThe size of the kernel cache. " +
  462.     "(default 1000003)",
  463.     "A", 1, "-A <int>"));
  464.     newVector.addElement(new Option("tThe tolerance parameter. " +
  465.     "(default 1.0e-3)",
  466.     "T", 1, "-T <double>"));
  467.     newVector.addElement(new Option("tThe epsilon for round-off error. " +
  468.     "(default 1.0e-12)",
  469.     "P", 1, "-P <double>"));
  470.     
  471.     return newVector.elements();
  472.   }
  473.   /**
  474.    * Parses a given list of options. Valid options are:<p>
  475.    *
  476.    * -C num <br>
  477.    * The complexity constant C. (default 1)<p>
  478.    *
  479.    * -E num <br>
  480.    * The exponent for the polynomial kernel. (default 1) <p>
  481.    *
  482.    * -N <br>
  483.    * Don't normalize the training instances. <p>
  484.    *
  485.    * -L <br>
  486.    * Rescale kernel. <p>
  487.    *
  488.    * -O <br>
  489.    * Use lower-order terms. <p>
  490.    *
  491.    * -A num <br>
  492.    * Sets the size of the kernel cache. Should be a prime number. (default 1000003) <p>
  493.    *
  494.    * -T num <br>
  495.    * Sets the tolerance parameter. (default 1.0e-3)<p>
  496.    *
  497.    * -P num <br>
  498.    * Sets the epsilon for round-off error. (default 1.0e-12)<p>
  499.    *
  500.    * @param options the list of options as an array of strings
  501.    * @exception Exception if an option is not supported
  502.    */
  503.   public void setOptions(String[] options) throws Exception {
  504.     
  505.     String complexityString = Utils.getOption('C', options);
  506.     if (complexityString.length() != 0) {
  507.       m_C = (new Double(complexityString)).doubleValue();
  508.     } else {
  509.       m_C = 1.0;
  510.     }
  511.     String exponentsString = Utils.getOption('E', options);
  512.     if (exponentsString.length() != 0) {
  513.       m_exponent = (new Double(exponentsString)).doubleValue();
  514.     } else {
  515.       m_exponent = 1.0;
  516.     }
  517.     String cacheString = Utils.getOption('A', options);
  518.     if (cacheString.length() != 0) {
  519.       m_cacheSize = Integer.parseInt(cacheString);
  520.     } else {
  521.       m_cacheSize = 1000003;
  522.     }
  523.     String toleranceString = Utils.getOption('T', options);
  524.     if (toleranceString.length() != 0) {
  525.       m_tol = (new Double(toleranceString)).doubleValue();
  526.     } else {
  527.       m_tol = 1.0e-3;
  528.     }
  529.     String epsilonString = Utils.getOption('P', options);
  530.     if (epsilonString.length() != 0) {
  531.       m_eps = (new Double(epsilonString)).doubleValue();
  532.     } else {
  533.       m_eps = 1.0e-12;
  534.     }
  535.     m_Normalize = !Utils.getFlag('N', options);
  536.     m_rescale = Utils.getFlag('L', options);
  537.     if ((m_exponent == 1.0) && (m_rescale)) {
  538.       throw new Exception("Can't use rescaling with linear machine.");
  539.     }
  540.     m_lowerOrder = Utils.getFlag('O', options);
  541.     if ((m_exponent == 1.0) && (m_lowerOrder)) {
  542.       throw new Exception("Can't use lower-order terms with linear machine.");
  543.     }
  544.   }
  545.   /**
  546.    * Gets the current settings of the classifier.
  547.    *
  548.    * @return an array of strings suitable for passing to setOptions
  549.    */
  550.   public String [] getOptions() {
  551.     String [] options = new String [13];
  552.     int current = 0;
  553.     options[current++] = "-C"; options[current++] = "" + m_C;
  554.     options[current++] = "-E"; options[current++] = "" + m_exponent;
  555.     options[current++] = "-A"; options[current++] = "" + m_cacheSize;
  556.     options[current++] = "-T"; options[current++] = "" + m_tol;
  557.     options[current++] = "-P"; options[current++] = "" + m_eps;
  558.     if (!m_Normalize) {
  559.       options[current++] = "-N";
  560.     }
  561.     if (m_rescale) {
  562.       options[current++] = "-L";
  563.     }
  564.     if (m_lowerOrder) {
  565.       options[current++] = "-O";
  566.     }
  567.     while (current < options.length) {
  568.       options[current++] = "";
  569.     }
  570.     return options;
  571.   }
  572.   /**
  573.    * Prints out the classifier.
  574.    *
  575.    * @return a description of the classifier as a string
  576.    */
  577.   public String toString() {
  578.     StringBuffer text = new StringBuffer();
  579.     int printed = 0;
  580.     if (m_alpha == null) {
  581.       return "SMO: No model built yet.";
  582.     }
  583.     try {
  584.       text.append("SMOnn");
  585.       // If machine linear, print weight vector
  586.       if (m_exponent == 1.0) {
  587. text.append("Machine linear: showing attribute weights, ");
  588. text.append("not support vectors.nn");
  589. for (int i = 0; i < m_weights.length; i++) {
  590.   if (i != (int)m_data.classIndex()) {
  591.     if (printed > 0) {
  592.       text.append(" + ");
  593.     } else {
  594.       text.append("   ");
  595.     }
  596.     text.append(m_weights[i] + " * " + m_data.attribute(i).name()+"n");
  597.     printed++;
  598.   }
  599. }
  600.       } else {
  601. for (int i = 0; i < m_alpha.length; i++) {
  602.   if (m_supportVectors.contains(i)) {
  603.     if (printed > 0) {
  604.       text.append(" + ");
  605.     } else {
  606.       text.append("   ");
  607.     }
  608.     text.append(((int)m_class[i]) + " * " +
  609. m_alpha[i] + " * K[X(" + i + ") * X]n");
  610.     printed++;
  611.   }
  612. }
  613.       }
  614.       text.append(" - " + m_b);
  615.       text.append("nnNumber of support vectors: " + m_supportVectors.numElements());
  616.       text.append("nnNumber of kernel evaluations: " + m_kernelEvals);
  617.     } catch (Exception e) {
  618.       return "Can't print SMO classifier.";
  619.     }
  620.     
  621.     return text.toString();
  622.   }
  623.   
  624.   /**
  625.    * Get the value of exponent. 
  626.    *
  627.    * @return Value of exponent.
  628.    */
  629.   public double getExponent() {
  630.     
  631.     return m_exponent;
  632.   }
  633.   
  634.   /**
  635.    * Set the value of exponent. If linear kernel
  636.    * is used, rescaling and lower-order terms are
  637.    * turned off.
  638.    *
  639.    * @param v  Value to assign to exponent.
  640.    */
  641.   public void setExponent(double v) {
  642.     
  643.     if (v == 1.0) {
  644.       m_rescale = false;
  645.       m_lowerOrder = false;
  646.     }
  647.     m_exponent = v;
  648.   }
  649.   
  650.   /**
  651.    * Get the value of C.
  652.    *
  653.    * @return Value of C.
  654.    */
  655.   public double getC() {
  656.     
  657.     return m_C;
  658.   }
  659.   
  660.   /**
  661.    * Set the value of C.
  662.    *
  663.    * @param v  Value to assign to C.
  664.    */
  665.   public void setC(double v) {
  666.     
  667.     m_C = v;
  668.   }
  669.   
  670.   /**
  671.    * Get the value of tolerance parameter.
  672.    * @return Value of tolerance parameter.
  673.    */
  674.   public double getToleranceParameter() {
  675.     
  676.     return m_tol;
  677.   }
  678.   
  679.   /**
  680.    * Set the value of tolerance parameter.
  681.    * @param v  Value to assign to tolerance parameter.
  682.    */
  683.   public void setToleranceParameter(double v) {
  684.     
  685.     m_tol = v;
  686.   }
  687.   
  688.   /**
  689.    * Get the value of epsilon.
  690.    * @return Value of epsilon.
  691.    */
  692.   public double getEpsilon() {
  693.     
  694.     return m_eps;
  695.   }
  696.   
  697.   /**
  698.    * Set the value of epsilon.
  699.    * @param v  Value to assign to epsilon.
  700.    */
  701.   public void setEpsilon(double v) {
  702.     
  703.     m_eps = v;
  704.   }
  705.   
  706.   /**
  707.    * Get the size of the kernel cache
  708.    * @return Size of kernel cache.
  709.    */
  710.   public int getCacheSize() {
  711.     
  712.     return m_cacheSize;
  713.   }
  714.   
  715.   /**
  716.    * Set the value of the kernel cache.
  717.    * @param v  Size of kernel cache.
  718.    */
  719.   public void setCacheSize(int v) {
  720.     
  721.     m_cacheSize = v;
  722.   }
  723.   
  724.   /**
  725.    * Check whether data is to be normalized.
  726.    * @return true if data is to be normalized
  727.    */
  728.   public boolean getNormalizeData() {
  729.     
  730.     return m_Normalize;
  731.   }
  732.   
  733.   /**
  734.    * Set whether data is to be normalized.
  735.    * @param v  true if data is to be normalized
  736.    */
  737.   public void setNormalizeData(boolean v) {
  738.     
  739.     m_Normalize = v;
  740.   }
  741.   
  742.   /**
  743.    * Check whether kernel is being rescaled.
  744.    * @return Value of rescale.
  745.    */
  746.   public boolean getRescaleKernel() throws Exception {
  747.     return m_rescale;
  748.   }
  749.   
  750.   /**
  751.    * Set whether kernel is to be rescaled. Defaults
  752.    * to false if a linear machine is built.
  753.    * @param v  Value to assign to rescale.
  754.    */
  755.   public void setRescaleKernel(boolean v) throws Exception {
  756.     
  757.     if (m_exponent == 1.0) {
  758.       m_rescale = false;
  759.     } else {
  760.       m_rescale = v;
  761.     }
  762.   }
  763.   
  764.   /**
  765.    * Check whether lower-order terms are being used.
  766.    * @return Value of lowerOrder.
  767.    */
  768.   public boolean getLowerOrderTerms() {
  769.     
  770.     return m_lowerOrder;
  771.   }
  772.   
  773.   /**
  774.    * Set whether lower-order terms are to be used. Defaults
  775.    * to false if a linear machine is built.
  776.    * @param v  Value to assign to lowerOrder.
  777.    */
  778.   public void setLowerOrderTerms(boolean v) {
  779.     
  780.     if (m_exponent == 1.0) {
  781.       m_lowerOrder = false;
  782.     } else {
  783.       m_lowerOrder = v;
  784.     }
  785.   }
  786.   /**
  787.    * Computes the result of the kernel function for two instances.
  788.    *
  789.    * @param id1 the index of the first instance
  790.    * @param id2 the index of the second instance
  791.    * @param inst the instance corresponding to id1
  792.    * @return the result of the kernel function
  793.    */
  794.   private double kernel(int id1, int id2, Instance inst1) throws Exception {
  795.     double result = 0;
  796.     long key = -1;
  797.     int location = -1;
  798.     // we can only cache if we know the indexes
  799.     if (id1 >= 0) {
  800.       if (id1 > id2) {
  801. key = (long)id1 * m_alpha.length + id2;
  802.       } else {
  803. key = (long)id2 * m_alpha.length + id1;
  804.       }
  805.       if (key < 0) {
  806. throw new Exception("Cache overflow detected!");
  807.       }
  808.       location = (int)(key % m_keys.length);
  809.       if (m_keys[location] == (key + 1)) {
  810. return m_storage[location];
  811.       }
  812.     }
  813.     // we can do a fast dot product
  814.     Instance inst2 = m_data.instance(id2);
  815.     int n1 = inst1.numValues(); int n2 = inst2.numValues();
  816.     int classIndex = m_data.classIndex();
  817.     for (int p1 = 0, p2 = 0; p1 < n1 && p2 < n2;) {
  818.       int ind1 = inst1.index(p1); 
  819.       int ind2 = inst2.index(p2);
  820.       if (ind1 == ind2) {
  821. if (ind1 != classIndex) {
  822.   result += inst1.valueSparse(p1) * inst2.valueSparse(p2);
  823. }
  824. p1++; p2++;
  825.       } else if (ind1 > ind2) {
  826. p2++;
  827.       } else { 
  828. p1++;
  829.       }
  830.     }
  831.     
  832.     // Use lower order terms?
  833.     if (m_lowerOrder) {
  834.       result += 1.0;
  835.     }
  836.     // Rescale kernel?
  837.     if (m_rescale) {
  838.       result /= (double)m_data.numAttributes() - 1;
  839.     }      
  840.     
  841.     if (m_exponent != 1.0) {
  842.       result = Math.pow(result, m_exponent);
  843.     }
  844.     m_kernelEvals++;
  845.     
  846.     // store result in cache 
  847.     if (key != -1){
  848.       m_storage[location] = result;
  849.       m_keys[location] = (key + 1);
  850.     }
  851.     return result;
  852.   }
  853.   /**
  854.    * Examines instance.
  855.    *
  856.    * @param i2 index of instance to examine
  857.    * @return true if examination was successfull
  858.    * @exception Exception if something goes wrong
  859.    */
  860.   private boolean examineExample(int i2) throws Exception {
  861.     
  862.     double y2, alph2, F2;
  863.     int i1 = -1;
  864.     
  865.     y2 = m_class[i2];
  866.     alph2 = m_alpha[i2];
  867.     if (m_I0.contains(i2)) {
  868.       F2 = m_errors[i2];
  869.     } else {
  870.       F2 = SVMOutput(i2, m_data.instance(i2)) + m_b - y2;
  871.       m_errors[i2] = F2;
  872.       
  873.       // Update thresholds
  874.       if ((m_I1.contains(i2) || m_I2.contains(i2)) && (F2 < m_bUp)) {
  875. m_bUp = F2; m_iUp = i2;
  876.       } else if ((m_I3.contains(i2) || m_I4.contains(i2)) && (F2 > m_bLow)) {
  877. m_bLow = F2; m_iLow = i2;
  878.       }
  879.     }
  880.     // Check optimality using current bLow and bUp and, if
  881.     // violated, find an index i1 to do joint optimization
  882.     // with i2...
  883.     boolean optimal = true;
  884.     if (m_I0.contains(i2) || m_I1.contains(i2) || m_I2.contains(i2)) {
  885.       if (m_bLow - F2 > 2 * m_tol) {
  886. optimal = false; i1 = m_iLow;
  887.       }
  888.     }
  889.     if (m_I0.contains(i2) || m_I3.contains(i2) || m_I4.contains(i2)) {
  890.       if (F2 - m_bUp > 2 * m_tol) {
  891. optimal = false; i1 = m_iUp;
  892.       }
  893.     }
  894.     if (optimal) {
  895.       return false;
  896.     }
  897.     // For i2 unbound choose the better i1...
  898.     if (m_I0.contains(i2)) {
  899.       if (m_bLow - F2 > F2 - m_bUp) {
  900. i1 = m_iLow;
  901.       } else {
  902. i1 = m_iUp;
  903.       }
  904.     }
  905.     if (i1 == -1) {
  906.       throw new Exception("This should never happen!");
  907.     }
  908.     return takeStep(i1, i2, F2);
  909.   }
  910.   /**
  911.    * Method solving for the Lagrange multipliers for
  912.    * two instances.
  913.    *
  914.    * @param i1 index of the first instance
  915.    * @param i2 index of the second instance
  916.    * @return true if multipliers could be found
  917.    * @exception Exception if something goes wrong
  918.    */
  919.   private boolean takeStep(int i1, int i2, double F2) throws Exception {
  920.     double alph1, alph2, y1, y2, F1, s, L, H, k11, k12, k22, eta,
  921.       a1, a2, f1, f2, v1, v2, Lobj, Hobj, b1, b2, bOld;
  922.     // Don't do anything if the two instances are the same
  923.     if (i1 == i2) {
  924.       return false;
  925.     }
  926.     // Initialize variables
  927.     alph1 = m_alpha[i1]; alph2 = m_alpha[i2];
  928.     y1 = m_class[i1]; y2 = m_class[i2];
  929.     F1 = m_errors[i1];
  930.     s = y1 * y2;
  931.     // Find the constraints on a2
  932.     if (y1 != y2) {
  933.       L = Math.max(0, alph2 - alph1); 
  934.       H = Math.min(m_C, m_C + alph2 - alph1);
  935.     } else {
  936.       L = Math.max(0, alph1 + alph2 - m_C);
  937.       H = Math.min(m_C, alph1 + alph2);
  938.     }
  939.     if (L >= H) {       
  940.       return false;
  941.     }
  942.     // Compute second derivative of objective function
  943.     k11 = kernel(i1, i1, m_data.instance(i1));
  944.     k12 = kernel(i1, i2, m_data.instance(i1));
  945.     k22 = kernel(i2, i2, m_data.instance(i2));
  946.     eta = 2 * k12 - k11 - k22;
  947.     // Check if second derivative is negative
  948.     if (eta < 0) {
  949.       // Compute unconstrained maximum
  950.       a2 = alph2 - y2 * (F1 - F2) / eta;
  951.       // Compute constrained maximum
  952.       if (a2 < L) {
  953. a2 = L;
  954.       } else if (a2 > H) {
  955. a2 = H;
  956.       }
  957.     } else {
  958.       // Look at endpoints of diagonal
  959.       f1 = SVMOutput(i1, m_data.instance(i1));
  960.       f2 = SVMOutput(i2, m_data.instance(i2));
  961.       v1 = f1 + m_b - y1 * alph1 * k11 - y2 * alph2 * k12; 
  962.       v2 = f2 + m_b - y1 * alph1 * k12 - y2 * alph2 * k22; 
  963.       double gamma = alph1 + s * alph2;
  964.       Lobj = (gamma - s * L) + L - 0.5 * k11 * (gamma - s * L) * (gamma - s * L) - 
  965. 0.5 * k22 * L * L - s * k12 * (gamma - s * L) * L - 
  966. y1 * (gamma - s * L) * v1 - y2 * L * v2;
  967.       Hobj = (gamma - s * H) + H - 0.5 * k11 * (gamma - s * H) * (gamma - s * H) - 
  968. 0.5 * k22 * H * H - s * k12 * (gamma - s * H) * H - 
  969. y1 * (gamma - s * H) * v1 - y2 * H * v2;
  970.       if (Lobj > Hobj + m_eps) {
  971. a2 = L;
  972.       } else if (Lobj < Hobj - m_eps) {
  973. a2 = H;
  974.       } else {
  975. a2 = alph2;
  976.       }
  977.     }
  978.     if (Math.abs(a2 - alph2) < m_eps * (a2 + alph2 + m_eps)) {
  979.       return false;
  980.     }
  981.     // To prevent precision problems
  982.     if (a2 > m_C - m_Del * m_C) {
  983.       a2 = m_C;
  984.     } else if (a2 <= m_Del * m_C) {
  985.       a2 = 0;
  986.     }
  987.     // Recompute a1
  988.     a1 = alph1 + s * (alph2 - a2);
  989.     // To prevent precision problems
  990.     if (a1 > m_C - m_Del * m_C) {
  991.       a1 = m_C;
  992.     } else if (a1 <= m_Del * m_C) {
  993.       a1 = 0;
  994.     }
  995.     // Update sets
  996.     if (a1 > 0) {
  997.       m_supportVectors.insert(i1);
  998.     } else {
  999.       m_supportVectors.delete(i1);
  1000.     }
  1001.     if ((a1 > 0) && (a1 < m_C)) {
  1002.       m_I0.insert(i1);
  1003.     } else {
  1004.       m_I0.delete(i1);
  1005.     }
  1006.     if ((y1 == 1) && (a1 == 0)) {
  1007.       m_I1.insert(i1);
  1008.     } else {
  1009.       m_I1.delete(i1);
  1010.     }
  1011.     if ((y1 == -1) && (a1 == m_C)) {
  1012.       m_I2.insert(i1);
  1013.     } else {
  1014.       m_I2.delete(i1);
  1015.     }
  1016.     if ((y1 == 1) && (a1 == m_C)) {
  1017.       m_I3.insert(i1);
  1018.     } else {
  1019.       m_I3.delete(i1);
  1020.     }
  1021.     if ((y1 == -1) && (a1 == 0)) {
  1022.       m_I4.insert(i1);
  1023.     } else {
  1024.       m_I4.delete(i1);
  1025.     }
  1026.     if (a2 > 0) {
  1027.       m_supportVectors.insert(i2);
  1028.     } else {
  1029.       m_supportVectors.delete(i2);
  1030.     }
  1031.     if ((a2 > 0) && (a2 < m_C)) {
  1032.       m_I0.insert(i2);
  1033.     } else {
  1034.       m_I0.delete(i2);
  1035.     }
  1036.     if ((y2 == 1) && (a2 == 0)) {
  1037.       m_I1.insert(i2);
  1038.     } else {
  1039.       m_I1.delete(i2);
  1040.     }
  1041.     if ((y2 == -1) && (a2 == m_C)) {
  1042.       m_I2.insert(i2);
  1043.     } else {
  1044.       m_I2.delete(i2);
  1045.     }
  1046.     if ((y2 == 1) && (a2 == m_C)) {
  1047.       m_I3.insert(i2);
  1048.     } else {
  1049.       m_I3.delete(i2);
  1050.     }
  1051.     if ((y2 == -1) && (a2 == 0)) {
  1052.       m_I4.insert(i2);
  1053.     } else {
  1054.       m_I4.delete(i2);
  1055.     }
  1056.     
  1057.     // Update weight vector to reflect change a1 and a2, if linear SVM
  1058.     if (m_exponent == 1.0) {
  1059.       Instance inst1 = m_data.instance(i1);
  1060.       for (int p1 = 0; p1 < inst1.numValues(); p1++) {
  1061. if (inst1.index(p1) != m_data.classIndex()) {
  1062.   m_weights[inst1.index(p1)] += 
  1063.     y1 * (a1 - alph1) * inst1.valueSparse(p1);
  1064. }
  1065.       }
  1066.       Instance inst2 = m_data.instance(i2);
  1067.       for (int p2 = 0; p2 < inst2.numValues(); p2++) {
  1068. if (inst2.index(p2) != m_data.classIndex()) {
  1069.   m_weights[inst2.index(p2)] += 
  1070.     y2 * (a2 - alph2) * inst2.valueSparse(p2);
  1071. }
  1072.       }
  1073.     }
  1074.     // Update error cache using new Lagrange multipliers
  1075.     for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
  1076.       if ((j != i1) && (j != i2)) {
  1077. m_errors[j] += 
  1078.   y1 * (a1 - alph1) * kernel(i1, j, m_data.instance(i1)) + 
  1079.   y2 * (a2 - alph2) * kernel(i2, j, m_data.instance(i2));
  1080.       }
  1081.     }
  1082.     // Update error cache for i1 and i2
  1083.     m_errors[i1] += y1 * (a1 - alph1) * k11 + y2 * (a2 - alph2) * k12;
  1084.     m_errors[i2] += y1 * (a1 - alph1) * k12 + y2 * (a2 - alph2) * k22;
  1085.     // Update array with Lagrange multipliers
  1086.     m_alpha[i1] = a1;
  1087.     m_alpha[i2] = a2;
  1088.       
  1089.     // Update thresholds
  1090.     m_bLow = -Double.MAX_VALUE; m_bUp = Double.MAX_VALUE;
  1091.     m_iLow = -1; m_iUp = -1;
  1092.     for (int j = m_I0.getNext(-1); j != -1; j = m_I0.getNext(j)) {
  1093.       if (m_errors[j] < m_bUp) {
  1094. m_bUp = m_errors[j]; m_iUp = j;
  1095.       }
  1096.       if (m_errors[j] > m_bLow) {
  1097. m_bLow = m_errors[j]; m_iLow = j;
  1098.       }
  1099.     }
  1100.     if (!m_I0.contains(i1)) {
  1101.       if (m_I3.contains(i1) || m_I4.contains(i1)) {
  1102. if (m_errors[i1] > m_bLow) {
  1103.   m_bLow = m_errors[i1]; m_iLow = i1;
  1104.       } else {
  1105. if (m_errors[i1] < m_bUp) {
  1106.   m_bUp = m_errors[i1]; m_iUp = i1;
  1107. }
  1108.       }
  1109.     }
  1110.     if (!m_I0.contains(i2)) {
  1111.       if (m_I3.contains(i2) || m_I4.contains(i2)) {
  1112. if (m_errors[i2] > m_bLow) {
  1113.   m_bLow = m_errors[i2]; m_iLow = i2;
  1114. }
  1115.       } else {
  1116. if (m_errors[i2] < m_bUp) {
  1117.   m_bUp = m_errors[i2]; m_iUp = i2;
  1118. }
  1119.       }
  1120.     }
  1121.     if ((m_iLow == -1) || (m_iUp == -1)) {
  1122.       throw new Exception("This should never happen!");
  1123.     }
  1124.     // Made some progress.
  1125.     return true;
  1126.   }
  1127.   
  1128.   /**
  1129.    * Quick and dirty check whether the quadratic programming problem is solved.
  1130.    */
  1131.   private void checkClassifier() throws Exception {
  1132.     double sum = 0;
  1133.     for (int i = 0; i < m_alpha.length; i++) {
  1134.       if (m_alpha[i] > 0) {
  1135. sum += m_class[i] * m_alpha[i];
  1136.       }
  1137.     }
  1138.     System.err.println("Sum of y(i) * alpha(i): " + sum);
  1139.     for (int i = 0; i < m_alpha.length; i++) {
  1140.       double output = SVMOutput(i, m_data.instance(i));
  1141.       if (Utils.eq(m_alpha[i], 0)) {
  1142. if (Utils.sm(m_class[i] * output, 1)) {
  1143.   System.err.println("KKT condition 1 violated: " + m_class[i] * output);
  1144. }
  1145.       } 
  1146.       if (Utils.gr(m_alpha[i], 0) && Utils.sm(m_alpha[i], m_C)) {
  1147. if (!Utils.eq(m_class[i] * output, 1)) {
  1148.   System.err.println("KKT condition 2 violated: " + m_class[i] * output);
  1149. }
  1150.       } 
  1151.       if (Utils.eq(m_alpha[i], m_C)) {
  1152. if (Utils.gr(m_class[i] * output, 1)) {
  1153.   System.err.println("KKT condition 3 violated: " + m_class[i] * output);
  1154. }
  1155.       } 
  1156.     }
  1157.   }  
  1158.   /**
  1159.    * Main method for testing this class.
  1160.    */
  1161.   public static void main(String[] argv) {
  1162.     Classifier scheme;
  1163.     try {
  1164.       scheme = new SMO();
  1165.       System.out.println(Evaluation.evaluateModel(scheme, argv));
  1166.     } catch (Exception e) {
  1167.       System.err.println(e.getMessage());
  1168.     }
  1169.   }
  1170.