CfsSubsetEval.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 25k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    CfsSubsetEval.java
  18.  *    Copyright (C) 1999 Mark Hall
  19.  *
  20.  */
  21. package  weka.attributeSelection;
  22. import  java.io.*;
  23. import  java.util.*;
  24. import  weka.core.*;
  25. import  weka.classifiers.*;
  26. import  weka.filters.*;
  27. /** 
  28.  * CFS attribute subset evaluator.
  29.  * For more information see: <p>
  30.  *
  31.  * Hall, M. A. (1998). Correlation-based Feature Subset Selection for Machine 
  32.  * Learning. Thesis submitted in partial fulfilment of the requirements of the
  33.  * degree of Doctor of Philosophy at the University of Waikato. <p>
  34.  *
  35.  * Valid options are:
  36.  *
  37.  * -M <br>
  38.  * Treat missing values as a seperate value. <p>
  39.  * 
  40.  * -L <br>
  41.  * Include locally predictive attributes. <p>
  42.  *
  43.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  44.  * @version $Revision: 1.14 $
  45.  */
  46. public class CfsSubsetEval
  47.   extends SubsetEvaluator
  48.   implements OptionHandler
  49. {
  50.   /** The training instances */
  51.   private Instances m_trainInstances;
  52.   /** Discretise attributes when class in nominal */
  53.   private DiscretizeFilter m_disTransform;
  54.   /** The class index */
  55.   private int m_classIndex;
  56.   /** Is the class numeric */
  57.   private boolean m_isNumeric;
  58.   /** Number of attributes in the training data */
  59.   private int m_numAttribs;
  60.   /** Number of instances in the training data */
  61.   private int m_numInstances;
  62.   /** Treat missing values as seperate values */
  63.   private boolean m_missingSeperate;
  64.   /** Include locally predicitive attributes */
  65.   private boolean m_locallyPredictive;
  66.   /** Holds the matrix of attribute correlations */
  67.   private Matrix m_corr_matrix;
  68.   /** Standard deviations of attributes (when using pearsons correlation) */
  69.   private double[] m_std_devs;
  70.   /** Threshold for admitting locally predictive features */
  71.   private double m_c_Threshold;
  72.   /**
  73.    * Returns a string describing this attribute evaluator
  74.    * @return a description of the evaluator suitable for
  75.    * displaying in the explorer/experimenter gui
  76.    */
  77.   public String globalInfo() {
  78.     return "CfsSubsetEval :nnEvaluates the worth of a subset of attributes "
  79.       +"by considering the individual predictive ability of each feature "
  80.       +"along with the degree of redundancy between them.nn"
  81.       +"Subsets of features that are highly correlated with the class "
  82.       +"while having low intercorrelation are preferred.n";
  83.   }
  84.   /**
  85.    * Constructor
  86.    */
  87.   public CfsSubsetEval () {
  88.     resetOptions();
  89.   }
  90.   /**
  91.    * Returns an enumeration describing the available options
  92.    * @return an enumeration of all the available options
  93.    *
  94.    **/
  95.   public Enumeration listOptions () {
  96.     Vector newVector = new Vector(3);
  97.     newVector.addElement(new Option("tTreat missing values as a seperate" 
  98.     + "ntvalue.", "M", 0, "-M"));
  99.     newVector.addElement(new Option("tInclude locally predictive attributes" 
  100.     + ".", "L", 0, "-L"));
  101.     return  newVector.elements();
  102.   }
  103.   /**
  104.    * Parses and sets a given list of options. <p>
  105.    *
  106.    * Valid options are:
  107.    *
  108.    * -M <br>
  109.    * Treat missing values as a seperate value. <p>
  110.    * 
  111.    * -L <br>
  112.    * Include locally predictive attributes. <p>
  113.    *
  114.    * @param options the list of options as an array of strings
  115.    * @exception Exception if an option is not supported
  116.    *
  117.    **/
  118.   public void setOptions (String[] options)
  119.     throws Exception
  120.   {
  121.     String optionString;
  122.     resetOptions();
  123.     setMissingSeperate(Utils.getFlag('M', options));
  124.     setLocallyPredictive(Utils.getFlag('L', options));
  125.   }
  126.   /**
  127.    * Returns the tip text for this property
  128.    * @return tip text for this property suitable for
  129.    * displaying in the explorer/experimenter gui
  130.    */
  131.   public String locallyPredictiveTipText() {
  132.     return "Identify locally predictive attributes. Iteratively adds "
  133.       +"attributes with the highest correlation with the class as long "
  134.       +"as there is not already an attribute in the subset that has a "
  135.       +"higher correlation with the attribute in question";
  136.   }
  137.   /**
  138.    * Include locally predictive attributes
  139.    *
  140.    * @param b true or false
  141.    */
  142.   public void setLocallyPredictive (boolean b) {
  143.     m_locallyPredictive = b;
  144.   }
  145.   /**
  146.    * Return true if including locally predictive attributes
  147.    *
  148.    * @return true if locally predictive attributes are to be used
  149.    */
  150.   public boolean getLocallyPredictive () {
  151.     return  m_locallyPredictive;
  152.   }
  153.   /**
  154.    * Returns the tip text for this property
  155.    * @return tip text for this property suitable for
  156.    * displaying in the explorer/experimenter gui
  157.    */
  158.   public String missingSeperateTipText() {
  159.     return "Treat missing as a separate value. Otherwise, counts for missing "
  160.       +"values are distributed across other values in proportion to their "
  161.       +"frequency.";
  162.   }
  163.   /**
  164.    * Treat missing as a seperate value
  165.    *
  166.    * @param b true or false
  167.    */
  168.   public void setMissingSeperate (boolean b) {
  169.     m_missingSeperate = b;
  170.   }
  171.   /**
  172.    * Return true is missing is treated as a seperate value
  173.    *
  174.    * @return true if missing is to be treated as a seperate value
  175.    */
  176.   public boolean getMissingSeperate () {
  177.     return  m_missingSeperate;
  178.   }
  179.   /**
  180.    * Gets the current settings of CfsSubsetEval
  181.    *
  182.    * @return an array of strings suitable for passing to setOptions()
  183.    */
  184.   public String[] getOptions () {
  185.     String[] options = new String[2];
  186.     int current = 0;
  187.     if (getMissingSeperate()) {
  188.       options[current++] = "-M";
  189.     }
  190.     if (getLocallyPredictive()) {
  191.       options[current++] = "-L";
  192.     }
  193.     while (current < options.length) {
  194.       options[current++] = "";
  195.     }
  196.     return  options;
  197.   }
  198.   /**
  199.    * Generates a attribute evaluator. Has to initialize all fields of the 
  200.    * evaluator that are not being set via options.
  201.    *
  202.    * CFS also discretises attributes (if necessary) and initializes
  203.    * the correlation matrix.
  204.    *
  205.    * @param data set of instances serving as training data 
  206.    * @exception Exception if the evaluator has not been 
  207.    * generated successfully
  208.    */
  209.   public void buildEvaluator (Instances data)
  210.     throws Exception
  211.   {
  212.     if (data.checkForStringAttributes()) {
  213.       throw  new Exception("Can't handle string attributes!");
  214.     }
  215.     m_trainInstances = data;
  216.     m_trainInstances.deleteWithMissingClass();
  217.     m_classIndex = m_trainInstances.classIndex();
  218.     m_numAttribs = m_trainInstances.numAttributes();
  219.     m_numInstances = m_trainInstances.numInstances();
  220.     m_isNumeric = m_trainInstances.attribute(m_classIndex).isNumeric();
  221.     if (!m_isNumeric) {
  222.       m_disTransform = new DiscretizeFilter();
  223.       m_disTransform.setUseBetterEncoding(true);
  224.       m_disTransform.setInputFormat(m_trainInstances);
  225.       m_trainInstances = Filter.useFilter(m_trainInstances, m_disTransform);
  226.     }
  227.     m_std_devs = new double[m_numAttribs];
  228.     m_corr_matrix = new Matrix(m_numAttribs, m_numAttribs);
  229.     for (int i = 0; i < m_corr_matrix.numRows(); i++) {
  230.       m_corr_matrix.setElement(i, i, 1.0);
  231.       m_std_devs[i] = 1.0;
  232.     }
  233.     for (int i = 0; i < m_numAttribs; i++) {
  234.       for (int j = i + 1; j < m_numAttribs; j++) {
  235.         m_corr_matrix.setElement(i, j, -999);
  236.         m_corr_matrix.setElement(j, i, -999);
  237.       }
  238.     }
  239.   }
  240.   /**
  241.    * evaluates a subset of attributes
  242.    *
  243.    * @param subset a bitset representing the attribute subset to be 
  244.    * evaluated 
  245.    * @exception Exception if the subset could not be evaluated
  246.    */
  247.   public double evaluateSubset (BitSet subset)
  248.     throws Exception
  249.   {
  250.     double num = 0.0;
  251.     double denom = 0.0;
  252.     double corr;
  253.     // do numerator
  254.     for (int i = 0; i < m_numAttribs; i++) {
  255.       if (i != m_classIndex) {
  256.         if (subset.get(i)) {
  257.           if (m_corr_matrix.getElement(i, m_classIndex) == -999) {
  258.             corr = correlate(i, m_classIndex);
  259.             m_corr_matrix.setElement(i, m_classIndex, corr);
  260.             m_corr_matrix.setElement(m_classIndex, i, corr);
  261.             num += (m_std_devs[i] * corr);
  262.           }
  263.           else {num += (m_std_devs[i] * 
  264. m_corr_matrix.getElement(i, m_classIndex));
  265.   }
  266. }
  267.       }
  268.     }
  269.     // do denominator
  270.     for (int i = 0; i < m_numAttribs; i++) {
  271.       if (i != m_classIndex) {
  272. if (subset.get(i)) {
  273.   denom += (1.0 * m_std_devs[i] * m_std_devs[i]);
  274.   for (int j = i + 1; j < m_numAttribs; j++) {if (subset.get(j)) {
  275.     if (m_corr_matrix.getElement(i, j) == -999) {
  276.       corr = correlate(i, j);
  277.       m_corr_matrix.setElement(i, j, corr);
  278.       m_corr_matrix.setElement(j, i, corr);
  279.       denom += (2.0 * m_std_devs[i] * m_std_devs[j] * corr);
  280.     }
  281.     else {denom += (2.0 * m_std_devs[i] * m_std_devs[j] * 
  282.     m_corr_matrix.getElement(i, j));
  283.     }
  284.   }
  285.   }
  286. }
  287.       }
  288.     }
  289.     if (denom < 0.0) {
  290.       denom *= -1.0;
  291.     }
  292.     if (denom == 0.0) {
  293.       return  (0.0);
  294.     }
  295.     double merit = (num/Math.sqrt(denom));
  296.     if (merit < 0.0) {
  297.       merit *= -1.0;
  298.     }
  299.     return  merit;
  300.   }
  301.   private double correlate (int att1, int att2) {
  302.     if (!m_isNumeric) {
  303.       return  symmUncertCorr(att1, att2);
  304.     }
  305.     boolean att1_is_num = (m_trainInstances.attribute(att1).isNumeric());
  306.     boolean att2_is_num = (m_trainInstances.attribute(att2).isNumeric());
  307.     if (att1_is_num && att2_is_num) {
  308.       return  num_num(att1, att2);
  309.     }
  310.     else {if (att2_is_num) {
  311.       return  num_nom2(att1, att2);
  312.     }
  313.     else {if (att1_is_num) {
  314.       return  num_nom2(att2, att1);
  315.     }
  316.     }
  317.     }
  318.     return  nom_nom(att1, att2);
  319.   }
  320.   private double symmUncertCorr (int att1, int att2) {
  321.     int i, j, k, ii, jj;
  322.     int nnj, nni, ni, nj;
  323.     double sum = 0.0;
  324.     double sumi[], sumj[];
  325.     double counts[][];
  326.     Instance inst;
  327.     double corr_measure;
  328.     boolean flag = false;
  329.     double temp = 0.0;
  330.     if (att1 == m_classIndex || att2 == m_classIndex) {
  331.       flag = true;
  332.     }
  333.     ni = m_trainInstances.attribute(att1).numValues() + 1;
  334.     nj = m_trainInstances.attribute(att2).numValues() + 1;
  335.     counts = new double[ni][nj];
  336.     sumi = new double[ni];
  337.     sumj = new double[nj];
  338.     for (i = 0; i < ni; i++) {
  339.       sumi[i] = 0.0;
  340.       for (j = 0; j < nj; j++) {
  341. sumj[j] = 0.0;
  342. counts[i][j] = 0.0;
  343.       }
  344.     }
  345.     // Fill the contingency table
  346.     for (i = 0; i < m_numInstances; i++) {
  347.       inst = m_trainInstances.instance(i);
  348.       if (inst.isMissing(att1)) {
  349. ii = ni - 1;
  350.       }
  351.       else {
  352. ii = (int)inst.value(att1);
  353.       }
  354.       if (inst.isMissing(att2)) {
  355. jj = nj - 1;
  356.       }
  357.       else {
  358. jj = (int)inst.value(att2);
  359.       }
  360.       counts[ii][jj]++;
  361.     }
  362.     // get the row totals
  363.     for (i = 0; i < ni; i++) {
  364.       sumi[i] = 0.0;
  365.       for (j = 0; j < nj; j++) {
  366. sumi[i] += counts[i][j];
  367. sum += counts[i][j];
  368.       }
  369.     }
  370.     // get the column totals
  371.     for (j = 0; j < nj; j++) {
  372.       sumj[j] = 0.0;
  373.       for (i = 0; i < ni; i++) {
  374. sumj[j] += counts[i][j];
  375.       }
  376.     }
  377.     // distribute missing counts
  378.     if (!m_missingSeperate && 
  379. (sumi[ni-1] < m_numInstances) && 
  380. (sumj[nj-1] < m_numInstances)) {
  381.       double[] i_copy = new double[sumi.length];
  382.       double[] j_copy = new double[sumj.length];
  383.       double[][] counts_copy = new double[sumi.length][sumj.length];
  384.       for (i = 0; i < ni; i++) {
  385. System.arraycopy(counts[i], 0, counts_copy[i], 0, sumj.length);
  386.       }
  387.       System.arraycopy(sumi, 0, i_copy, 0, sumi.length);
  388.       System.arraycopy(sumj, 0, j_copy, 0, sumj.length);
  389.       double total_missing = 
  390. (sumi[ni - 1] + sumj[nj - 1] - counts[ni - 1][nj - 1]);
  391.       // do the missing i's
  392.       if (sumi[ni - 1] > 0.0) {
  393. for (j = 0; j < nj - 1; j++) {
  394.   if (counts[ni - 1][j] > 0.0) {
  395.     for (i = 0; i < ni - 1; i++) {
  396.       temp = ((i_copy[i]/(sum - i_copy[ni - 1]))*counts[ni - 1][j]);
  397.       counts[i][j] += temp;
  398.       sumi[i] += temp;
  399.     }
  400.     counts[ni - 1][j] = 0.0;
  401.   }
  402. }
  403.       }
  404.       sumi[ni - 1] = 0.0;
  405.       // do the missing j's
  406.       if (sumj[nj - 1] > 0.0) {
  407. for (i = 0; i < ni - 1; i++) {
  408.   if (counts[i][nj - 1] > 0.0) {
  409.     for (j = 0; j < nj - 1; j++) {
  410.       temp = ((j_copy[j]/(sum - j_copy[nj - 1]))*counts[i][nj - 1]);
  411.       counts[i][j] += temp;
  412.       sumj[j] += temp;
  413.     }
  414.     counts[i][nj - 1] = 0.0;
  415.   }
  416. }
  417.       }
  418.       sumj[nj - 1] = 0.0;
  419.       // do the both missing
  420.       if (counts[ni - 1][nj - 1] > 0.0 && total_missing != sum) {
  421. for (i = 0; i < ni - 1; i++) {
  422.   for (j = 0; j < nj - 1; j++) {
  423.     temp = (counts_copy[i][j]/(sum - total_missing)) * 
  424.       counts_copy[ni - 1][nj - 1];
  425.     
  426.     counts[i][j] += temp;
  427.     sumi[i] += temp;
  428.     sumj[j] += temp;
  429.   }
  430. }
  431. counts[ni - 1][nj - 1] = 0.0;
  432.       }
  433.     }
  434.     // corr_measure = Correlate.symm_uncert(counts,sumi,sumj,sum,ni,nj,flag);
  435.     corr_measure = ContingencyTables.symmetricalUncertainty(counts);
  436.     // corr_measure = ContingencyTables.gainRatio(counts);
  437.     if (Utils.eq(corr_measure, 0.0)) {
  438.       if (flag == true) {
  439. return  (0.0);
  440.       }
  441.       else {
  442. return  (1.0);
  443.       }
  444.     }
  445.     else {
  446.       return  (corr_measure);
  447.     }
  448.   }
  449.   private double num_num (int att1, int att2) {
  450.     int i;
  451.     Instance inst;
  452.     double r, diff1, diff2, num = 0.0, sx = 0.0, sy = 0.0;
  453.     double mx = m_trainInstances.meanOrMode(m_trainInstances.attribute(att1));
  454.     double my = m_trainInstances.meanOrMode(m_trainInstances.attribute(att2));
  455.     for (i = 0; i < m_numInstances; i++) {
  456.       inst = m_trainInstances.instance(i);
  457.       diff1 = (inst.isMissing(att1))? 0.0 : (inst.value(att1) - mx);
  458.       diff2 = (inst.isMissing(att2))? 0.0 : (inst.value(att2) - my);
  459.       num += (diff1*diff2);
  460.       sx += (diff1*diff1);
  461.       sy += (diff2*diff2);
  462.     }
  463.     if (sx != 0.0) {
  464.       if (m_std_devs[att1] == 1.0) {
  465. m_std_devs[att1] = Math.sqrt((sx/m_numInstances));
  466.       }
  467.     }
  468.     if (sy != 0.0) {
  469.       if (m_std_devs[att2] == 1.0) {
  470. m_std_devs[att2] = Math.sqrt((sy/m_numInstances));
  471.       }
  472.     }
  473.     if ((sx*sy) > 0.0) {
  474.       r = (num/(Math.sqrt(sx*sy)));
  475.       return  ((r < 0.0)? -r : r);
  476.     }
  477.     else {
  478.       if (att1 != m_classIndex && att2 != m_classIndex) {
  479. return  1.0;
  480.       }
  481.       else {
  482. return  0.0;
  483.       }
  484.     }
  485.   }
  486.   private double num_nom2 (int att1, int att2) {
  487.     int i, ii, k;
  488.     double temp;
  489.     Instance inst;
  490.     int mx = (int)m_trainInstances.
  491.       meanOrMode(m_trainInstances.attribute(att1));
  492.     double my = m_trainInstances.
  493.       meanOrMode(m_trainInstances.attribute(att2));
  494.     double stdv_num = 0.0;
  495.     double diff1, diff2;
  496.     double r = 0.0, rr, max_corr = 0.0;
  497.     int nx = (!m_missingSeperate) 
  498.       ? m_trainInstances.attribute(att1).numValues() 
  499.       : m_trainInstances.attribute(att1).numValues() + 1;
  500.     double[] prior_nom = new double[nx];
  501.     double[] stdvs_nom = new double[nx];
  502.     double[] covs = new double[nx];
  503.     for (i = 0; i < nx; i++) {
  504.       stdvs_nom[i] = covs[i] = prior_nom[i] = 0.0;
  505.     }
  506.     // calculate frequencies (and means) of the values of the nominal 
  507.     // attribute
  508.     for (i = 0; i < m_numInstances; i++) {
  509.       inst = m_trainInstances.instance(i);
  510.       if (inst.isMissing(att1)) {
  511. if (!m_missingSeperate) {
  512.   ii = mx;
  513. }
  514. else {
  515.   ii = nx - 1;
  516. }
  517.       }
  518.       else {
  519. ii = (int)inst.value(att1);
  520.       }
  521.       // increment freq for nominal
  522.       prior_nom[ii]++;
  523.     }
  524.     for (k = 0; k < m_numInstances; k++) {
  525.       inst = m_trainInstances.instance(k);
  526.       // std dev of numeric attribute
  527.       diff2 = (inst.isMissing(att2))? 0.0 : (inst.value(att2) - my);
  528.       stdv_num += (diff2*diff2);
  529.       // 
  530.       for (i = 0; i < nx; i++) {
  531. if (inst.isMissing(att1)) {
  532.   if (!m_missingSeperate) {
  533.     temp = (i == mx)? 1.0 : 0.0;
  534.   }
  535.   else {
  536.     temp = (i == (nx - 1))? 1.0 : 0.0;
  537.   }
  538. }
  539. else {
  540.   temp = (i == inst.value(att1))? 1.0 : 0.0;
  541. }
  542. diff1 = (temp - (prior_nom[i]/m_numInstances));
  543. stdvs_nom[i] += (diff1*diff1);
  544. covs[i] += (diff1*diff2);
  545.       }
  546.     }
  547.     // calculate weighted correlation
  548.     for (i = 0, temp = 0.0; i < nx; i++) {
  549.       // calculate the weighted variance of the nominal
  550.       temp += ((prior_nom[i]/m_numInstances)*(stdvs_nom[i]/m_numInstances));
  551.       if ((stdvs_nom[i]*stdv_num) > 0.0) {
  552. //System.out.println("Stdv :"+stdvs_nom[i]);
  553. rr = (covs[i]/(Math.sqrt(stdvs_nom[i]*stdv_num)));
  554. if (rr < 0.0) {
  555.   rr = -rr;
  556. }
  557. r += ((prior_nom[i]/m_numInstances)*rr);
  558.       }
  559.       /* if there is zero variance for the numeric att at a specific 
  560.  level of the catergorical att then if neither is the class then 
  561.  make this correlation at this level maximally bad i.e. 1.0. 
  562.  If either is the class then maximally bad correlation is 0.0 */
  563.       else {if (att1 != m_classIndex && att2 != m_classIndex) {
  564. r += ((prior_nom[i]/m_numInstances)*1.0);
  565.       }
  566.       }
  567.     }
  568.     // set the standard deviations for these attributes if necessary
  569.     // if ((att1 != classIndex) && (att2 != classIndex)) // =============
  570.     if (temp != 0.0) {
  571.       if (m_std_devs[att1] == 1.0) {
  572. m_std_devs[att1] = Math.sqrt(temp);
  573.       }
  574.     }
  575.     if (stdv_num != 0.0) {
  576.       if (m_std_devs[att2] == 1.0) {
  577. m_std_devs[att2] = Math.sqrt((stdv_num/m_numInstances));
  578.       }
  579.     }
  580.     if (r == 0.0) {
  581.       if (att1 != m_classIndex && att2 != m_classIndex) {
  582. r = 1.0;
  583.       }
  584.     }
  585.     return  r;
  586.   }
  587.   private double nom_nom (int att1, int att2) {
  588.     int i, j, ii, jj, z;
  589.     double temp1, temp2;
  590.     Instance inst;
  591.     int mx = (int)m_trainInstances.
  592.       meanOrMode(m_trainInstances.attribute(att1));
  593.     int my = (int)m_trainInstances.
  594.       meanOrMode(m_trainInstances.attribute(att2));
  595.     double diff1, diff2;
  596.     double r = 0.0, rr, max_corr = 0.0;
  597.     int nx = (!m_missingSeperate) 
  598.       ? m_trainInstances.attribute(att1).numValues() 
  599.       : m_trainInstances.attribute(att1).numValues() + 1;
  600.     int ny = (!m_missingSeperate)
  601.       ? m_trainInstances.attribute(att2).numValues() 
  602.       : m_trainInstances.attribute(att2).numValues() + 1;
  603.     double[][] prior_nom = new double[nx][ny];
  604.     double[] sumx = new double[nx];
  605.     double[] sumy = new double[ny];
  606.     double[] stdvsx = new double[nx];
  607.     double[] stdvsy = new double[ny];
  608.     double[][] covs = new double[nx][ny];
  609.     for (i = 0; i < nx; i++) {
  610.       sumx[i] = stdvsx[i] = 0.0;
  611.     }
  612.     for (j = 0; j < ny; j++) {
  613.       sumy[j] = stdvsy[j] = 0.0;
  614.     }
  615.     for (i = 0; i < nx; i++) {
  616.       for (j = 0; j < ny; j++) {
  617. covs[i][j] = prior_nom[i][j] = 0.0;
  618.       }
  619.     }
  620.     // calculate frequencies (and means) of the values of the nominal 
  621.     // attribute
  622.     for (i = 0; i < m_numInstances; i++) {
  623.       inst = m_trainInstances.instance(i);
  624.       if (inst.isMissing(att1)) {
  625. if (!m_missingSeperate) {
  626.   ii = mx;
  627. }
  628. else {
  629.   ii = nx - 1;
  630. }
  631.       }
  632.       else {
  633. ii = (int)inst.value(att1);
  634.       }
  635.       if (inst.isMissing(att2)) {
  636. if (!m_missingSeperate) {
  637.   jj = my;
  638. }
  639. else {
  640.   jj = ny - 1;
  641. }
  642.       }
  643.       else {
  644. jj = (int)inst.value(att2);
  645.       }
  646.       // increment freq for nominal
  647.       prior_nom[ii][jj]++;
  648.       sumx[ii]++;
  649.       sumy[jj]++;
  650.     }
  651.     for (z = 0; z < m_numInstances; z++) {
  652.       inst = m_trainInstances.instance(z);
  653.       for (j = 0; j < ny; j++) {
  654. if (inst.isMissing(att2)) {
  655.   if (!m_missingSeperate) {
  656.     temp2 = (j == my)? 1.0 : 0.0;
  657.   }
  658.   else {
  659.     temp2 = (j == (ny - 1))? 1.0 : 0.0;
  660.   }
  661. }
  662. else {
  663.   temp2 = (j == inst.value(att2))? 1.0 : 0.0;
  664. }
  665. diff2 = (temp2 - (sumy[j]/m_numInstances));
  666. stdvsy[j] += (diff2*diff2);
  667.       }
  668.       // 
  669.       for (i = 0; i < nx; i++) {
  670. if (inst.isMissing(att1)) {
  671.   if (!m_missingSeperate) {
  672.     temp1 = (i == mx)? 1.0 : 0.0;
  673.   }
  674.   else {
  675.     temp1 = (i == (nx - 1))? 1.0 : 0.0;
  676.   }
  677. }
  678. else {
  679.   temp1 = (i == inst.value(att1))? 1.0 : 0.0;
  680. }
  681. diff1 = (temp1 - (sumx[i]/m_numInstances));
  682. stdvsx[i] += (diff1*diff1);
  683. for (j = 0; j < ny; j++) {
  684.   if (inst.isMissing(att2)) {
  685.     if (!m_missingSeperate) {
  686.       temp2 = (j == my)? 1.0 : 0.0;
  687.     }
  688.     else {
  689.       temp2 = (j == (ny - 1))? 1.0 : 0.0;
  690.     }
  691.   }
  692.   else {
  693.     temp2 = (j == inst.value(att2))? 1.0 : 0.0;
  694.   }
  695.   diff2 = (temp2 - (sumy[j]/m_numInstances));
  696.   covs[i][j] += (diff1*diff2);
  697. }
  698.       }
  699.     }
  700.     // calculate weighted correlation
  701.     for (i = 0; i < nx; i++) {
  702.       for (j = 0; j < ny; j++) {
  703. if ((stdvsx[i]*stdvsy[j]) > 0.0) {
  704.   //System.out.println("Stdv :"+stdvs_nom[i]);
  705.   rr = (covs[i][j]/(Math.sqrt(stdvsx[i]*stdvsy[j])));
  706.   if (rr < 0.0) {
  707.     rr = -rr;
  708.   }
  709.   r += ((prior_nom[i][j]/m_numInstances)*rr);
  710. }
  711. // if there is zero variance for either of the categorical atts then if
  712. // neither is the class then make this
  713. // correlation at this level maximally bad i.e. 1.0. If either is 
  714. // the class then maximally bad correlation is 0.0
  715. else {if (att1 != m_classIndex && att2 != m_classIndex) {
  716.   r += ((prior_nom[i][j]/m_numInstances)*1.0);
  717. }
  718. }
  719.       }
  720.     }
  721.     // calculate weighted standard deviations for these attributes
  722.     // (if necessary)
  723.     for (i = 0, temp1 = 0.0; i < nx; i++) {
  724.       temp1 += ((sumx[i]/m_numInstances)*(stdvsx[i]/m_numInstances));
  725.     }
  726.     if (temp1 != 0.0) {
  727.       if (m_std_devs[att1] == 1.0) {
  728. m_std_devs[att1] = Math.sqrt(temp1);
  729.       }
  730.     }
  731.     for (j = 0, temp2 = 0.0; j < ny; j++) {
  732.       temp2 += ((sumy[j]/m_numInstances)*(stdvsy[j]/m_numInstances));
  733.     }
  734.     if (temp2 != 0.0) {
  735.       if (m_std_devs[att2] == 1.0) {
  736. m_std_devs[att2] = Math.sqrt(temp2);
  737.       }
  738.     }
  739.     if (r == 0.0) {
  740.       if (att1 != m_classIndex && att2 != m_classIndex) {
  741. r = 1.0;
  742.       }
  743.     }
  744.     return  r;
  745.   }
  746.   /**
  747.    * returns a string describing CFS
  748.    *
  749.    * @return the description as a string
  750.    */
  751.   public String toString () {
  752.     StringBuffer text = new StringBuffer();
  753.     if (m_trainInstances == null) {
  754.       text.append("CFS subset evaluator has not been built yetn");
  755.     }
  756.     else {
  757.       text.append("tCFS Subset Evaluatorn");
  758.       if (m_missingSeperate) {
  759. text.append("tTreating missing values as a seperate valuen");
  760.       }
  761.       if (m_locallyPredictive) {
  762. text.append("tIncluding locally predictive attributesn");
  763.       }
  764.     }
  765.     return  text.toString();
  766.   }
  767.   private void addLocallyPredictive (BitSet best_group) {
  768.     int i, j;
  769.     boolean done = false;
  770.     boolean ok = true;
  771.     double temp_best = -1.0;
  772.     double corr;
  773.     j = 0;
  774.     BitSet temp_group = (BitSet)best_group.clone();
  775.     while (!done) {
  776.       temp_best = -1.0;
  777.       // find best not already in group
  778.       for (i = 0; i < m_numAttribs; i++) {
  779. if ((!temp_group.get(i)) && (i != m_classIndex)) {
  780.   if (m_corr_matrix.getElement(i, m_classIndex) == -999) {
  781.     corr = correlate(i, m_classIndex);
  782.     m_corr_matrix.setElement(i, m_classIndex, corr);
  783.     m_corr_matrix.setElement(m_classIndex, i, corr);
  784.   }
  785.   if (m_corr_matrix.getElement(i, m_classIndex) > temp_best) {
  786.     temp_best = m_corr_matrix.getElement(i, m_classIndex);
  787.     j = i;
  788.   }
  789. }
  790.       }
  791.       if (temp_best == -1.0) {
  792. done = true;
  793.       }
  794.       else {
  795. ok = true;
  796. temp_group.set(j);
  797. // check the best against correlations with others already
  798. // in group 
  799. for (i = 0; i < m_numAttribs; i++) {
  800.   if (best_group.get(i)) {
  801.     if (m_corr_matrix.getElement(i, j) == -999) {
  802.       corr = correlate(i, j);
  803.       m_corr_matrix.setElement(i, j, corr);
  804.       m_corr_matrix.setElement(j, i, corr);
  805.     }
  806.     if (m_corr_matrix.getElement(i, j) > temp_best - m_c_Threshold) {
  807.       ok = false;
  808.       break;
  809.     }
  810.   }
  811. }
  812. // if ok then add to best_group
  813. if (ok) {
  814.   best_group.set(j);
  815. }
  816.       }
  817.     }
  818.   }
  819.   /**
  820.    * Calls locallyPredictive in order to include locally predictive
  821.    * attributes (if requested).
  822.    *
  823.    * @param attributeSet the set of attributes found by the search
  824.    * @return a possibly ranked list of postprocessed attributes
  825.    * @exception Exception if postprocessing fails for some reason
  826.    */
  827.   public int[] postProcess (int[] attributeSet)
  828.     throws Exception
  829.   {
  830.     int j = 0;
  831.     if (!m_locallyPredictive) {
  832.       //      m_trainInstances = new Instances(m_trainInstances,0);
  833.       return  attributeSet;
  834.     }
  835.     BitSet bestGroup = new BitSet(m_numAttribs);
  836.     for (int i = 0; i < attributeSet.length; i++) {
  837.       bestGroup.set(attributeSet[i]);
  838.     }
  839.     addLocallyPredictive(bestGroup);
  840.     // count how many are set
  841.     for (int i = 0; i < m_numAttribs; i++) {
  842.       if (bestGroup.get(i)) {
  843. j++;
  844.       }
  845.     }
  846.     int[] newSet = new int[j];
  847.     j = 0;
  848.     for (int i = 0; i < m_numAttribs; i++) {
  849.       if (bestGroup.get(i)) {
  850. newSet[j++] = i;
  851.       }
  852.     }
  853.     //    m_trainInstances = new Instances(m_trainInstances,0);
  854.     return  newSet;
  855.   }
  856.   protected void resetOptions () {
  857.     m_trainInstances = null;
  858.     m_missingSeperate = false;
  859.     m_locallyPredictive = false;
  860.     m_c_Threshold = 0.0;
  861.   }
  862.   /**
  863.    * Main method for testing this class.
  864.    *
  865.    * @param args the options
  866.    */
  867.   public static void main (String[] args) {
  868.     try {
  869.       System.out.println(AttributeSelection.
  870.  SelectAttributes(new CfsSubsetEval(), args));
  871.     }
  872.     catch (Exception e) {
  873.       e.printStackTrace();
  874.       System.out.println(e.getMessage());
  875.     }
  876.   }
  877. }