PaceRegression.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 19k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.  */
  15. /*
  16.  *    PaceRegression.java
  17.  *    Copyright (C) 2002 Yong Wang
  18.  */
  19. package weka.classifiers.functions.pace;
  20. import weka.classifiers.Classifier;
  21. import weka.classifiers.Evaluation;
  22. import java.io.*;
  23. import java.util.*;
  24. import weka.core.*;
  25. /**
  26.  * Class for building pace regression linear models and using them for
  27.  * prediction. <p>
  28.  * 
  29.  * Under regularity conditions, pace regression is provably optimal when
  30.  * the number of coefficients tends to infinity. It consists of a group of
  31.  * estimators that are either overall optimal or optimal under certain
  32.  * conditions. <p>
  33.  *
  34.  * The current work of the pace regression theory, and therefore also this
  35.  * implementation, do not handle: <p>
  36.  *
  37.  * - missing values <br>
  38.  * - non-binary nominal attributes <br>
  39.  * - the case that n - k is small where n is number of instances and k is  
  40.  *   number of coefficients (the threshold used in this implmentation is 20) 
  41.  * <p>
  42.  *  
  43.  * Valid options are:<p>
  44.  *
  45.  * -D <br>
  46.  * Produce debugging output. <p>
  47.  * -E estimator <br>
  48.  * The estimator can be one of the following: <br>
  49.  * <ul>
  50.  * <li>eb -- Empirical Bayes estimator for noraml mixture (default) <br>
  51.  * <li>nested -- Optimal nested model selector for normal mixture <br>
  52.  * <li>subset -- Optimal subset selector for normal mixture <br>
  53.  * <li>pace2 -- PACE2 for Chi-square mixture <br>
  54.  * <li>pace4 -- PACE4 for Chi-square mixture<br>
  55.  * <li>pace6 -- PACE6 for Chi-square mixture <br>
  56.  * <li>ols -- Ordinary least squares estimator <br>
  57.  * <li>aic -- AIC estimator <br>
  58.  * <li>bic -- BIC estimator <br>
  59.  * <li>ric -- RIC estimator <br>
  60.  * <li>olsc -- Ordinary least squares subset selector with a threshold <br>
  61.  * </ul>
  62.  * -S <threshold value <br>
  63.  * Threshold for the olsc estimator<p>
  64.  *
  65.  * <p>
  66.  * REFERENCES <p>
  67.  * 
  68.  * Wang, Y. (2000). "A new approach to fitting linear models in high
  69.  * dimensional spaces." PhD Thesis. Department of Computer Science,
  70.  * University of Waikato, New Zealand. <p>
  71.  * 
  72.  * Wang, Y. and Witten, I. H. (2002). "Modeling for optimal probability
  73.  * prediction." Proceedings of ICML'2002. Sydney. <p>
  74.  *
  75.  * @author Yong Wang (yongwang@cs.waikato.ac.nz)
  76.  * @author Gabi Schmidberger (gabi@cs.waikato.ac.nz)
  77.  * @version $Revision: 1.4 $ */
  78. public class PaceRegression extends Classifier implements OptionHandler,
  79.        WeightedInstancesHandler {
  80.   /** The model used */
  81.   Instances m_Model = null;
  82.   /** Array for storing coefficients of linear regression. */
  83.   private double[] m_Coefficients;
  84.   /** The index of the class attribute */
  85.   private int m_ClassIndex;
  86.   /** True if debug output will be printed */
  87.   private boolean m_Debug;
  88.   private static final int olsEstimator = 0;
  89.   private static final int ebEstimator = 1;
  90.   private static final int nestedEstimator = 2;
  91.   private static final int subsetEstimator = 3; 
  92.   private static final int pace2Estimator = 4; 
  93.   private static final int pace4Estimator = 5; 
  94.   private static final int pace6Estimator = 6; 
  95.   private static final int olscEstimator = 7;
  96.   private static final int aicEstimator = 8;
  97.   private static final int bicEstimator = 9;
  98.   private static final int ricEstimator = 10;
  99.   public static final Tag [] TAGS_ESTIMATOR = {
  100.     new Tag(olsEstimator, "Ordinary least squares"),
  101.     new Tag(ebEstimator, "Empirical Bayes"),
  102.     new Tag(nestedEstimator, "Nested model selector"),
  103.     new Tag(subsetEstimator, "Subset selector"),
  104.     new Tag(pace2Estimator, "PACE2"),
  105.     new Tag(pace4Estimator, "PACE4"),
  106.     new Tag(pace6Estimator, "PACE6"),
  107.     new Tag(olscEstimator, "Ordinary least squares selection"),
  108.     new Tag(aicEstimator, "AIC"),
  109.     new Tag(bicEstimator, "BIC"),
  110.     new Tag(ricEstimator, "RIC")
  111.   };
  112.   private int paceEstimator = ebEstimator;  
  113.   private double olscThreshold = 2;  // AIC
  114.   
  115.   /**
  116.    * Builds a pace regression model for the given data.
  117.    *
  118.    * @param data the training data to be used for generating the
  119.    * linear regression function
  120.    * @exception Exception if the classifier could not be built successfully
  121.    */
  122.   public void buildClassifier(Instances data) throws Exception {
  123.     //  Checks on data model and instances
  124.     try {
  125.     if (!data.classAttribute().isNumeric()) {
  126.       throw new UnsupportedClassTypeException("Class attribute has to be numeric"+
  127.       " for pace regression!");
  128.     }
  129.     } catch (UnassignedClassException e) {
  130.       System.err.println(data);
  131.       System.err.println(data.classIndex());
  132.     }
  133.     if (data.numInstances() == 0) {
  134.       throw new Exception("No instances in training file!");
  135.     }
  136.     if (data.checkForStringAttributes()) {
  137.       throw new UnsupportedAttributeTypeException("Can't handle string attributes!");
  138.     }
  139.     if (checkForNonBinary(data)) {
  140.       throw new UnsupportedAttributeTypeException("Can only deal with numeric and binary attributes!");
  141.     }
  142.     // check for missing data and throw exception if some are found
  143.     if (checkForMissing(data)) {
  144.       throw new NoSupportForMissingValuesException("Can't handle missing values!");
  145.     }
  146.     // n - k should be >= 20
  147.     if (data.numInstances() - data.numAttributes() < 20) {
  148.       throw new IllegalArgumentException("Not enough instances. Ratio of number of instances (n) to number of "
  149.                           + "attributes (k) is too small (n - k < 20).");
  150.     }
  151.     
  152.     /*
  153.      * initialize the following
  154.      */
  155.     m_Model = new Instances(data, 0);
  156.     m_ClassIndex = data.classIndex();
  157.     double[][] transformedDataMatrix = 
  158.     getTransformedDataMatrix(data, m_ClassIndex);
  159.     double[] classValueVector = data.attributeToDoubleArray(m_ClassIndex);
  160.     
  161.     m_Coefficients = null;
  162.     /* 
  163.      * Perform pace regression
  164.      */
  165.     m_Coefficients = pace(transformedDataMatrix, classValueVector);
  166.   }
  167.   /**
  168.    * pace regression
  169.    *
  170.    * @param matrix_X matrix with observations
  171.    * @param vector_Y vektor with class values
  172.    * @return vector with coefficients
  173.    * @exception Exception if pace regression cannot be done successfully
  174.    */
  175.   private double [] pace(double[][] matrix_X, double [] vector_Y) {
  176.     
  177.     PaceMatrix X = new PaceMatrix( matrix_X );
  178.     PaceMatrix Y = new PaceMatrix( vector_Y, vector_Y.length );
  179.     IntVector pvt = IntVector.seq(0, X.getColumnDimension()-1);
  180.     int n = X.getRowDimension();
  181.     int kr = X.getColumnDimension();
  182.     X.lsqrSelection( Y, pvt, 1 );
  183.     X.positiveDiagonal( Y, pvt );
  184.     
  185.     int k = pvt.size();
  186.     PaceMatrix sol = (PaceMatrix) Y.clone();
  187.     X.rsolve( sol, pvt, pvt.size() );
  188.     DoubleVector betaHat = sol.getColumn(0).unpivoting(pvt, kr); 
  189.     DoubleVector r = Y.getColumn( pvt.size(), n-1, 0);
  190.     double sde = Math.sqrt(r.sum2() / r.size());
  191.     
  192.     DoubleVector aHat = Y.getColumn( 0, pvt.size()-1, 0).times( 1./sde );
  193.     DoubleVector aTilde = null;
  194.     switch( paceEstimator) {
  195.     case ebEstimator: 
  196.     case nestedEstimator:
  197.     case subsetEstimator:
  198.       NormalMixture d = new NormalMixture();
  199.       d.fit( aHat, MixtureDistribution.NNMMethod ); 
  200.       if( paceEstimator == ebEstimator ) 
  201. aTilde = d.empiricalBayesEstimate( aHat );
  202.       else if( paceEstimator == ebEstimator ) 
  203. aTilde = d.subsetEstimate( aHat );
  204.       else aTilde = d.nestedEstimate( aHat );
  205.       break;
  206.     case pace2Estimator: 
  207.     case pace4Estimator:
  208.     case pace6Estimator:
  209.       DoubleVector AHat = aHat.square();
  210.       ChisqMixture dc = new ChisqMixture();
  211.       dc.fit( AHat, MixtureDistribution.NNMMethod ); 
  212.       DoubleVector ATilde; 
  213.       if( paceEstimator == pace6Estimator ) 
  214. ATilde = dc.pace6( AHat );
  215.       else if( paceEstimator == pace2Estimator ) 
  216. ATilde = dc.pace2( AHat );
  217.       else ATilde = dc.pace4( AHat );
  218.       aTilde = ATilde.sqrt().times( aHat.sign() );
  219.       break;
  220.     case olsEstimator: 
  221.       aTilde = aHat.copy();
  222.       break;
  223.     case aicEstimator: 
  224.     case bicEstimator:
  225.     case ricEstimator: 
  226.     case olscEstimator:
  227.       if(paceEstimator == aicEstimator) olscThreshold = 2;
  228.       else if(paceEstimator == bicEstimator) olscThreshold = Math.log( n );
  229.       else if(paceEstimator == ricEstimator) olscThreshold = 2*Math.log( kr );
  230.       aTilde = aHat.copy();
  231.       for( int i = 0; i < aTilde.size(); i++ )
  232. if( Math.abs(aTilde.get(i)) < Math.sqrt(olscThreshold) ) 
  233.   aTilde.set(i, 0);
  234.     }
  235.     PaceMatrix YTilde = new PaceMatrix((new PaceMatrix(aTilde)).times( sde ));
  236.     X.rsolve( YTilde, pvt, pvt.size() );
  237.     DoubleVector betaTilde = YTilde.getColumn(0).unpivoting( pvt, kr );
  238.     
  239.     return betaTilde.getArrayCopy();
  240.   }
  241.   /**
  242.    * Checks if instances have a missing value.
  243.    * @param data the data set
  244.    * @return true if missing value is present in data set
  245.    */
  246.   public boolean checkForMissing(Instances data) {
  247.     for (int i = 0; i < data.numInstances(); i++) {
  248.       Instance inst = data.instance(i);
  249.       for (int j = 0; j < data.numAttributes(); j++) {
  250. if (inst.isMissing(j)) {
  251.   return true;
  252. }
  253.       }
  254.     }
  255.     return false;
  256.   }
  257.   /**
  258.    * Checks if an instance has a missing value.
  259.    * @param instance the instance
  260.    * @return true if missing value is present
  261.    */
  262.   public boolean checkForMissing(Instance instance, Instances model) {
  263.     for (int j = 0; j < instance.numAttributes(); j++) {
  264.       if (j != model.classIndex()) {
  265. if (instance.isMissing(j)) {
  266.   return true;
  267. }
  268.       }
  269.     }
  270.     return false;
  271.   }
  272.   /**
  273.    * Checks if any of the nominal attributes is non-binary.
  274.    * @param data the data set
  275.    * @return true if non binary attribute is present
  276.    */
  277.   public boolean checkForNonBinary(Instances data) {
  278.     for (int i = 0; i < data.numAttributes(); i++) {
  279.       if (data.attribute(i).isNominal()) {
  280. if (data.attribute(i).numValues() != 2)
  281.   return true;
  282.       }           
  283.     }
  284.     return false;
  285.   }
  286.   /**
  287.    * Transforms dataset into a two-dimensional array.
  288.    *
  289.    * @param data dataset
  290.    * @param classIndex index of the class attribute
  291.    */
  292.   private double [][] getTransformedDataMatrix(Instances data, 
  293.        int classIndex) {
  294.     int numInstances = data.numInstances();
  295.     int numAttributes = data.numAttributes();
  296.     int middle = classIndex;
  297.     if (middle < 0) { 
  298.       middle = numAttributes;
  299.     }
  300.     double[][] result = new double[numInstances]
  301.     [numAttributes];
  302.     for (int i = 0; i < numInstances; i++) {
  303.       Instance inst = data.instance(i);
  304.       
  305.       result[i][0] = 1.0;
  306.       // the class value (lies on index middle) is left out
  307.       for (int j = 0; j < middle; j++) {
  308. result[i][j + 1] = inst.value(j);
  309.       }
  310.       for (int j = middle + 1; j < numAttributes; j++) {
  311. result[i][j] = inst.value(j);
  312.       }
  313.     }
  314.     return result;
  315.   }
  316.   /**
  317.    * Classifies the given instance using the linear regression function.
  318.    *
  319.    * @param instance the test instance
  320.    * @return the classification
  321.    * @exception Exception if classification can't be done successfully
  322.    */
  323.   public double classifyInstance(Instance instance) throws Exception {
  324.     
  325.     if (m_Coefficients == null) {
  326.       throw new Exception("Pace Regression: No model built yet.");
  327.     }
  328.     
  329.     // check for missing data and throw exception if some are found
  330.     if (checkForMissing(instance, m_Model)) {
  331.       throw new NoSupportForMissingValuesException("Can't handle missing values!");
  332.     }
  333.     // Calculate the dependent variable from the regression model
  334.     return regressionPrediction(instance,
  335. m_Coefficients);
  336.   }
  337.   /**
  338.    * Outputs the linear regression model as a string.
  339.    */
  340.   public String toString() {
  341.     if (m_Coefficients == null) {
  342.       return "Pace Regression: No model built yet.";
  343.     }
  344.     //    try {
  345.     StringBuffer text = new StringBuffer();
  346.     
  347.     text.append("nPace Regression Modelnn");
  348.     
  349.     text.append(m_Model.classAttribute().name()+" =nn");
  350.     int index = 0;   
  351.     
  352.     text.append(Utils.doubleToString(m_Coefficients[0],
  353.      12, 4) );
  354.     
  355.     for (int i = 1; i < m_Coefficients.length; i++) {
  356.       
  357.       // jump over the class attribute
  358.       if (index == m_ClassIndex) index++;
  359.       
  360.       if (m_Coefficients[i] != 0.0) {
  361. // output a coefficient if unequal zero
  362. text.append(" +n");
  363. text.append(Utils.doubleToString(m_Coefficients[i], 12, 4)
  364.     + " * ");
  365. text.append(m_Model.attribute(index).name());
  366.       }
  367.       index ++;
  368.     }
  369.     
  370.     return text.toString();
  371.   }
  372.   
  373.   /**
  374.    * Returns an enumeration describing the available options.
  375.    *
  376.    * @return an enumeration of all the available options.
  377.    */
  378.   public Enumeration listOptions() {
  379.     
  380.     Vector newVector = new Vector(2);
  381.     newVector.addElement(new Option("tProduce debugging output.n"
  382.     + "t(default no debugging output)",
  383.     "D", 0, "-D"));
  384.     newVector.addElement(new Option("tThe estimator can be one of the following:n" + 
  385.     "ttebtEmpirical Bayes(default)n" +
  386.     "ttnestedtOptimal nested modeln" + 
  387.     "ttsubsettOptimal subsetn" +
  388.     "ttpace2tPACE2n" +
  389.     "ttpace4tPACE4n" +
  390.     "ttpace6tPACE6nn" + 
  391.     "ttolstOrdinary least squaresn" +  
  392.     "ttaictAICn" +  
  393.     "ttbictBICn" +  
  394.     "ttrictRICn" +  
  395.     "ttolsctOLSC", 
  396.     "E", 0, "-E <estimator>"));
  397.     newVector.addElement(new Option("tThreshold value for the OLSC estimator",
  398.     "S", 0, "-S <threshold value>"));
  399.     return newVector.elements();
  400.   }
  401.   /**
  402.    * Parses a given list of options. <p>
  403.    * @param options the list of options as an array of strings
  404.    * @exception Exception if an option is not supported
  405.    */
  406.   public void setOptions(String[] options) throws Exception {
  407.     
  408.     setDebug(Utils.getFlag('D', options));
  409.     String estimator = Utils.getOption('E', options);
  410.     if ( estimator.equals("ols") ) paceEstimator = olsEstimator;
  411.     else if ( estimator.equals("olsc") ) paceEstimator = olscEstimator;
  412.     else if( estimator.equals("eb") || estimator.equals("") ) 
  413.       paceEstimator = ebEstimator;
  414.     else if ( estimator.equals("nested") ) paceEstimator = nestedEstimator;
  415.     else if ( estimator.equals("subset") ) paceEstimator = subsetEstimator;
  416.     else if ( estimator.equals("pace2") ) paceEstimator = pace2Estimator; 
  417.     else if ( estimator.equals("pace4") ) paceEstimator = pace4Estimator;
  418.     else if ( estimator.equals("pace6") ) paceEstimator = pace6Estimator;
  419.     else if ( estimator.equals("aic") ) paceEstimator = aicEstimator;
  420.     else if ( estimator.equals("bic") ) paceEstimator = bicEstimator;
  421.     else if ( estimator.equals("ric") ) paceEstimator = ricEstimator;
  422.     else throw new WekaException("unknown estimator " + estimator + 
  423.  " for -E option" );
  424.     String string = Utils.getOption('S', options);
  425.     if( ! string.equals("") ) olscThreshold = Double.parseDouble( string );
  426.     
  427.   }
  428.   /**
  429.    * Returns the coefficients for this linear model.
  430.    */
  431.   public double[] coefficients() {
  432.     double[] coefficients = new double[m_Coefficients.length];
  433.     for (int i = 0; i < coefficients.length; i++) {
  434.       coefficients[i] = m_Coefficients[i];
  435.     }
  436.     return coefficients;
  437.   }
  438.   /**
  439.    * Gets the current settings of the classifier.
  440.    *
  441.    * @return an array of strings suitable for passing to setOptions
  442.    */
  443.   public String [] getOptions() {
  444.     String [] options = new String [6];
  445.     int current = 0;
  446.     if (getDebug()) {
  447.       options[current++] = "-D";
  448.     }
  449.     options[current++] = "-E";
  450.     switch (paceEstimator) {
  451.     case olsEstimator: options[current++] = "ols";
  452.       break;
  453.     case olscEstimator: options[current++] = "olsc";
  454.       options[current++] = "-S";
  455.       options[current++] = "" + olscThreshold;
  456.       break;
  457.     case ebEstimator: options[current++] = "eb";
  458.       break;
  459.     case nestedEstimator: options[current++] = "nested";
  460.       break;
  461.     case subsetEstimator: options[current++] = "subset";
  462.       break;
  463.     case pace2Estimator: options[current++] = "pace2";
  464.       break; 
  465.     case pace4Estimator: options[current++] = "pace4";
  466.       break;
  467.     case pace6Estimator: options[current++] = "pace6";
  468.       break;
  469.     case aicEstimator: options[current++] = "aic";
  470.       break;
  471.     case bicEstimator: options[current++] = "bic";
  472.       break;
  473.     case ricEstimator: options[current++] = "ric";
  474.       break;
  475.     }
  476.     while (current < options.length) {
  477.       options[current++] = "";
  478.     }
  479.     return options;
  480.   }
  481.   
  482.   /**
  483.    * Get the number of coefficients used in the model
  484.    *
  485.    * @return the number of coefficients
  486.    */
  487.   public int numParameters()
  488.   {
  489.     return m_Coefficients.length-1;
  490.   }
  491.   /**
  492.    * Controls whether debugging output will be printed
  493.    *
  494.    * @param debug true if debugging output should be printed
  495.    */
  496.   public void setDebug(boolean debug) {
  497.     m_Debug = debug;
  498.   }
  499.   /**
  500.    * Controls whether debugging output will be printed
  501.    *
  502.    * @param debug true if debugging output should be printed
  503.    */
  504.   public boolean getDebug() {
  505.     return m_Debug;
  506.   }
  507.   /**
  508.    * Gets the estimator
  509.    *
  510.    * @return the estimator
  511.    */
  512.   public SelectedTag getEstimator() {
  513.     return new SelectedTag(paceEstimator, TAGS_ESTIMATOR);
  514.   }
  515.   
  516.   /**
  517.    * Sets the estimator.
  518.    *
  519.    * @param estimator the new estimator
  520.    */
  521.   public void setEstimator(SelectedTag estimator) {
  522.     
  523.     if (estimator.getTags() == TAGS_ESTIMATOR) {
  524.       paceEstimator = estimator.getSelectedTag().getID();
  525.     }
  526.   }
  527.   /**
  528.    * Set threshold for the olsc estimator
  529.    *
  530.    * @param threshold the threshold for the olsc estimator
  531.    */
  532.   public void setThreshold(double newThreshold) {
  533.     olscThreshold = newThreshold;
  534.   }
  535.   /**
  536.    * Gets the threshold for olsc estimator
  537.    *
  538.    * @return the threshold
  539.    */
  540.   public double getThreshold() {
  541.     return olscThreshold;
  542.   }
  543.   /**
  544.    * Calculate the dependent value for a given instance for a
  545.    * given regression model.
  546.    *
  547.    * @param transformedInstance the input instance
  548.    * @param selectedAttributes an array of flags indicating which 
  549.    * attributes are included in the regression model
  550.    * @param coefficients an array of coefficients for the regression
  551.    * model
  552.    * @return the regression value for the instance.
  553.    * @exception Exception if the class attribute of the input instance
  554.    * is not assigned
  555.    */
  556.   private double regressionPrediction(Instance transformedInstance,
  557.       double [] coefficients) 
  558.     throws Exception {
  559.     int column = 0;
  560.     double result = coefficients[column];
  561.     for (int j = 0; j < transformedInstance.numAttributes(); j++) {
  562.       if (m_ClassIndex != j) {
  563. column++;
  564. result += coefficients[column] * transformedInstance.value(j);
  565.       }
  566.     }
  567.     
  568.     return result;
  569.   }
  570.   /**
  571.    * Generates a linear regression function predictor.
  572.    *
  573.    * @param String the options
  574.    */
  575.   public static void main(String argv[]) {
  576.     
  577.     Classifier scheme;
  578.     try {
  579.       scheme = new PaceRegression();
  580.       System.out.println(Evaluation.evaluateModel(scheme, argv));
  581.     } catch (Exception e) {
  582.       e.printStackTrace();
  583.       // System.out.println(e.getMessage());
  584.     }
  585.   }
  586. }