SimpleKMeans.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 13k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    SimpleKMeans.java
  18.  *    Copyright (C) 2000 Mark Hall
  19.  *
  20.  */
  21. package weka.clusterers;
  22. import  java.io.*;
  23. import  java.util.*;
  24. import  weka.core.*;
  25. import  weka.filters.Filter;
  26. import  weka.filters.ReplaceMissingValuesFilter;
  27. /**
  28.  * Simple k means clustering class.
  29.  *
  30.  * Valid options are:<p>
  31.  *
  32.  * -N <number of clusters> <br>
  33.  * Specify the number of clusters to generate. <p>
  34.  *
  35.  * -S <seed> <br>
  36.  * Specify random number seed. <p>
  37.  *
  38.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  39.  * @version $Revision: 1.6 $
  40.  * @see Clusterer
  41.  * @see OptionHandler
  42.  */
  43. public class SimpleKMeans extends Clusterer implements OptionHandler {
  44.   /**
  45.    * training instances
  46.    */
  47.   private Instances m_instances;
  48.   /**
  49.    * replace missing values in training instances
  50.    */
  51.   private ReplaceMissingValuesFilter m_ReplaceMissingFilter;
  52.   /**
  53.    * number of clusters to generate
  54.    */
  55.   private int m_NumClusters = 2;
  56.   /**
  57.    * holds the cluster centroids
  58.    */
  59.   private Instances m_ClusterCentroids;
  60.   /**
  61.    * temporary variable holding cluster assignments while iterating
  62.    */
  63.   private int [] m_ClusterAssignments;
  64.   /**
  65.    * random seed
  66.    */
  67.   private int m_Seed = 10;
  68.   /**
  69.    * attribute min values
  70.    */
  71.   private double [] m_Min;
  72.   
  73.   /**
  74.    * attribute max values
  75.    */
  76.   private double [] m_Max;
  77.   /**
  78.    * Returns a string describing this clusterer
  79.    * @return a description of the evaluator suitable for
  80.    * displaying in the explorer/experimenter gui
  81.    */
  82.   public String globalInfo() {
  83.     return "Cluster data using the k means algorithm";
  84.   }
  85.   /**
  86.    * Generates a clusterer. Has to initialize all fields of the clusterer
  87.    * that are not being set via options.
  88.    *
  89.    * @param data set of instances serving as training data 
  90.    * @exception Exception if the clusterer has not been 
  91.    * generated successfully
  92.    */
  93.   public void buildClusterer(Instances data) throws Exception {
  94.     
  95.     if (data.checkForStringAttributes()) {
  96.       throw  new Exception("Can't handle string attributes!");
  97.     }
  98.     m_ReplaceMissingFilter = new ReplaceMissingValuesFilter();
  99.     m_ReplaceMissingFilter.setInputFormat(data);
  100.     m_instances = Filter.useFilter(data, m_ReplaceMissingFilter);
  101.     m_Min = new double [m_instances.numAttributes()];
  102.     m_Max = new double [m_instances.numAttributes()];
  103.     for (int i = 0; i < m_instances.numAttributes(); i++) {
  104.       m_Min[i] = m_Max[i] = Double.NaN;
  105.     }
  106.     for (int i = 0; i < m_instances.numInstances(); i++) {
  107.       updateMinMax(m_instances.instance(i));
  108.     }
  109.     
  110.     m_ClusterCentroids = new Instances(m_instances, m_NumClusters);
  111.     m_ClusterAssignments = new int [m_instances.numInstances()];
  112.     Random RandomO = new Random(m_Seed);
  113.     for (int i = 0; i < m_NumClusters; i++) {
  114.       int instIndex = Math.abs(RandomO.nextInt()) % m_instances.numInstances();
  115.       m_ClusterCentroids.add(m_instances.instance(instIndex));
  116.     }
  117.     boolean converged = false;
  118.     while (!converged) {
  119.       converged = true;
  120.       for (int i = 0; i < m_instances.numInstances(); i++) {
  121. Instance toCluster = m_instances.instance(i);
  122. int newC = clusterProcessedInstance(toCluster);
  123. if (newC != m_ClusterAssignments[i]) {
  124.   converged = false;
  125. }
  126. m_ClusterAssignments[i] = newC;
  127.       }
  128.       
  129.       Instances [] tempI = new Instances[m_NumClusters];
  130.       // update centroids
  131.       m_ClusterCentroids = new Instances(m_instances, m_NumClusters);
  132.       for (int i = 0; i < m_NumClusters; i++) {
  133. tempI[i] = new Instances(m_instances, 0);
  134.       }
  135.       for (int i = 0; i < m_instances.numInstances(); i++) {
  136. tempI[m_ClusterAssignments[i]].add(m_instances.instance(i));
  137.       }
  138.       for (int i = 0; i < m_NumClusters; i++) {
  139. double [] vals = new double[m_instances.numAttributes()];
  140. for (int j = 0; j < m_instances.numAttributes(); j++) {
  141.   vals[j] = tempI[i].meanOrMode(j);
  142. }
  143. m_ClusterCentroids.add(new Instance(1.0, vals));
  144.       }
  145.     }
  146.   }
  147.   /**
  148.    * clusters an instance that has been through the filters
  149.    *
  150.    * @param instance the instance to assign a cluster to
  151.    * @return a cluster number
  152.    */
  153.   private int clusterProcessedInstance(Instance instance) {
  154.     double minDist = Integer.MAX_VALUE;
  155.     int bestCluster = 0;
  156.     for (int i = 0; i < m_NumClusters; i++) {
  157.       double dist = distance(instance, m_ClusterCentroids.instance(i));
  158.       if (dist < minDist) {
  159. minDist = dist;
  160. bestCluster = i;
  161.       }
  162.     }
  163.     return bestCluster;
  164.   }
  165.   /**
  166.    * Classifies a given instance.
  167.    *
  168.    * @param instance the instance to be assigned to a cluster
  169.    * @return the number of the assigned cluster as an interger
  170.    * if the class is enumerated, otherwise the predicted value
  171.    * @exception Exception if instance could not be classified
  172.    * successfully
  173.    */
  174.   public int clusterInstance(Instance instance) throws Exception {
  175.     m_ReplaceMissingFilter.input(instance);
  176.     m_ReplaceMissingFilter.batchFinished();
  177.     Instance inst = m_ReplaceMissingFilter.output();
  178.     return clusterProcessedInstance(inst);
  179.   }
  180.   /**
  181.    * Calculates the distance between two instances
  182.    *
  183.    * @param test the first instance
  184.    * @param train the second instance
  185.    * @return the distance between the two given instances, between 0 and 1
  186.    */          
  187.   private double distance(Instance first, Instance second) {  
  188.     double distance = 0;
  189.     int firstI, secondI;
  190.     for (int p1 = 0, p2 = 0; 
  191.  p1 < first.numValues() || p2 < second.numValues();) {
  192.       if (p1 >= first.numValues()) {
  193. firstI = m_instances.numAttributes();
  194.       } else {
  195. firstI = first.index(p1); 
  196.       }
  197.       if (p2 >= second.numValues()) {
  198. secondI = m_instances.numAttributes();
  199.       } else {
  200. secondI = second.index(p2);
  201.       }
  202.       if (firstI == m_instances.classIndex()) {
  203. p1++; continue;
  204.       } 
  205.       if (secondI == m_instances.classIndex()) {
  206. p2++; continue;
  207.       } 
  208.       double diff;
  209.       if (firstI == secondI) {
  210. diff = difference(firstI, 
  211.   first.valueSparse(p1),
  212.   second.valueSparse(p2));
  213. p1++; p2++;
  214.       } else if (firstI > secondI) {
  215. diff = difference(secondI, 
  216.   0, second.valueSparse(p2));
  217. p2++;
  218.       } else {
  219. diff = difference(firstI, 
  220.   first.valueSparse(p1), 0);
  221. p1++;
  222.       }
  223.       distance += diff * diff;
  224.     }
  225.     
  226.     return Math.sqrt(distance / m_instances.numAttributes());
  227.   }
  228.   /**
  229.    * Computes the difference between two given attribute
  230.    * values.
  231.    */
  232.   private double difference(int index, double val1, double val2) {
  233.     switch (m_instances.attribute(index).type()) {
  234.     case Attribute.NOMINAL:
  235.       
  236.       // If attribute is nominal
  237.       if (Instance.isMissingValue(val1) || 
  238.   Instance.isMissingValue(val2) ||
  239.   ((int)val1 != (int)val2)) {
  240. return 1;
  241.       } else {
  242. return 0;
  243.       }
  244.     case Attribute.NUMERIC:
  245.       // If attribute is numeric
  246.       if (Instance.isMissingValue(val1) || 
  247.   Instance.isMissingValue(val2)) {
  248. if (Instance.isMissingValue(val1) && 
  249.     Instance.isMissingValue(val2)) {
  250.   return 1;
  251. } else {
  252.   double diff;
  253.   if (Instance.isMissingValue(val2)) {
  254.     diff = norm(val1, index);
  255.   } else {
  256.     diff = norm(val2, index);
  257.   }
  258.   if (diff < 0.5) {
  259.     diff = 1.0 - diff;
  260.   }
  261.   return diff;
  262. }
  263.       } else {
  264. return norm(val1, index) - norm(val2, index);
  265.       }
  266.     default:
  267.       return 0;
  268.     }
  269.   }
  270.   /**
  271.    * Normalizes a given value of a numeric attribute.
  272.    *
  273.    * @param x the value to be normalized
  274.    * @param i the attribute's index
  275.    */
  276.   private double norm(double x, int i) {
  277.     if (Double.isNaN(m_Min[i]) || Utils.eq(m_Max[i],m_Min[i])) {
  278.       return 0;
  279.     } else {
  280.       return (x - m_Min[i]) / (m_Max[i] - m_Min[i]);
  281.     }
  282.   }
  283.   /**
  284.    * Updates the minimum and maximum values for all the attributes
  285.    * based on a new instance.
  286.    *
  287.    * @param instance the new instance
  288.    */
  289.   private void updateMinMax(Instance instance) {  
  290.     for (int j = 0;j < m_instances.numAttributes(); j++) {
  291.       if (!instance.isMissing(j)) {
  292. if (Double.isNaN(m_Min[j])) {
  293.   m_Min[j] = instance.value(j);
  294.   m_Max[j] = instance.value(j);
  295. } else {
  296.   if (instance.value(j) < m_Min[j]) {
  297.     m_Min[j] = instance.value(j);
  298.   } else {
  299.     if (instance.value(j) > m_Max[j]) {
  300.       m_Max[j] = instance.value(j);
  301.     }
  302.   }
  303. }
  304.       }
  305.     }
  306.   }
  307.   
  308.   /**
  309.    * Returns the number of clusters.
  310.    *
  311.    * @return the number of clusters generated for a training dataset.
  312.    * @exception Exception if number of clusters could not be returned
  313.    * successfully
  314.    */
  315.   public int numberOfClusters() throws Exception {
  316.     return m_NumClusters;
  317.   }
  318.   /**
  319.    * Returns an enumeration describing the available options. <p>
  320.    *
  321.    * Valid options are:<p>
  322.    *
  323.    * -N <number of clusters> <br>
  324.    * Specify the number of clusters to generate. If omitted,
  325.    * EM will use cross validation to select the number of clusters
  326.    * automatically. <p>
  327.    *
  328.    * -S <seed> <br>
  329.    * Specify random number seed. <p>
  330.    *
  331.    * @return an enumeration of all the available options
  332.    *
  333.    **/
  334.   public Enumeration listOptions () {
  335.     Vector newVector = new Vector(2);
  336.      newVector.addElement(new Option("tnumber of clusters. (default = 2)." 
  337.     , "N", 1, "-N <num>"));
  338.      newVector.addElement(new Option("trandom number seed.n (default 10)"
  339.      , "S", 1, "-S <num>"));
  340.      return  newVector.elements();
  341.   }
  342.   /**
  343.    * Returns the tip text for this property
  344.    * @return tip text for this property suitable for
  345.    * displaying in the explorer/experimenter gui
  346.    */
  347.   public String numClustersTipText() {
  348.     return "set number of clusters";
  349.   }
  350.   /**
  351.    * set the number of clusters to generate
  352.    *
  353.    * @param n the number of clusters to generate
  354.    */
  355.   public void setNumClusters(int n) {
  356.     m_NumClusters = n;
  357.   }
  358.   /**
  359.    * gets the number of clusters to generate
  360.    *
  361.    * @return the number of clusters to generate
  362.    */
  363.   public int getNumClusters() {
  364.     return m_NumClusters;
  365.   }
  366.     
  367.   /**
  368.    * Returns the tip text for this property
  369.    * @return tip text for this property suitable for
  370.    * displaying in the explorer/experimenter gui
  371.    */
  372.   public String seedTipText() {
  373.     return "random number seed";
  374.   }
  375.   /**
  376.    * Set the random number seed
  377.    *
  378.    * @param s the seed
  379.    */
  380.   public void setSeed (int s) {
  381.     m_Seed = s;
  382.   }
  383.   /**
  384.    * Get the random number seed
  385.    *
  386.    * @return the seed
  387.    */
  388.   public int getSeed () {
  389.     return  m_Seed;
  390.   }
  391.   /**
  392.    * Parses a given list of options.
  393.    * @param options the list of options as an array of strings
  394.    * @exception Exception if an option is not supported
  395.    *
  396.    **/
  397.   public void setOptions (String[] options)
  398.     throws Exception {
  399.     String optionString = Utils.getOption('N', options);
  400.     if (optionString.length() != 0) {
  401.       setNumClusters(Integer.parseInt(optionString));
  402.     }
  403.     optionString = Utils.getOption('S', options);
  404.     
  405.     if (optionString.length() != 0) {
  406.       setSeed(Integer.parseInt(optionString));
  407.     }
  408.   }
  409.   /**
  410.    * Gets the current settings of SimpleKMeans
  411.    *
  412.    * @return an array of strings suitable for passing to setOptions()
  413.    */
  414.   public String[] getOptions () {
  415.     String[] options = new String[4];
  416.     int current = 0;
  417.     
  418.     options[current++] = "-N";
  419.     options[current++] = "" + getNumClusters();
  420.     options[current++] = "-S";
  421.     options[current++] = "" + getSeed();
  422.     
  423.     while (current < options.length) {
  424.       options[current++] = "";
  425.     }
  426.     return  options;
  427.   }
  428.   /**
  429.    * return a string describing this clusterer
  430.    *
  431.    * @return a description of the clusterer as a string
  432.    */
  433.   public String toString() {
  434.     StringBuffer temp = new StringBuffer();
  435.     temp.append("nkMeansn======n");
  436.     temp.append("nCluster centroids:n");
  437.     for (int i = 0; i < m_NumClusters; i++) {
  438.       temp.append("nCluster "+i+"nt");
  439.       for (int j = 0; j < m_ClusterCentroids.numAttributes(); j++) {
  440. if (m_ClusterCentroids.attribute(j).isNominal()) {
  441.   temp.append(" "+m_ClusterCentroids.attribute(j).
  442.       value((int)m_ClusterCentroids.instance(i).value(j)));
  443. } else {
  444.   temp.append(" "+m_ClusterCentroids.instance(i).value(j));
  445. }
  446.       }
  447.     }
  448.     return temp.toString();
  449.   }
  450.   /**
  451.    * Main method for testing this class.
  452.    *
  453.    * @param argv should contain the following arguments: <p>
  454.    * -t training file [-N number of clusters]
  455.    */
  456.   public static void main (String[] argv) {
  457.     try {
  458.       System.out.println(ClusterEvaluation.
  459.  evaluateClusterer(new SimpleKMeans(), argv));
  460.     }
  461.     catch (Exception e) {
  462.       System.out.println(e.getMessage());
  463.       e.printStackTrace();
  464.     }
  465.   }
  466. }