ReliefFAttributeEval.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 31k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    ReliefFAttributeEval.java
  18.  *    Copyright (C) 1999 Mark Hall
  19.  *
  20.  */
  21. package  weka.attributeSelection;
  22. import  java.io.*;
  23. import  java.util.*;
  24. import  weka.core.*;
  25. /** 
  26.  * Class for Evaluating attributes individually using ReliefF. <p>
  27.  *
  28.  * For more information see: <p>
  29.  *
  30.  * Kira, K. and Rendell, L. A. (1992). A practical approach to feature
  31.  * selection. In D. Sleeman and P. Edwards, editors, <i>Proceedings of
  32.  * the International Conference on Machine Learning,</i> pages 249-256.
  33.  * Morgan Kaufmann. <p>
  34.  *
  35.  * Kononenko, I. (1994). Estimating attributes: analysis and extensions of
  36.  * Relief. In De Raedt, L. and Bergadano, F., editors, <i> Machine Learning:
  37.  * ECML-94, </i> pages 171-182. Springer Verlag. <p>
  38.  *
  39.  * Marko Robnik Sikonja, Igor Kononenko: An adaptation of Relief for attribute
  40.  * estimation on regression. In D.Fisher (ed.): <i> Machine Learning, 
  41.  * Proceedings of 14th International Conference on Machine Learning ICML'97, 
  42.  * </i> Nashville, TN, 1997. <p>
  43.  *
  44.  *
  45.  * Valid options are:
  46.  *
  47.  * -M <number of instances> <br>
  48.  * Specify the number of instances to sample when estimating attributes. <br>
  49.  * If not specified then all instances will be used. <p>
  50.  *
  51.  * -D <seed> <br>
  52.  * Seed for randomly sampling instances. <p>
  53.  *
  54.  * -K <number of neighbours> <br>
  55.  * Number of nearest neighbours to use for estimating attributes. <br>
  56.  * (Default is 10). <p>
  57.  *
  58.  * -W <br>
  59.  * Weight nearest neighbours by distance. <p>
  60.  *
  61.  * -A <sigma> <br>
  62.  * Specify sigma value (used in an exp function to control how quickly <br>
  63.  * weights decrease for more distant instances). Use in conjunction with <br>
  64.  * -W. Sensible values = 1/5 to 1/10 the number of nearest neighbours. <br>
  65.  *
  66.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  67.  * @version $Revision: 1.14 $
  68.  */
  69. public class ReliefFAttributeEval
  70.   extends AttributeEvaluator
  71.   implements OptionHandler
  72. {
  73.   /** The training instances */
  74.   private Instances m_trainInstances;
  75.   /** The class index */
  76.   private int m_classIndex;
  77.   /** The number of attributes */
  78.   private int m_numAttribs;
  79.   /** The number of instances */
  80.   private int m_numInstances;
  81.   /** Numeric class */
  82.   private boolean m_numericClass;
  83.   /** The number of classes if class is nominal */
  84.   private int m_numClasses;
  85.   /** 
  86.    * Used to hold the probability of a different class val given nearest
  87.    * instances (numeric class)
  88.    */
  89.   private double m_ndc;
  90.   /** 
  91.    * Used to hold the prob of different value of an attribute given
  92.    * nearest instances (numeric class case)
  93.    */
  94.   private double[] m_nda;
  95.   /**
  96.    * Used to hold the prob of a different class val and different att
  97.    * val given nearest instances (numeric class case)
  98.    */
  99.   private double[] m_ndcda;
  100.   /** Holds the weights that relief assigns to attributes */
  101.   private double[] m_weights;
  102.   /** Prior class probabilities (discrete class case) */
  103.   private double[] m_classProbs;
  104.   /** 
  105.    * The number of instances to sample when estimating attributes
  106.    * default == -1, use all instances
  107.    */
  108.   private int m_sampleM;
  109.   /** The number of nearest hits/misses */
  110.   private int m_Knn;
  111.   /** k nearest scores + instance indexes for n classes */
  112.   private double[][][] m_karray;
  113.   /** Upper bound for numeric attributes */
  114.   private double[] m_maxArray;
  115.   /** Lower bound for numeric attributes */
  116.   private double[] m_minArray;
  117.   /** Keep track of the farthest instance for each class */
  118.   private double[] m_worst;
  119.   /** Index in the m_karray of the farthest instance for each class */
  120.   private int[] m_index;
  121.   /** Number of nearest neighbours stored of each class */
  122.   private int[] m_stored;
  123.  
  124.   /** Random number seed used for sampling instances */
  125.   private int m_seed;
  126.   /**
  127.    *  used to (optionally) weight nearest neighbours by their distance
  128.    *  from the instance in question. Each entry holds 
  129.    *  exp(-((rank(r_i, i_j)/sigma)^2)) where rank(r_i,i_j) is the rank of
  130.    *  instance i_j in a sequence of instances ordered by the distance
  131.    *  from r_i. sigma is a user defined parameter, default=20
  132.    **/
  133.   private double[] m_weightsByRank;
  134.   private int m_sigma;
  135.   
  136.   /** Weight by distance rather than equal weights */
  137.   private boolean m_weightByDistance;
  138.   /**
  139.    * Returns a string describing this attribute evaluator
  140.    * @return a description of the evaluator suitable for
  141.    * displaying in the explorer/experimenter gui
  142.    */
  143.   public String globalInfo() {
  144.     return "ReliefFAttributeEval :nnEvaluates the worth of an attribute by "
  145.       +"repeatedly sampling an instance and considering the value of the "
  146.       +"given attribute for the nearest instance of the same and different "
  147.       +"class. Can operate on both discrete and continuous class data.n";
  148.   }
  149.   /**
  150.    * Constructor
  151.    */
  152.   public ReliefFAttributeEval () {
  153.     resetOptions();
  154.   }
  155.   /**
  156.    * Returns an enumeration describing the available options.
  157.    * @return an enumeration of all the available options.
  158.    **/
  159.   public Enumeration listOptions () {
  160.     Vector newVector = new Vector(4);
  161.     newVector
  162.       .addElement(new Option("tSpecify the number of instances ton" 
  163.      + "tsample when estimating attributes.n" 
  164.      + "tIf not specified, then all instancesn" 
  165.      + "twill be used.", "M", 1
  166.      , "-M <num instances>"));
  167.     newVector.
  168.       addElement(new Option("tSeed for randomly sampling instances.n" 
  169.     + "t(Default = 1)", "D", 1
  170.     , "-D <seed>"));
  171.     newVector.
  172.       addElement(new Option("tNumber of nearest neighbours (k) usedn" 
  173.     + "tto estimate attribute relevancesn" 
  174.     + "t(Default = 10).", "K", 1
  175.     , "-K <number of neighbours>"));
  176.     newVector.
  177.       addElement(new Option("tWeight nearest neighbours by distancen", "W"
  178.     , 0, "-W"));
  179.     newVector.
  180.       addElement(new Option("tSpecify sigma value (used in an expn" 
  181.     + "tfunction to control how quicklyn" 
  182.     + "tweights for more distant instancesn" 
  183.     + "tdecrease. Use in conjunction with -W.n" 
  184.     + "tSensible value=1/5 to 1/10 of then" 
  185.     + "tnumber of nearest neighbours.n" 
  186.     + "t(Default = 2)", "A", 1, "-A <num>"));
  187.     return  newVector.elements();
  188.   }
  189.   /**
  190.    * Parses a given list of options.
  191.    *
  192.    * Valid options are: <p>
  193.    *
  194.    * -M <number of instances> <br>
  195.    * Specify the number of instances to sample when estimating attributes. <br>
  196.    * If not specified then all instances will be used. <p>
  197.    *
  198.    * -D <seed> <br>
  199.    * Seed for randomly sampling instances. <p>
  200.    *
  201.    * -K <number of neighbours> <br>
  202.    * Number of nearest neighbours to use for estimating attributes. <br>
  203.    * (Default is 10). <p>
  204.    *
  205.    * -W <br>
  206.    * Weight nearest neighbours by distance. <p>
  207.    *
  208.    * -A <sigma> <br>
  209.    * Specify sigma value (used in an exp function to control how quickly <br>
  210.    * weights decrease for more distant instances). Use in conjunction with <br>
  211.    * -W. Sensible values = 1/5 to 1/10 the number of nearest neighbours. <br>
  212.    *
  213.    * @param options the list of options as an array of strings
  214.    * @exception Exception if an option is not supported
  215.    *
  216.    **/
  217.   public void setOptions (String[] options)
  218.     throws Exception
  219.   {
  220.     String optionString;
  221.     resetOptions();
  222.     setWeightByDistance(Utils.getFlag('W', options));
  223.     optionString = Utils.getOption('M', options);
  224.     if (optionString.length() != 0) {
  225.       setSampleSize(Integer.parseInt(optionString));
  226.     }
  227.     optionString = Utils.getOption('D', options);
  228.     if (optionString.length() != 0) {
  229.       setSeed(Integer.parseInt(optionString));
  230.     }
  231.     optionString = Utils.getOption('K', options);
  232.     if (optionString.length() != 0) {
  233.       setNumNeighbours(Integer.parseInt(optionString));
  234.     }
  235.     optionString = Utils.getOption('A', options);
  236.     if (optionString.length() != 0) {
  237.       setWeightByDistance(true); // turn on weighting by distance
  238.       setSigma(Integer.parseInt(optionString));
  239.     }
  240.   }
  241.   /**
  242.    * Returns the tip text for this property
  243.    * @return tip text for this property suitable for
  244.    * displaying in the explorer/experimenter gui
  245.    */
  246.   public String sigmaTipText() {
  247.     return "Set influence of nearest neighbours. Used in an exp function to "
  248.       +"control how quickly weights decrease for more distant instances. "
  249.       +"Use in conjunction with weightByDistance. Sensible values = 1/5 to "
  250.       +"1/10 the number of nearest neighbours.";
  251.   }
  252.   /**
  253.    * Sets the sigma value.
  254.    *
  255.    * @param s the value of sigma (> 0)
  256.    * @exception Exception if s is not positive
  257.    */
  258.   public void setSigma (int s)
  259.     throws Exception
  260.   {
  261.     if (s <= 0) {
  262.       throw  new Exception("value of sigma must be > 0!");
  263.     }
  264.     m_sigma = s;
  265.   }
  266.   /**
  267.    * Get the value of sigma.
  268.    *
  269.    * @return the sigma value.
  270.    */
  271.   public int getSigma () {
  272.     return  m_sigma;
  273.   }
  274.   /**
  275.    * Returns the tip text for this property
  276.    * @return tip text for this property suitable for
  277.    * displaying in the explorer/experimenter gui
  278.    */
  279.   public String numNeighboursTipText() {
  280.     return "Number of nearest neighbours for attribute estimation.";
  281.   }
  282.   /**
  283.    * Set the number of nearest neighbours
  284.    *
  285.    * @param n the number of nearest neighbours.
  286.    */
  287.   public void setNumNeighbours (int n) {
  288.     m_Knn = n;
  289.   }
  290.   /**
  291.    * Get the number of nearest neighbours
  292.    *
  293.    * @return the number of nearest neighbours
  294.    */
  295.   public int getNumNeighbours () {
  296.     return  m_Knn;
  297.   }
  298.   /**
  299.    * Returns the tip text for this property
  300.    * @return tip text for this property suitable for
  301.    * displaying in the explorer/experimenter gui
  302.    */
  303.   public String seedTipText() {
  304.     return "Random seed for sampling instances.";
  305.   }
  306.   /**
  307.    * Set the random number seed for randomly sampling instances.
  308.    *
  309.    * @param s the random number seed.
  310.    */
  311.   public void setSeed (int s) {
  312.     m_seed = s;
  313.   }
  314.   /**
  315.    * Get the seed used for randomly sampling instances.
  316.    *
  317.    * @return the random number seed.
  318.    */
  319.   public int getSeed () {
  320.     return  m_seed;
  321.   }
  322.   /**
  323.    * Returns the tip text for this property
  324.    * @return tip text for this property suitable for
  325.    * displaying in the explorer/experimenter gui
  326.    */
  327.   public String sampleSizeTipText() {
  328.     return "Number of instances to sample. Default (-1) indicates that all "
  329.       +"instances will be used for attribute estimation.";
  330.   }
  331.   /**
  332.    * Set the number of instances to sample for attribute estimation
  333.    *
  334.    * @param s the number of instances to sample.
  335.    */
  336.   public void setSampleSize (int s) {
  337.     m_sampleM = s;
  338.   }
  339.   /**
  340.    * Get the number of instances used for estimating attributes
  341.    *
  342.    * @return the number of instances.
  343.    */
  344.   public int getSampleSize () {
  345.     return  m_sampleM;
  346.   }
  347.   /**
  348.    * Returns the tip text for this property
  349.    * @return tip text for this property suitable for
  350.    * displaying in the explorer/experimenter gui
  351.    */
  352.   public String weightByDistanceTipText() {
  353.     return "Weight nearest neighbours by their distance.";
  354.   }
  355.   /**
  356.    * Set the nearest neighbour weighting method
  357.    *
  358.    * @param b true nearest neighbours are to be weighted by distance.
  359.    */
  360.   public void setWeightByDistance (boolean b) {
  361.     m_weightByDistance = b;
  362.   }
  363.   /**
  364.    * Get whether nearest neighbours are being weighted by distance
  365.    *
  366.    * @return m_weightByDiffernce
  367.    */
  368.   public boolean getWeightByDistance () {
  369.     return  m_weightByDistance;
  370.   }
  371.   /**
  372.    * Gets the current settings of ReliefFAttributeEval.
  373.    *
  374.    * @return an array of strings suitable for passing to setOptions()
  375.    */
  376.   public String[] getOptions () {
  377.     String[] options = new String[9];
  378.     int current = 0;
  379.     if (getWeightByDistance()) {
  380.       options[current++] = "-W";
  381.     }
  382.     options[current++] = "-M";
  383.     options[current++] = "" + getSampleSize();
  384.     options[current++] = "-D";
  385.     options[current++] = "" + getSeed();
  386.     options[current++] = "-K";
  387.     options[current++] = "" + getNumNeighbours();
  388.     options[current++] = "-A";
  389.     options[current++] = "" + getSigma();
  390.     while (current < options.length) {
  391.       options[current++] = "";
  392.     }
  393.     return  options;
  394.   }
  395.   /**
  396.    * Return a description of the ReliefF attribute evaluator.
  397.    *
  398.    * @return a description of the evaluator as a String.
  399.    */
  400.   public String toString () {
  401.     StringBuffer text = new StringBuffer();
  402.     if (m_trainInstances == null) {
  403.       text.append("ReliefF feature evaluator has not been built yetn");
  404.     }
  405.     else {
  406.       text.append("tReliefF Ranking Filter");
  407.       text.append("ntInstances sampled: ");
  408.       if (m_sampleM == -1) {
  409.         text.append("alln");
  410.       }
  411.       else {
  412.         text.append(m_sampleM + "n");
  413.       }
  414.       text.append("tNumber of nearest neighbours (k): " + m_Knn + "n");
  415.       if (m_weightByDistance) {
  416.         text.append("tExponentially decreasing (with distance) " 
  417.     + "influence forn" 
  418.     + "tnearest neighbours. Sigma: " 
  419.     + m_sigma + "n");
  420.       }
  421.       else {
  422.         text.append("tEqual influence nearest neighboursn");
  423.       }
  424.     }
  425.     return  text.toString();
  426.   }
  427.   /**
  428.    * Initializes a ReliefF attribute evaluator. 
  429.    *
  430.    * @param data set of instances serving as training data 
  431.    * @exception Exception if the evaluator has not been 
  432.    * generated successfully
  433.    */
  434.   public void buildEvaluator (Instances data)
  435.     throws Exception
  436.   {
  437.     int z, totalInstances;
  438.     Random r = new Random(m_seed);
  439.     if (data.checkForStringAttributes()) {
  440.       throw  new UnsupportedAttributeTypeException("Can't handle string attributes!");
  441.     }
  442.     m_trainInstances = data;
  443.     m_classIndex = m_trainInstances.classIndex();
  444.     m_numAttribs = m_trainInstances.numAttributes();
  445.     m_numInstances = m_trainInstances.numInstances();
  446.     if (m_trainInstances.attribute(m_classIndex).isNumeric()) {
  447.       m_numericClass = true;
  448.     }
  449.     else {
  450.       m_numericClass = false;
  451.     }
  452.     if (!m_numericClass) {
  453.       m_numClasses = m_trainInstances.attribute(m_classIndex).numValues();
  454.     }
  455.     else {
  456.       m_ndc = 0;
  457.       m_numClasses = 1;
  458.       m_nda = new double[m_numAttribs];
  459.       m_ndcda = new double[m_numAttribs];
  460.     }
  461.     if (m_weightByDistance) // set up the rank based weights
  462.       {
  463. m_weightsByRank = new double[m_Knn];
  464. for (int i = 0; i < m_Knn; i++) {
  465.   m_weightsByRank[i] = 
  466.     Math.exp(-((i/(double)m_sigma)*(i/(double)m_sigma)));
  467. }
  468.       }
  469.     // the final attribute weights
  470.     m_weights = new double[m_numAttribs];
  471.     // num classes (1 for numeric class) knn neighbours, 
  472.     // and 0 = distance, 1 = instance index
  473.     m_karray = new double[m_numClasses][m_Knn][2];
  474.     if (!m_numericClass) {
  475.       m_classProbs = new double[m_numClasses];
  476.       for (int i = 0; i < m_numInstances; i++) {
  477.         m_classProbs[(int)m_trainInstances.instance(i).value(m_classIndex)]++;
  478.       }
  479.       for (int i = 0; i < m_numClasses; i++) {
  480.         m_classProbs[i] /= m_numInstances;
  481.       }
  482.     }
  483.     m_worst = new double[m_numClasses];
  484.     m_index = new int[m_numClasses];
  485.     m_stored = new int[m_numClasses];
  486.     m_minArray = new double[m_numAttribs];
  487.     m_maxArray = new double[m_numAttribs];
  488.     for (int i = 0; i < m_numAttribs; i++) {
  489.       m_minArray[i] = m_maxArray[i] = Double.NaN;
  490.     }
  491.     for (int i = 0; i < m_numInstances; i++) {
  492.       updateMinMax(m_trainInstances.instance(i));
  493.     }
  494.     
  495.     if ((m_sampleM > m_numInstances) || (m_sampleM < 0)) {
  496.       totalInstances = m_numInstances;
  497.     }
  498.     else {
  499.       totalInstances = m_sampleM;
  500.     }
  501.     // process each instance, updating attribute weights
  502.     for (int i = 0; i < totalInstances; i++) {
  503.       if (totalInstances == m_numInstances) {
  504.         z = i;
  505.       }
  506.       else {
  507.         z = r.nextInt()%m_numInstances;
  508.       }
  509.       if (z < 0) {
  510.         z *= -1;
  511.       }
  512.       if (!(m_trainInstances.instance(z).isMissing(m_classIndex))) {
  513.         // first clear the knn and worst index stuff for the classes
  514.         for (int j = 0; j < m_numClasses; j++) {
  515.           m_index[j] = m_stored[j] = 0;
  516.           for (int k = 0; k < m_Knn; k++) {
  517.             m_karray[j][k][0] = m_karray[j][k][1] = 0;
  518.           }
  519.         }
  520.         findKHitMiss(z);
  521.         if (m_numericClass) {
  522.           updateWeightsNumericClass(z);
  523.         }
  524.         else {
  525.           updateWeightsDiscreteClass(z);
  526.         }
  527.       }
  528.     }
  529.     // now scale weights by 1/m_numInstances (nominal class) or
  530.     // calculate weights numeric class
  531.     // System.out.println("num inst:"+m_numInstances+" r_ndc:"+r_ndc);
  532.     for (int i = 0; i < m_numAttribs; i++) {if (i != m_classIndex) {
  533.       if (m_numericClass) {
  534.         m_weights[i] = m_ndcda[i]/m_ndc - 
  535.   ((m_nda[i] - m_ndcda[i])/((double)totalInstances - m_ndc));
  536.       }
  537.       else {
  538.         m_weights[i] *= (1.0/(double)totalInstances);
  539.       }
  540.       //   System.out.println(r_weights[i]);
  541.     }
  542.     }
  543.   }
  544.   /**
  545.    * Evaluates an individual attribute using ReliefF's instance based approach.
  546.    * The actual work is done by buildEvaluator which evaluates all features.
  547.    *
  548.    * @param attribute the index of the attribute to be evaluated
  549.    * @exception Exception if the attribute could not be evaluated
  550.    */
  551.   public double evaluateAttribute (int attribute)
  552.     throws Exception
  553.   {
  554.     return  m_weights[attribute];
  555.   }
  556.   /**
  557.    * Reset options to their default values
  558.    */
  559.   protected void resetOptions () {
  560.     m_trainInstances = null;
  561.     m_sampleM = -1;
  562.     m_Knn = 10;
  563.     m_sigma = 2;
  564.     m_weightByDistance = false;
  565.     m_seed = 1;
  566.   }
  567.   /**
  568.    * Normalizes a given value of a numeric attribute.
  569.    *
  570.    * @param x the value to be normalized
  571.    * @param i the attribute's index
  572.    */
  573.   private double norm (double x, int i) {
  574.     if (Double.isNaN(m_minArray[i]) || 
  575. Utils.eq(m_maxArray[i], m_minArray[i])) {
  576.       return  0;
  577.     }
  578.     else {
  579.       return  (x - m_minArray[i])/(m_maxArray[i] - m_minArray[i]);
  580.     }
  581.   }
  582.   /**
  583.    * Updates the minimum and maximum values for all the attributes
  584.    * based on a new instance.
  585.    *
  586.    * @param instance the new instance
  587.    */
  588.   private void updateMinMax (Instance instance) {
  589.     //    for (int j = 0; j < m_numAttribs; j++) {
  590.     try {
  591.       for (int j = 0; j < instance.numValues(); j++) {
  592. if ((instance.attributeSparse(j).isNumeric()) && 
  593.     (!instance.isMissingSparse(j))) {
  594.   if (Double.isNaN(m_minArray[instance.index(j)])) {
  595.     m_minArray[instance.index(j)] = instance.valueSparse(j);
  596.     m_maxArray[instance.index(j)] = instance.valueSparse(j);
  597.   }
  598. else {
  599.   if (instance.valueSparse(j) < m_minArray[instance.index(j)]) {
  600.     m_minArray[instance.index(j)] = instance.valueSparse(j);
  601.   }
  602.   else {
  603.     if (instance.valueSparse(j) > m_maxArray[instance.index(j)]) {
  604.       m_maxArray[instance.index(j)] = instance.valueSparse(j);
  605.     }
  606.   }
  607. }
  608. }
  609.       }
  610.     } catch (Exception ex) {
  611.       System.err.println(ex);
  612.       ex.printStackTrace();
  613.     }
  614.   }
  615.   /**
  616.    * Computes the difference between two given attribute
  617.    * values.
  618.    */
  619.   private double difference(int index, double val1, double val2) {
  620.     switch (m_trainInstances.attribute(index).type()) {
  621.     case Attribute.NOMINAL:
  622.       
  623.       // If attribute is nominal
  624.       if (Instance.isMissingValue(val1) || 
  625.   Instance.isMissingValue(val2)) {
  626. return (1.0 - (1.0/((double)m_trainInstances.
  627.     attribute(index).numValues())));
  628.       } else if ((int)val1 != (int)val2) {
  629. return 1;
  630.       } else {
  631. return 0;
  632.       }
  633.     case Attribute.NUMERIC:
  634.       // If attribute is numeric
  635.       if (Instance.isMissingValue(val1) || 
  636.   Instance.isMissingValue(val2)) {
  637. if (Instance.isMissingValue(val1) && 
  638.     Instance.isMissingValue(val2)) {
  639.   return 1;
  640. } else {
  641.   double diff;
  642.   if (Instance.isMissingValue(val2)) {
  643.     diff = norm(val1, index);
  644.   } else {
  645.     diff = norm(val2, index);
  646.   }
  647.   if (diff < 0.5) {
  648.     diff = 1.0 - diff;
  649.   }
  650.   return diff;
  651. }
  652.       } else {
  653. return Math.abs(norm(val1, index) - norm(val2, index));
  654.       }
  655.     default:
  656.       return 0;
  657.     }
  658.   }
  659.   /**
  660.    * Calculates the distance between two instances
  661.    *
  662.    * @param test the first instance
  663.    * @param train the second instance
  664.    * @return the distance between the two given instances, between 0 and 1
  665.    */          
  666.   private double distance(Instance first, Instance second) {  
  667.     double distance = 0;
  668.     int firstI, secondI;
  669.     for (int p1 = 0, p2 = 0; 
  670.  p1 < first.numValues() || p2 < second.numValues();) {
  671.       if (p1 >= first.numValues()) {
  672. firstI = m_trainInstances.numAttributes();
  673.       } else {
  674. firstI = first.index(p1); 
  675.       }
  676.       if (p2 >= second.numValues()) {
  677. secondI = m_trainInstances.numAttributes();
  678.       } else {
  679. secondI = second.index(p2);
  680.       }
  681.       if (firstI == m_trainInstances.classIndex()) {
  682. p1++; continue;
  683.       } 
  684.       if (secondI == m_trainInstances.classIndex()) {
  685. p2++; continue;
  686.       } 
  687.       double diff;
  688.       if (firstI == secondI) {
  689. diff = difference(firstI, 
  690.   first.valueSparse(p1),
  691.   second.valueSparse(p2));
  692. p1++; p2++;
  693.       } else if (firstI > secondI) {
  694. diff = difference(secondI, 
  695.   0, second.valueSparse(p2));
  696. p2++;
  697.       } else {
  698. diff = difference(firstI, 
  699.   first.valueSparse(p1), 0);
  700. p1++;
  701.       }
  702.       //      distance += diff * diff;
  703.       distance += diff;
  704.     }
  705.     
  706.     //    return Math.sqrt(distance / m_NumAttributesUsed);
  707.     return distance;
  708.   }
  709.   /**
  710.    * update attribute weights given an instance when the class is numeric
  711.    *
  712.    * @param instNum the index of the instance to use when updating weights
  713.    */
  714.   private void updateWeightsNumericClass (int instNum) {
  715.     int i, j;
  716.     double temp,temp2;
  717.     int[] tempSorted = null;
  718.     double[] tempDist = null;
  719.     double distNorm = 1.0;
  720.     int firstI, secondI;
  721.     Instance inst = m_trainInstances.instance(instNum);
  722.    
  723.     // sort nearest neighbours and set up normalization variable
  724.     if (m_weightByDistance) {
  725.       tempDist = new double[m_stored[0]];
  726.       for (j = 0, distNorm = 0; j < m_stored[0]; j++) {
  727. // copy the distances
  728. tempDist[j] = m_karray[0][j][0];
  729. // sum normalizer
  730. distNorm += m_weightsByRank[j];
  731.       }
  732.       tempSorted = Utils.sort(tempDist);
  733.     }
  734.     for (i = 0; i < m_stored[0]; i++) {
  735.       // P diff prediction (class) given nearest instances
  736.       if (m_weightByDistance) {
  737. temp = difference(m_classIndex, 
  738.   inst.value(m_classIndex),
  739.   m_trainInstances.
  740.   instance((int)m_karray[0][tempSorted[i]][1]).
  741.   value(m_classIndex));
  742. temp *= (m_weightsByRank[i]/distNorm);
  743.       }
  744.       else {
  745. temp = difference(m_classIndex, 
  746.   inst.value(m_classIndex), 
  747.   m_trainInstances.
  748.   instance((int)m_karray[0][i][1]).
  749.   value(m_classIndex));
  750. temp *= (1.0/(double)m_stored[0]); // equal influence
  751.       }
  752.       m_ndc += temp;
  753.       Instance cmp;
  754.       cmp = (m_weightByDistance) 
  755. ? m_trainInstances.instance((int)m_karray[0][tempSorted[i]][1])
  756. : m_trainInstances.instance((int)m_karray[0][i][1]);
  757.  
  758.       double temp_diffP_diffA_givNearest = 
  759. difference(m_classIndex, inst.value(m_classIndex),
  760.    cmp.value(m_classIndex));
  761.       // now the attributes
  762.       for (int p1 = 0, p2 = 0; 
  763.    p1 < inst.numValues() || p2 < cmp.numValues();) {
  764. if (p1 >= inst.numValues()) {
  765.   firstI = m_trainInstances.numAttributes();
  766. } else {
  767.   firstI = inst.index(p1); 
  768. }
  769. if (p2 >= cmp.numValues()) {
  770.   secondI = m_trainInstances.numAttributes();
  771. } else {
  772.   secondI = cmp.index(p2);
  773. }
  774. if (firstI == m_trainInstances.classIndex()) {
  775.   p1++; continue;
  776. if (secondI == m_trainInstances.classIndex()) {
  777.   p2++; continue;
  778. temp = 0.0;
  779. temp2 = 0.0;
  780.       
  781. if (firstI == secondI) {
  782.   j = firstI;
  783.   temp = difference(j, inst.valueSparse(p1), cmp.valueSparse(p2)); 
  784.   p1++;p2++;
  785. } else if (firstI > secondI) {
  786.   j = secondI;
  787.   temp = difference(j, 0, cmp.valueSparse(p2));
  788.   p2++;
  789. } else {
  790.   j = firstI;
  791.   temp = difference(j, inst.valueSparse(p1), 0);
  792.   p1++;
  793.        
  794. temp2 = temp_diffP_diffA_givNearest * temp; 
  795. // P of different prediction and different att value given
  796. // nearest instances
  797. if (m_weightByDistance) {
  798.   temp2 *= (m_weightsByRank[i]/distNorm);
  799. }
  800. else {
  801.   temp2 *= (1.0/(double)m_stored[0]); // equal influence
  802. }
  803. m_ndcda[j] += temp2;
  804.        
  805. // P of different attribute val given nearest instances
  806. if (m_weightByDistance) {
  807.   temp *= (m_weightsByRank[i]/distNorm);
  808. }
  809. else {
  810.   temp *= (1.0/(double)m_stored[0]); // equal influence
  811. }
  812. m_nda[j] += temp;
  813.       }
  814.     }
  815.   }
  816.   /**
  817.    * update attribute weights given an instance when the class is discrete
  818.    *
  819.    * @param instNum the index of the instance to use when updating weights
  820.    */
  821.   private void updateWeightsDiscreteClass (int instNum) {
  822.     int i, j, k;
  823.     int cl;
  824.     double cc = m_numInstances;
  825.     double temp, temp_diff, w_norm = 1.0;
  826.     double[] tempDistClass;
  827.     int[] tempSortedClass = null;
  828.     double distNormClass = 1.0;
  829.     double[] tempDistAtt;
  830.     int[][] tempSortedAtt = null;
  831.     double[] distNormAtt = null;
  832.     int firstI, secondI;
  833.     // store the indexes (sparse instances) of non-zero elements
  834.     Instance inst = m_trainInstances.instance(instNum);
  835.     // get the class of this instance
  836.     cl = (int)m_trainInstances.instance(instNum).value(m_classIndex);
  837.     // sort nearest neighbours and set up normalization variables
  838.     if (m_weightByDistance) {
  839.       // do class (hits) first
  840.       // sort the distances
  841.       tempDistClass = new double[m_stored[cl]];
  842.       for (j = 0, distNormClass = 0; j < m_stored[cl]; j++) {
  843. // copy the distances
  844. tempDistClass[j] = m_karray[cl][j][0];
  845. // sum normalizer
  846. distNormClass += m_weightsByRank[j];
  847.       }
  848.       tempSortedClass = Utils.sort(tempDistClass);
  849.       // do misses (other classes)
  850.       tempSortedAtt = new int[m_numClasses][1];
  851.       distNormAtt = new double[m_numClasses];
  852.       for (k = 0; k < m_numClasses; k++) {
  853. if (k != cl) // already done cl
  854.   {
  855.     // sort the distances
  856.     tempDistAtt = new double[m_stored[k]];
  857.     for (j = 0, distNormAtt[k] = 0; j < m_stored[k]; j++) {
  858.       // copy the distances
  859.       tempDistAtt[j] = m_karray[k][j][0];
  860.       // sum normalizer
  861.       distNormAtt[k] += m_weightsByRank[j];
  862.     }
  863.     tempSortedAtt[k] = Utils.sort(tempDistAtt);
  864.   }
  865.       }
  866.     }
  867.     if (m_numClasses > 2) {
  868.       // the amount of probability space left after removing the
  869.       // probability of this instance's class value
  870.       w_norm = (1.0 - m_classProbs[cl]);
  871.     }
  872.     
  873.     // do the k nearest hits of the same class
  874.     for (j = 0, temp_diff = 0.0; j < m_stored[cl]; j++) {
  875.       Instance cmp;
  876.       cmp = (m_weightByDistance) 
  877. ? m_trainInstances.
  878. instance((int)m_karray[cl][tempSortedClass[j]][1])
  879. : m_trainInstances.instance((int)m_karray[cl][j][1]);
  880.       for (int p1 = 0, p2 = 0; 
  881.    p1 < inst.numValues() || p2 < cmp.numValues();) {
  882. if (p1 >= inst.numValues()) {
  883.   firstI = m_trainInstances.numAttributes();
  884. } else {
  885.   firstI = inst.index(p1); 
  886. }
  887. if (p2 >= cmp.numValues()) {
  888.   secondI = m_trainInstances.numAttributes();
  889. } else {
  890.   secondI = cmp.index(p2);
  891. }
  892. if (firstI == m_trainInstances.classIndex()) {
  893.   p1++; continue;
  894. if (secondI == m_trainInstances.classIndex()) {
  895.   p2++; continue;
  896. if (firstI == secondI) {
  897.   i = firstI;
  898.   temp_diff = difference(i, inst.valueSparse(p1), 
  899.  cmp.valueSparse(p2)); 
  900.   p1++;p2++;
  901. } else if (firstI > secondI) {
  902.   i = secondI;
  903.   temp_diff = difference(i, 0, cmp.valueSparse(p2));
  904.   p2++;
  905. } else {
  906.   i = firstI;
  907.   temp_diff = difference(i, inst.valueSparse(p1), 0);
  908.   p1++;
  909. if (m_weightByDistance) {
  910.   temp_diff *=
  911.     (m_weightsByRank[j]/distNormClass);
  912. } else {
  913.   if (m_stored[cl] > 0) {
  914.     temp_diff /= (double)m_stored[cl];
  915.   }
  916. }
  917. m_weights[i] -= temp_diff;
  918.       }
  919.     }
  920.       
  921.     // now do k nearest misses from each of the other classes
  922.     temp_diff = 0.0;
  923.     for (k = 0; k < m_numClasses; k++) {
  924.       if (k != cl) // already done cl
  925. {
  926.   for (j = 0, temp = 0.0; j < m_stored[k]; j++) {
  927.     Instance cmp;
  928.     cmp = (m_weightByDistance) 
  929.       ? m_trainInstances.
  930.       instance((int)m_karray[k][tempSortedAtt[k][j]][1])
  931.       : m_trainInstances.instance((int)m_karray[k][j][1]);
  932.     for (int p1 = 0, p2 = 0; 
  933.  p1 < inst.numValues() || p2 < cmp.numValues();) {
  934.       if (p1 >= inst.numValues()) {
  935. firstI = m_trainInstances.numAttributes();
  936.       } else {
  937. firstI = inst.index(p1); 
  938.       }
  939.       if (p2 >= cmp.numValues()) {
  940. secondI = m_trainInstances.numAttributes();
  941.       } else {
  942. secondI = cmp.index(p2);
  943.       }
  944.       if (firstI == m_trainInstances.classIndex()) {
  945. p1++; continue;
  946.       } 
  947.       if (secondI == m_trainInstances.classIndex()) {
  948. p2++; continue;
  949.       } 
  950.       if (firstI == secondI) {
  951. i = firstI;
  952. temp_diff = difference(i, inst.valueSparse(p1), 
  953.        cmp.valueSparse(p2)); 
  954. p1++;p2++;
  955.       } else if (firstI > secondI) {
  956. i = secondI;
  957. temp_diff = difference(i, 0, cmp.valueSparse(p2));
  958. p2++;
  959.       } else {
  960. i = firstI;
  961. temp_diff = difference(i, inst.valueSparse(p1), 0);
  962. p1++;
  963.       } 
  964.       if (m_weightByDistance) {
  965. temp_diff *=
  966.   (m_weightsByRank[j]/distNormAtt[k]);
  967.       }
  968.       else {
  969. if (m_stored[k] > 0) {
  970.   temp_diff /= (double)m_stored[k];
  971. }
  972.       }
  973.       if (m_numClasses > 2) {
  974. m_weights[i] += ((m_classProbs[k]/w_norm)*temp_diff);
  975.       } else {
  976. m_weights[i] += temp_diff;
  977.       }
  978.     }
  979.   }
  980. }
  981.     }
  982.   }
  983.   /**
  984.    * Find the K nearest instances to supplied instance if the class is numeric,
  985.    * or the K nearest Hits (same class) and Misses (K from each of the other
  986.    * classes) if the class is discrete.
  987.    *
  988.    * @param instNum the index of the instance to find nearest neighbours of
  989.    */
  990.   private void findKHitMiss (int instNum) {
  991.     int i, j;
  992.     int cl;
  993.     double ww;
  994.     double temp_diff = 0.0;
  995.     Instance thisInst = m_trainInstances.instance(instNum);
  996.     for (i = 0; i < m_numInstances; i++) {
  997.       if (i != instNum) {
  998. Instance cmpInst = m_trainInstances.instance(i);
  999. temp_diff = distance(cmpInst, thisInst);
  1000. // class of this training instance or 0 if numeric
  1001. if (m_numericClass) {
  1002.   cl = 0;
  1003. }
  1004. else {
  1005.   cl = (int)m_trainInstances.instance(i).value(m_classIndex);
  1006. }
  1007. // add this diff to the list for the class of this instance
  1008. if (m_stored[cl] < m_Knn) {
  1009.   m_karray[cl][m_stored[cl]][0] = temp_diff;
  1010.   m_karray[cl][m_stored[cl]][1] = i;
  1011.   m_stored[cl]++;
  1012.   // note the worst diff for this class
  1013.   for (j = 0, ww = -1.0; j < m_stored[cl]; j++) {
  1014.     if (m_karray[cl][j][0] > ww) {
  1015.       ww = m_karray[cl][j][0];
  1016.       m_index[cl] = j;
  1017.     }
  1018.   }
  1019.   m_worst[cl] = ww;
  1020. }
  1021. else 
  1022.   /* if we already have stored knn for this class then check to
  1023.      see if this instance is better than the worst */
  1024.   {
  1025.     if (temp_diff < m_karray[cl][m_index[cl]][0]) {
  1026.       m_karray[cl][m_index[cl]][0] = temp_diff;
  1027.       m_karray[cl][m_index[cl]][1] = i;
  1028.       for (j = 0, ww = -1.0; j < m_stored[cl]; j++) {
  1029. if (m_karray[cl][j][0] > ww) {
  1030.   ww = m_karray[cl][j][0];
  1031.   m_index[cl] = j;
  1032. }
  1033.       }
  1034.       m_worst[cl] = ww;
  1035.     }
  1036.   }
  1037.       }
  1038.     }
  1039.   }
  1040.   // ============
  1041.   // Test method.
  1042.   // ============
  1043.   /**
  1044.    * Main method for testing this class.
  1045.    *
  1046.    * @param args the options
  1047.    */
  1048.   public static void main (String[] args) {
  1049.     try {
  1050.       System.out.println(AttributeSelection.
  1051.  SelectAttributes(new ReliefFAttributeEval(), args));
  1052.     }
  1053.     catch (Exception e) {
  1054.       e.printStackTrace();
  1055.       System.out.println(e.getMessage());
  1056.     }
  1057.   }
  1058. }