Logistic.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 12k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    Logistic.java
  18.  *    Copyright (C) 2002 Eibe Frank
  19.  *
  20.  */
  21. package weka.classifiers.functions;
  22. import weka.classifiers.meta.LogitBoost;
  23. import weka.classifiers.functions.LinearRegression;
  24. import weka.classifiers.Evaluation;
  25. import weka.classifiers.DistributionClassifier;
  26. import weka.classifiers.Classifier;
  27. import weka.core.UnsupportedClassTypeException;
  28. import weka.core.Instances;
  29. import weka.core.Instance;
  30. import weka.core.OptionHandler;
  31. import weka.core.SelectedTag;
  32. import weka.core.Utils;
  33. import weka.core.Attribute;
  34. import weka.core.Option;
  35. import weka.core.UnsupportedAttributeTypeException;
  36. import weka.filters.unsupervised.attribute.NominalToBinary;
  37. import weka.filters.unsupervised.attribute.ReplaceMissingValues;
  38. import weka.filters.unsupervised.attribute.Remove;
  39. import weka.filters.Filter;
  40. import java.util.Enumeration;
  41. import java.util.Vector;
  42. /**
  43.  * Implements linear logistic regression using LogitBoost and
  44.  * LinearRegression.<p>
  45.  *
  46.  * Missing values are replaced using ReplaceMissingValues, and
  47.  * nominal attributes are transformed into numeric attributes using
  48.  * NominalToBinary.<p>
  49.  *
  50.  * -P precision <br>
  51.  * Set the precision of stopping criterion based on average loglikelihood.
  52.  * (default 1.0e-13) <p>
  53.  *
  54.  * -R ridge <br>
  55.  * Set the ridge parameter for the linear regression models.
  56.  * (default 1.0e-8)<p>
  57.  *
  58.  * -M num <br>
  59.  * Set the maximum number of iterations.
  60.  * (default 200)<p>
  61.  *
  62.  * @author Eibe Frank (eibe@cs.waikato.ac.nz)
  63.  * @version $Revision: 1.22 $ 
  64.  */
  65. public class Logistic extends DistributionClassifier 
  66.   implements OptionHandler {
  67.   /* The coefficients */
  68.   private double[][] m_Coefficients = null;
  69.   /* The index of the class */
  70.   private int m_ClassIndex = -1;
  71.   /* An attribute filter */
  72.   private Remove m_AttFilter = null;
  73.   /* The header info */
  74.   private Instances m_Header = null;
  75.     
  76.   /** The filter used to make attributes numeric. */
  77.   private NominalToBinary m_NominalToBinary = null;
  78.   
  79.   /** The filter used to get rid of missing values. */
  80.   private ReplaceMissingValues m_ReplaceMissingValues = null;
  81.     
  82.   /** The ridge parameter. */
  83.   private double m_Ridge = 1e-8;
  84.   /** The precision parameter */   
  85.   private double m_Precision = 1.0e-13;
  86.   
  87.   /** The maximum number of iterations. */
  88.   private int m_MaxIts = 200;
  89.     
  90.   /**
  91.    * Returns an enumeration describing the available options.
  92.    *
  93.    * @return an enumeration of all the available options.
  94.    */
  95.   public Enumeration listOptions() {
  96.     
  97.     Vector newVector = new Vector(3);
  98.     newVector.addElement(new Option("tSet the precision of stopping criterion based onn" + 
  99.     "tchange in average loglikelihood (default 1.0e-13).",
  100.     "P", 1, "-P <precision>"));
  101.     newVector.addElement(new Option("tSet the ridge for the linear regression models (default 1.0e-8).",
  102.     "R", 1, "-R <ridge>"));
  103.     newVector.addElement(new Option("tSet the maximum number of iterations (default 200).",
  104.     "M", 1, "-M <number>"));
  105.     return newVector.elements();
  106.   }
  107.   
  108.   /**
  109.    * Parses a given list of options. Valid options are:<p>
  110.    *
  111.    * -P precision <br>
  112.    * Set the precision of stopping criterion based on average loglikelihood.
  113.    * (default 1.0e-13) <p>
  114.    *
  115.    * -R ridge <br>
  116.    * Set the ridge parameter for the linear regression models.
  117.    * (default 1.0e-8)<p>
  118.    *
  119.    * -M num <br>
  120.    * Set the maximum number of iterations.
  121.    * (default 200)<p>
  122.    *
  123.    * @param options the list of options as an array of strings
  124.    * @exception Exception if an option is not supported
  125.    */
  126.   public void setOptions(String[] options) throws Exception {
  127.     
  128.     String precisionString = Utils.getOption('P', options);
  129.     if (precisionString.length() != 0) 
  130.       m_Precision = Double.parseDouble(precisionString);
  131.     else 
  132.       m_Precision = 1.0e-13;
  133.       
  134.     String ridgeString = Utils.getOption('R', options);
  135.     if (ridgeString.length() != 0) 
  136.       m_Ridge = Double.parseDouble(ridgeString);
  137.     else 
  138.       m_Ridge = 1.0e-8;
  139.       
  140.     String maxItsString = Utils.getOption('M', options);
  141.     if (maxItsString.length() != 0) 
  142.       m_MaxIts = Integer.parseInt(maxItsString);
  143.     else 
  144.       m_MaxIts = 200;
  145.   }
  146.   
  147.   /**
  148.    * Gets the current settings of the classifier.
  149.    *
  150.    * @return an array of strings suitable for passing to setOptions
  151.    */
  152.   public String [] getOptions() {
  153.     
  154.     String [] options = new String [6];
  155.     int current = 0;
  156.     
  157.     options[current++] = "-P";
  158.     options[current++] = ""+m_Precision;
  159.     options[current++] = "-R";
  160.     options[current++] = ""+m_Ridge;
  161.     options[current++] = "-M";
  162.     options[current++] = ""+m_MaxIts;
  163.     
  164.     while (current < options.length) 
  165.       options[current++] = "";
  166.     return options;
  167.   }
  168.   
  169.   /**
  170.    * Builds the model.
  171.    */
  172.   public void buildClassifier(Instances data) throws Exception {
  173.     if (data.classAttribute().type() != Attribute.NOMINAL) {
  174.       throw new UnsupportedClassTypeException("Class attribute must be nominal.");
  175.     }
  176.     if (data.checkForStringAttributes()) {
  177.       throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
  178.     }
  179.     data = new Instances(data);
  180.     data.deleteWithMissingClass();
  181.     if (data.numInstances() == 0) {
  182.       throw new Exception("No train instances without missing class value!");
  183.     }
  184.     m_ReplaceMissingValues = new ReplaceMissingValues();
  185.     m_ReplaceMissingValues.setInputFormat(data);
  186.     data = Filter.useFilter(data, m_ReplaceMissingValues);
  187.     m_NominalToBinary = new NominalToBinary();
  188.     m_NominalToBinary.setInputFormat(data);
  189.     data = Filter.useFilter(data, m_NominalToBinary);
  190.     // Find attributes that should be deleted because of
  191.     // zero variance
  192.     int[] indices = new int[data.numAttributes() - 1];
  193.     int numDeleted = 0;
  194.     for (int j = 0; j < data.numAttributes(); j++) {
  195.       if (j != data.classIndex()) {
  196.         double var = data.variance(j);
  197. if (var == 0) {
  198.   indices[numDeleted++] = j;
  199. }
  200.       }
  201.     }
  202.     int[] temp = new int[numDeleted];
  203.     System.arraycopy(indices, 0, temp, 0, numDeleted);
  204.     indices = temp;
  205.     // Remove useless attributes
  206.     m_AttFilter = new Remove();
  207.     m_AttFilter.setAttributeIndicesArray(indices);
  208.     m_AttFilter.setInvertSelection(false);
  209.     m_AttFilter.setInputFormat(data);
  210.     data = Filter.useFilter(data, m_AttFilter);
  211.     // Set class index
  212.     m_ClassIndex = data.classIndex();
  213.     // Standardize data
  214.     double[][] values = 
  215.       new double[data.numInstances()][data.numAttributes()];
  216.     double[] means = new double[data.numAttributes()];
  217.     double[] stdDevs = new double[data.numAttributes()];
  218.     for (int j = 0; j < data.numAttributes(); j++) {
  219.       if (j != data.classIndex()) {
  220. means[j] = data.meanOrMode(j);
  221. stdDevs[j] = Math.sqrt(data.variance(j));
  222. for (int i = 0; i < data.numInstances(); i++) {
  223.   values[i][j] = (data.instance(i).value(j) - means[j]) / 
  224.     stdDevs[j];
  225. }
  226.       } else {
  227. for (int i = 0; i < data.numInstances(); i++) {
  228.   values[i][j] = data.instance(i).value(j);
  229. }
  230.       }
  231.     }
  232.     Instances newData = new Instances(data, data.numInstances());
  233.     for (int i = 0; i < data.numInstances(); i++) {
  234.       newData.add(new Instance(data.instance(i).weight(), values[i]));
  235.     }
  236.     // Use LogitBoost to build model
  237.     LogitBoost boostedModel = new LogitBoost();
  238.     boostedModel.setLikelihoodThreshold(m_Precision);
  239.     boostedModel.setMaxIterations(m_MaxIts);
  240.     LinearRegression lr = new LinearRegression();
  241.     lr.setEliminateColinearAttributes(false);
  242.     lr.setAttributeSelectionMethod(new SelectedTag(LinearRegression.
  243.    SELECTION_NONE,
  244.    LinearRegression.
  245.    TAGS_SELECTION));
  246.     lr.turnChecksOff();
  247.     lr.setRidge(m_Ridge);
  248.     boostedModel.setClassifier(lr);
  249.     boostedModel.buildClassifier(newData);
  250.     // Extract coefficients
  251.     Classifier[][] models = boostedModel.classifiers();
  252.     m_Coefficients = new double[newData.numClasses()]
  253.       [newData.numAttributes() + 1];
  254.     for (int j = 0; j < newData.numClasses(); j++) {
  255.       for (int i = 0; i < models[j].length; i++) {
  256. double[] locCoefficients = 
  257.   ((LinearRegression)models[j][i]).coefficients();
  258. for (int k = 0; k <= newData.numAttributes(); k++) {
  259.   if (k != newData.classIndex()) {
  260.     m_Coefficients[j][k] += locCoefficients[k];
  261.   }
  262. }
  263.       }
  264.     }
  265.    
  266.     // Convert coefficients into original scale
  267.     for(int j = 0; j < data.numClasses(); j++){
  268.       for(int i = 0; i < data.numAttributes(); i++) {
  269. if ((i != newData.classIndex()) &&
  270.     (stdDevs[i] > 0)) {
  271.   m_Coefficients[j][i] /= stdDevs[i];
  272.   m_Coefficients[j][data.numAttributes()] -= 
  273.     m_Coefficients[j][i] * means[i];
  274. }
  275.       }
  276.     }
  277.     m_Header = new Instances(data, 0);
  278.   }
  279.   /**
  280.    * Classifies an instance.
  281.    */
  282.   public double[] distributionForInstance(Instance inst) 
  283.     throws Exception {
  284.     // Filter instance
  285.     m_ReplaceMissingValues.input(inst);
  286.     inst = m_ReplaceMissingValues.output();
  287.     m_NominalToBinary.input(inst);
  288.     inst = m_NominalToBinary.output();
  289.     m_AttFilter.input(inst);
  290.     m_AttFilter.batchFinished();
  291.     inst = m_AttFilter.output();
  292.     // Compute prediction
  293.     double[] preds = new double[m_Coefficients.length];
  294.     for (int j = 0; j < inst.numClasses(); j++) {
  295.       for (int i = 0; i < inst.numAttributes(); i++) {
  296. if (i != inst.classIndex()) {
  297.   preds[j] += inst.value(i) * m_Coefficients[j][i];
  298. }
  299.       }
  300.       preds[j] += m_Coefficients[j][inst.numAttributes()];
  301.     }
  302.     return probs(preds);
  303.   }
  304.   /**
  305.    * Computes probabilities from F scores
  306.    */
  307.   private double[] probs(double[] Fs) {
  308.     double maxF = -Double.MAX_VALUE;
  309.     for (int i = 0; i < Fs.length; i++) {
  310.       if (Fs[i] > maxF) {
  311. maxF = Fs[i];
  312.       }
  313.     }
  314.     double sum = 0;
  315.     double[] probs = new double[Fs.length];
  316.     for (int i = 0; i < Fs.length; i++) {
  317.       probs[i] = Math.exp(Fs[i] - maxF);
  318.       sum += probs[i];
  319.     }
  320.     Utils.normalize(probs, sum);
  321.     return probs;
  322.   }
  323.   /**
  324.    * Prints the model.
  325.    */
  326.   public String toString() {
  327.     if (m_Coefficients == null) {
  328.       return "No model has been built yet!";
  329.     } 
  330.     StringBuffer text = new StringBuffer();
  331.     for (int j = 0; j < m_Coefficients.length; j++) {
  332.       text.append("nModel for class: " + 
  333.   m_Header.classAttribute().value(j) + "nn");
  334.       for (int i = 0; i < m_Coefficients[j].length; i++) {
  335. if (i != m_ClassIndex) {
  336.   if (i > 0) {
  337.     text.append(" + ");
  338.   } else {
  339.     text.append("   ");
  340.   }
  341.   text.append(Utils.doubleToString(m_Coefficients[j][i], 12, 4));
  342.   if (i < m_Coefficients[j].length - 1) {
  343.     text.append(" * " 
  344. + m_Header.attribute(i).name() + "n");
  345.   }
  346. }
  347.       }
  348.       text.append("n");
  349.     }
  350.     return text.toString();
  351.   }
  352.   
  353.   /**
  354.    * Get the value of MaxIts.
  355.    *
  356.    * @return Value of MaxIts.
  357.    */
  358.   public int getMaxIts() {
  359.     
  360.     return m_MaxIts;
  361.   }
  362.   
  363.   /**
  364.    * Set the value of MaxIts.
  365.    *
  366.    * @param newMaxIts Value to assign to MaxIts.
  367.    */
  368.   public void setMaxIts(int newMaxIts) {
  369.     
  370.     m_MaxIts = newMaxIts;
  371.   }
  372.   
  373.   /**
  374.    * Sets the precision of stopping criterion in Newton method.
  375.    *
  376.    * @param precision the precision
  377.    */
  378.   public void setPrecision(double precision) {
  379.     m_Precision = precision;
  380.   }
  381.     
  382.   /**
  383.    * Gets the precision of stopping criterion in Newton method.
  384.    *
  385.    * @return the precision
  386.    */
  387.   public double getPrecision() {
  388.     return m_Precision;
  389.   }
  390.   /**
  391.    * Sets the ridge parameter.
  392.    *
  393.    * @param ridge the ridge
  394.    */
  395.   public void setRidge(double ridge) {
  396.     m_Ridge = ridge;
  397.   }
  398.     
  399.   /**
  400.    * Gets the ridge parameter.
  401.    *
  402.    * @return the ridge
  403.    */
  404.   public double getRidge() {
  405.     return m_Ridge;
  406.   }
  407.   /**
  408.    * Main method for testing this class.
  409.    */
  410.   public static void main(String[] argv) {
  411.     try {
  412.       System.out.println(Evaluation.evaluateModel(new Logistic(), argv));
  413.     } catch (Exception e) {
  414.       e.printStackTrace();
  415.       System.err.println(e.getMessage());
  416.     }
  417.   }
  418. }