KernelEstimator.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 8k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    KernelEstimator.java
  18.  *    Copyright (C) 1999 Len Trigg
  19.  *
  20.  */
  21. package weka.estimators;
  22. import java.util.*;
  23. import weka.core.*;
  24. /** 
  25.  * Simple kernel density estimator. Uses one gaussian kernel per observed
  26.  * data value.
  27.  *
  28.  * @author Len Trigg (trigg@cs.waikato.ac.nz)
  29.  * @version $Revision: 1.4 $
  30.  */
  31. public class KernelEstimator implements Estimator {
  32.   /** Vector containing all of the values seen */
  33.   private double [] m_Values;
  34.   /** Vector containing the associated weights */
  35.   private double [] m_Weights;
  36.   /** Number of values stored in m_Weights and m_Values so far */
  37.   private int m_NumValues;
  38.   /** The sum of the weights so far */
  39.   private double m_SumOfWeights;
  40.   /** The standard deviation */
  41.   private double m_StandardDev;
  42.   /** The precision of data values */
  43.   private double m_Precision;
  44.   /** Whether we can optimise the kernel summation */
  45.   private boolean m_AllWeightsOne;
  46.   /** Maximum percentage error permitted in probability calculations */
  47.   private static double MAX_ERROR = 0.01;
  48.   /**
  49.    * Execute a binary search to locate the nearest data value
  50.    *
  51.    * @param the data value to locate
  52.    * @return the index of the nearest data value
  53.    */
  54.   private int findNearestValue(double key) {
  55.     int low = 0; 
  56.     int high = m_NumValues;
  57.     int middle = 0;
  58.     while (low < high) {
  59.       middle = (low + high) / 2;
  60.       double current = m_Values[middle];
  61.       if (current == key) {
  62. return middle;
  63.       }
  64.       if (current > key) {
  65. high = middle;
  66.       } else if (current < key) {
  67. low = middle + 1;
  68.       }
  69.     }
  70.     return low;
  71.   }
  72.   /**
  73.    * Round a data value using the defined precision for this estimator
  74.    *
  75.    * @param data the value to round
  76.    * @return the rounded data value
  77.    */
  78.   private double round(double data) {
  79.     return Math.rint(data / m_Precision) * m_Precision;
  80.   }
  81.   
  82.   // ===============
  83.   // Public methods.
  84.   // ===============
  85.   
  86.   /**
  87.    * Constructor that takes a precision argument.
  88.    *
  89.    * @param precision the  precision to which numeric values are given. For
  90.    * example, if the precision is stated to be 0.1, the values in the
  91.    * interval (0.25,0.35] are all treated as 0.3. 
  92.    */
  93.   public KernelEstimator(double precision) {
  94.     m_Values = new double [50];
  95.     m_Weights = new double [50];
  96.     m_NumValues = 0;
  97.     m_SumOfWeights = 0;
  98.     m_AllWeightsOne = true;
  99.     m_Precision = precision;
  100.     //    m_StandardDev = 1e10 * m_Precision; // Set the standard deviation initially very wide
  101.     m_StandardDev = m_Precision / (2 * 3);
  102.   }
  103.   /**
  104.    * Add a new data value to the current estimator.
  105.    *
  106.    * @param data the new data value 
  107.    * @param weight the weight assigned to the data value 
  108.    */
  109.   public void addValue(double data, double weight) {
  110.     if (weight == 0) {
  111.       return;
  112.     }
  113.     data = round(data);
  114.     int insertIndex = findNearestValue(data);
  115.     if ((m_NumValues <= insertIndex) || (m_Values[insertIndex] != data)) {
  116.       if (m_NumValues < m_Values.length) {
  117. int left = m_NumValues - insertIndex; 
  118. System.arraycopy(m_Values, insertIndex, 
  119.  m_Values, insertIndex + 1, left);
  120. System.arraycopy(m_Weights, insertIndex, 
  121.  m_Weights, insertIndex + 1, left);
  122. m_Values[insertIndex] = data;
  123. m_Weights[insertIndex] = weight;
  124. m_NumValues++;
  125.       } else {
  126. double [] newValues = new double [m_Values.length * 2];
  127. double [] newWeights = new double [m_Values.length * 2];
  128. int left = m_NumValues - insertIndex; 
  129. System.arraycopy(m_Values, 0, newValues, 0, insertIndex);
  130. System.arraycopy(m_Weights, 0, newWeights, 0, insertIndex);
  131. newValues[insertIndex] = data;
  132. newWeights[insertIndex] = weight;
  133. System.arraycopy(m_Values, insertIndex, 
  134.  newValues, insertIndex + 1, left);
  135. System.arraycopy(m_Weights, insertIndex, 
  136.  newWeights, insertIndex + 1, left);
  137. m_NumValues++;
  138. m_Values = newValues;
  139. m_Weights = newWeights;
  140.       }
  141.       if (weight != 1) {
  142. m_AllWeightsOne = false;
  143.       }
  144.     } else {
  145.       m_Weights[insertIndex] += weight;
  146.       m_AllWeightsOne = false;      
  147.     }
  148.     m_SumOfWeights += weight;
  149.     double range = m_Values[m_NumValues - 1] - m_Values[0];
  150.     if (range > 0) {
  151.       m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights), 
  152.        // allow at most 3 sds within one interval
  153.        m_Precision / (2 * 3));
  154.     }
  155.   }
  156.   /**
  157.    * Get a probability estimate for a value.
  158.    *
  159.    * @param data the value to estimate the probability of
  160.    * @return the estimated probability of the supplied value
  161.    */
  162.   public double getProbability(double data) {
  163.     double delta = 0, sum = 0, currentProb = 0;
  164.     double zLower = 0, zUpper = 0;
  165.     if (m_NumValues == 0) {
  166.       zLower = (data - (m_Precision / 2)) / m_StandardDev;
  167.       zUpper = (data + (m_Precision / 2)) / m_StandardDev;
  168.       return (Statistics.normalProbability(zUpper)
  169.       - Statistics.normalProbability(zLower));
  170.     }
  171.     double weightSum = 0;
  172.     int start = findNearestValue(data);
  173.     for (int i = start; i < m_NumValues; i++) {
  174.       delta = m_Values[i] - data;
  175.       zLower = (delta - (m_Precision / 2)) / m_StandardDev;
  176.       zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
  177.       currentProb = (Statistics.normalProbability(zUpper)
  178.      - Statistics.normalProbability(zLower));
  179.       sum += currentProb * m_Weights[i];
  180.       /*
  181.       System.out.print("zL" + (i + 1) + ": " + zLower + " ");
  182.       System.out.print("zU" + (i + 1) + ": " + zUpper + " ");
  183.       System.out.print("P" + (i + 1) + ": " + currentProb + " ");
  184.       System.out.println("total: " + (currentProb * m_Weights[i]) + " ");
  185.       */
  186.       weightSum += m_Weights[i];
  187.       if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
  188. break;
  189.       }
  190.     }
  191.     for (int i = start - 1; i >= 0; i--) {
  192.       delta = m_Values[i] - data;
  193.       zLower = (delta - (m_Precision / 2)) / m_StandardDev;
  194.       zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
  195.       currentProb = (Statistics.normalProbability(zUpper)
  196.      - Statistics.normalProbability(zLower));
  197.       sum += currentProb * m_Weights[i];
  198.       weightSum += m_Weights[i];
  199.       if (currentProb * (m_SumOfWeights - weightSum) < sum * MAX_ERROR) {
  200. break;
  201.       }
  202.     }
  203.     return sum / m_SumOfWeights;
  204.   }
  205.   /** Display a representation of this estimator */
  206.   public String toString() {
  207.     String result = m_NumValues + " Normal Kernels. nStandardDev = " 
  208.       + Utils.doubleToString(m_StandardDev,6,4)
  209.       + " Precision = " + m_Precision;
  210.     if (m_NumValues == 0) {
  211.       result += "  nMean = 0";
  212.     } else {
  213.       result += "  nMeans =";
  214.       for (int i = 0; i < m_NumValues; i++) {
  215. result += " " + m_Values[i];
  216.       }
  217.       if (!m_AllWeightsOne) {
  218. result += "nWeights = ";
  219. for (int i = 0; i < m_NumValues; i++) {
  220.   result += " " + m_Weights[i];
  221. }
  222.       }
  223.     }
  224.     return result + "n";
  225.   }
  226.   /**
  227.    * Main method for testing this class.
  228.    *
  229.    * @param argv should contain a sequence of numeric values
  230.    */
  231.   public static void main(String [] argv) {
  232.     try {
  233.       if (argv.length < 2) {
  234. System.out.println("Please specify a set of instances.");
  235. return;
  236.       }
  237.       KernelEstimator newEst = new KernelEstimator(0.01);
  238.       for (int i = 0; i < argv.length - 3; i += 2) {
  239. newEst.addValue(Double.valueOf(argv[i]).doubleValue(), 
  240. Double.valueOf(argv[i + 1]).doubleValue());
  241.       }
  242.       System.out.println(newEst);
  243.       double start = Double.valueOf(argv[argv.length - 2]).doubleValue();
  244.       double finish = Double.valueOf(argv[argv.length - 1]).doubleValue();
  245.       for (double current = start; current < finish; 
  246.   current += (finish - start) / 50) {
  247. System.out.println("Data: " + current + " " 
  248.    + newEst.getProbability(current));
  249.       }
  250.     } catch (Exception e) {
  251.       System.out.println(e.getMessage());
  252.     }
  253.   }
  254. }