EM.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 25k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    EM.java
  18.  *    Copyright (C) 1999 Mark Hall
  19.  *
  20.  */
  21. package  weka.clusterers;
  22. import  java.io.*;
  23. import  java.util.*;
  24. import  weka.core.*;
  25. import  weka.estimators.*;
  26. /**
  27.  * Simple EM (expectation maximisation) class. <p>
  28.  * 
  29.  * EM assigns a probability distribution to each instance which
  30.  * indicates the probability of it belonging to each of the clusters.
  31.  * EM can decide how many clusters to create by cross validation, or you
  32.  * may specify apriori how many clusters to generate. <p>
  33.  *
  34.  * Valid options are:<p>
  35.  *
  36.  * -V <br>
  37.  * Verbose. <p>
  38.  *
  39.  * -N <number of clusters> <br>
  40.  * Specify the number of clusters to generate. If omitted,
  41.  * EM will use cross validation to select the number of clusters
  42.  * automatically. <p>
  43.  *
  44.  * -I <max iterations> <br>
  45.  * Terminate after this many iterations if EM has not converged. <p>
  46.  *
  47.  * -S <seed> <br>
  48.  * Specify random number seed. <p>
  49.  *
  50.  * -M <num> <br>
  51.  * Set the minimum allowable standard deviation for normal density calculation.
  52.  * <p>
  53.  *
  54.  * @author Mark Hall (mhall@cs.waikato.ac.nz)
  55.  * @version $Revision: 1.17 $
  56.  */
  57. public class EM
  58.   extends DistributionClusterer
  59.   implements OptionHandler
  60. {
  61.   /** hold the discrete estimators for each cluster */
  62.   private Estimator m_model[][];
  63.   /** hold the normal estimators for each cluster */
  64.   private double m_modelNormal[][][];
  65.   /** default minimum standard deviation */
  66.   private double m_minStdDev = 1e-6;
  67.   /** hold the weights of each instance for each cluster */
  68.   private double m_weights[][];
  69.   /** the prior probabilities for clusters */
  70.   private double m_priors[];
  71.   /** the loglikelihood of the data */
  72.   private double m_loglikely;
  73.   /** training instances */
  74.   private Instances m_theInstances = null;
  75.   /** number of clusters selected by the user or cross validation */
  76.   private int m_num_clusters;
  77.   /** the initial number of clusters requested by the user--- -1 if
  78.       xval is to be used to find the number of clusters */
  79.   private int m_initialNumClusters;
  80.   /** number of attributes */
  81.   private int m_num_attribs;
  82.   /** number of training instances */
  83.   private int m_num_instances;
  84.   /** maximum iterations to perform */
  85.   private int m_max_iterations;
  86.   /** random numbers and seed */
  87.   private Random m_rr;
  88.   private int m_rseed;
  89.   /** Constant for normal distribution. */
  90.   private static double m_normConst = Math.sqrt(2*Math.PI);
  91.   /** Verbose? */
  92.   private boolean m_verbose;
  93.   /**
  94.    * Returns a string describing this clusterer
  95.    * @return a description of the evaluator suitable for
  96.    * displaying in the explorer/experimenter gui
  97.    */
  98.   public String globalInfo() {
  99.     return "Cluster data using expectation maximization";
  100.   }
  101.   /**
  102.    * Returns an enumeration describing the available options.. <p>
  103.    *
  104.    * Valid options are:<p>
  105.    *
  106.    * -V <br>
  107.    * Verbose. <p>
  108.    *
  109.    * -N <number of clusters> <br>
  110.    * Specify the number of clusters to generate. If omitted,
  111.    * EM will use cross validation to select the number of clusters
  112.    * automatically. <p>
  113.    *
  114.    * -I <max iterations> <br>
  115.    * Terminate after this many iterations if EM has not converged. <p>
  116.    *
  117.    * -S <seed> <br>
  118.    * Specify random number seed. <p>
  119.    *
  120.    * -M <num> <br>
  121.    *  Set the minimum allowable standard deviation for normal density 
  122.    * calculation. <p>
  123.    *
  124.    * @return an enumeration of all the available options.
  125.    *
  126.    **/
  127.   public Enumeration listOptions () {
  128.     Vector newVector = new Vector(6);
  129.     newVector.addElement(new Option("tnumber of clusters. If omitted or" 
  130.     + "nt-1 specified, then cross " 
  131.     + "validation is used tontselect the " 
  132.     + "number of clusters.", "N", 1
  133.     , "-N <num>"));
  134.     newVector.addElement(new Option("tmax iterations.n(default 100)", "I"
  135.     , 1, "-I <num>"));
  136.     newVector.addElement(new Option("trandom number seed.n(default 1)"
  137.     , "S", 1, "-S <num>"));
  138.     newVector.addElement(new Option("tverbose.", "V", 0, "-V"));
  139.     newVector.addElement(new Option("tminimum allowable standard deviation "
  140.     +"for normal density computation "
  141.     +"nt(default 1e-6)"
  142.     ,"M",1,"-M <num>"));
  143.     return  newVector.elements();
  144.   }
  145.   /**
  146.    * Parses a given list of options.
  147.    * @param options the list of options as an array of strings
  148.    * @exception Exception if an option is not supported
  149.    *
  150.    **/
  151.   public void setOptions (String[] options)
  152.     throws Exception
  153.   {
  154.     resetOptions();
  155.     setDebug(Utils.getFlag('V', options));
  156.     String optionString = Utils.getOption('I', options);
  157.     if (optionString.length() != 0) {
  158.       setMaxIterations(Integer.parseInt(optionString));
  159.     }
  160.     optionString = Utils.getOption('N', options);
  161.     if (optionString.length() != 0) {
  162.       setNumClusters(Integer.parseInt(optionString));
  163.     }
  164.     optionString = Utils.getOption('S', options);
  165.     if (optionString.length() != 0) {
  166.       setSeed(Integer.parseInt(optionString));
  167.     }
  168.     optionString = Utils.getOption('M', options);
  169.     if (optionString.length() != 0) {
  170.       setMinStdDev((new Double(optionString)).doubleValue());
  171.     }
  172.   }
  173.   /**
  174.    * Returns the tip text for this property
  175.    * @return tip text for this property suitable for
  176.    * displaying in the explorer/experimenter gui
  177.    */
  178.   public String minStdDevTipText() {
  179.     return "set minimum allowable standard deviation";
  180.   }
  181.   /**
  182.    * Set the minimum value for standard deviation when calculating
  183.    * normal density. Reducing this value can help prevent arithmetic
  184.    * overflow resulting from multiplying large densities (arising from small
  185.    * standard deviations) when there are many singleton or near singleton
  186.    * values.
  187.    * @param m minimum value for standard deviation
  188.    */
  189.   public void setMinStdDev(double m) {
  190.     m_minStdDev = m;
  191.   }
  192.   /**
  193.    * Get the minimum allowable standard deviation.
  194.    * @return the minumum allowable standard deviation
  195.    */
  196.   public double getMinStdDev() {
  197.     return m_minStdDev;
  198.   }
  199.   /**
  200.    * Returns the tip text for this property
  201.    * @return tip text for this property suitable for
  202.    * displaying in the explorer/experimenter gui
  203.    */
  204.   public String seedTipText() {
  205.     return "random number seed";
  206.   }
  207.   /**
  208.    * Set the random number seed
  209.    *
  210.    * @param s the seed
  211.    */
  212.   public void setSeed (int s) {
  213.     m_rseed = s;
  214.   }
  215.   /**
  216.    * Get the random number seed
  217.    *
  218.    * @return the seed
  219.    */
  220.   public int getSeed () {
  221.     return  m_rseed;
  222.   }
  223.   /**
  224.    * Returns the tip text for this property
  225.    * @return tip text for this property suitable for
  226.    * displaying in the explorer/experimenter gui
  227.    */
  228.   public String numClustersTipText() {
  229.     return "set number of clusters. -1 to select number of clusters "
  230.       +"automatically by cross validation.";
  231.   }
  232.   /**
  233.    * Set the number of clusters (-1 to select by CV).
  234.    *
  235.    * @param n the number of clusters
  236.    * @exception Exception if n is 0
  237.    */
  238.   public void setNumClusters (int n)
  239.     throws Exception {
  240.     
  241.     if (n == 0) {
  242.       throw  new Exception("Number of clusters must be > 0. (or -1 to " 
  243.    + "select by cross validation).");
  244.     }
  245.     if (n < 0) {
  246.       m_num_clusters = -1;
  247.       m_initialNumClusters = -1;
  248.     }
  249.     else {
  250.       m_num_clusters = n;
  251.       m_initialNumClusters = n;
  252.     }
  253.   }
  254.   /**
  255.    * Get the number of clusters
  256.    *
  257.    * @return the number of clusters.
  258.    */
  259.   public int getNumClusters () {
  260.     return  m_initialNumClusters;
  261.   }
  262.   /**
  263.    * Returns the tip text for this property
  264.    * @return tip text for this property suitable for
  265.    * displaying in the explorer/experimenter gui
  266.    */
  267.   public String maxIterationsTipText() {
  268.     return "maximum number of iterations";
  269.   }
  270.   /**
  271.    * Set the maximum number of iterations to perform
  272.    *
  273.    * @param i the number of iterations
  274.    * @exception Exception if i is less than 1
  275.    */
  276.   public void setMaxIterations (int i)
  277.     throws Exception
  278.   {
  279.     if (i < 1) {
  280.       throw  new Exception("Maximum number of iterations must be > 0!");
  281.     }
  282.     m_max_iterations = i;
  283.   }
  284.   /**
  285.    * Get the maximum number of iterations
  286.    *
  287.    * @return the number of iterations
  288.    */
  289.   public int getMaxIterations () {
  290.     return  m_max_iterations;
  291.   }
  292.   /**
  293.    * Set debug mode - verbose output
  294.    *
  295.    * @param v true for verbose output
  296.    */
  297.   public void setDebug (boolean v) {
  298.     m_verbose = v;
  299.   }
  300.   /**
  301.    * Get debug mode
  302.    *
  303.    * @return true if debug mode is set
  304.    */
  305.   public boolean getDebug () {
  306.     return  m_verbose;
  307.   }
  308.   /**
  309.    * Gets the current settings of EM.
  310.    *
  311.    * @return an array of strings suitable for passing to setOptions()
  312.    */
  313.   public String[] getOptions () {
  314.     String[] options = new String[9];
  315.     int current = 0;
  316.     if (m_verbose) {
  317.       options[current++] = "-V";
  318.     }
  319.     options[current++] = "-I";
  320.     options[current++] = "" + m_max_iterations;
  321.     options[current++] = "-N";
  322.     options[current++] = "" + getNumClusters();
  323.     options[current++] = "-S";
  324.     options[current++] = "" + m_rseed;
  325.     options[current++] = "-M";
  326.     options[current++] = ""+getMinStdDev();
  327.     while (current < options.length) {
  328.       options[current++] = "";
  329.     }
  330.     return  options;
  331.   }
  332.   /**
  333.    * Initialised estimators and storage.
  334.    *
  335.    * @param inst the instances
  336.    * @param num_cl the number of clusters
  337.    **/
  338.   private void EM_Init (Instances inst, int num_cl)
  339.     throws Exception
  340.   {
  341.     m_weights = new double[inst.numInstances()][num_cl];
  342.     int z;
  343.     m_model = new Estimator[num_cl][m_num_attribs];
  344.     m_modelNormal = new double[num_cl][m_num_attribs][3];
  345.     m_priors = new double[num_cl];
  346.     for (int i = 0; i < inst.numInstances(); i++) {
  347.       for (int j = 0; j < num_cl; j++) {
  348.         m_weights[i][j] = m_rr.nextDouble();
  349.       }
  350.       Utils.normalize(m_weights[i]);
  351.     }
  352.     // initial priors
  353.     estimate_priors(inst, num_cl);
  354.   }
  355.   /**
  356.    * calculate prior probabilites for the clusters
  357.    *
  358.    * @param inst the instances
  359.    * @param num_cl the number of clusters
  360.    * @exception Exception if priors can't be calculated
  361.    **/
  362.   private void estimate_priors (Instances inst, int num_cl)
  363.     throws Exception
  364.   {
  365.     for (int i = 0; i < num_cl; i++) {
  366.       m_priors[i] = 0.0;
  367.     }
  368.     for (int i = 0; i < inst.numInstances(); i++) {
  369.       for (int j = 0; j < num_cl; j++) {
  370.         m_priors[j] += m_weights[i][j];
  371.       }
  372.     }
  373.     Utils.normalize(m_priors);
  374.   }
  375.   /**
  376.    * Density function of normal distribution.
  377.    * @param x input value
  378.    * @param mean mean of distribution
  379.    * @param stdDev standard deviation of distribution
  380.    */
  381.   private double normalDens (double x, double mean, double stdDev) {
  382.     double diff = x - mean;
  383.    
  384.     return  (1/(m_normConst*stdDev))*Math.exp(-(diff*diff/(2*stdDev*stdDev)));
  385.   }
  386.   /**
  387.    * New probability estimators for an iteration
  388.    *
  389.    * @param num_cl the numbe of clusters
  390.    */
  391.   private void new_estimators (int num_cl) {
  392.     for (int i = 0; i < num_cl; i++) {
  393.       for (int j = 0; j < m_num_attribs; j++) {
  394.         if (m_theInstances.attribute(j).isNominal()) {
  395.           m_model[i][j] = new DiscreteEstimator(m_theInstances.
  396. attribute(j).numValues()
  397. , true);
  398.         }
  399.         else {
  400.           m_modelNormal[i][j][0] = m_modelNormal[i][j][1] = 
  401.     m_modelNormal[i][j][2] = 0.0;
  402.         }
  403.       }
  404.     }
  405.   }
  406.   /**
  407.    * The M step of the EM algorithm.
  408.    * @param inst the training instances
  409.    * @param num_cl the number of clusters
  410.    */
  411.   private void M (Instances inst, int num_cl)
  412.     throws Exception
  413.   {
  414.     int i, j, l;
  415.     new_estimators(num_cl);
  416.     for (i = 0; i < num_cl; i++) {
  417.       for (j = 0; j < m_num_attribs; j++) {
  418.         for (l = 0; l < inst.numInstances(); l++) {
  419.           if (!inst.instance(l).isMissing(j)) {
  420.             if (inst.attribute(j).isNominal()) {
  421.               m_model[i][j].addValue(inst.instance(l).value(j), 
  422.      m_weights[l][i]);
  423.             }
  424.             else {
  425.               m_modelNormal[i][j][0] += (inst.instance(l).value(j) * 
  426.  m_weights[l][i]);
  427.               m_modelNormal[i][j][2] += m_weights[l][i];
  428.               m_modelNormal[i][j][1] += (inst.instance(l).value(j) * 
  429.  inst.instance(l).value(j)*m_weights[l][i]);
  430.             }
  431.           }
  432.         }
  433.       }
  434.     }
  435.     
  436.        // calcualte mean and std deviation for numeric attributes
  437.     for (j = 0; j < m_num_attribs; j++) {
  438.       if (!inst.attribute(j).isNominal()) {
  439.         for (i = 0; i < num_cl; i++) {
  440.           if (m_modelNormal[i][j][2] < 0) {
  441.             m_modelNormal[i][j][1] = 0;
  442.           } else {
  443.   // variance
  444.     m_modelNormal[i][j][1] = (m_modelNormal[i][j][1] - 
  445.       (m_modelNormal[i][j][0] * 
  446.        m_modelNormal[i][j][0] / 
  447.        m_modelNormal[i][j][2])) / 
  448.       m_modelNormal[i][j][2];
  449.     // std dev      
  450.     m_modelNormal[i][j][1] = Math.sqrt(m_modelNormal[i][j][1]); 
  451.     if (m_modelNormal[i][j][1] <= m_minStdDev 
  452. || Double.isNaN(m_modelNormal[i][j][1])) {
  453.       m_modelNormal[i][j][1] = 
  454. m_minStdDev;
  455.     }
  456.     
  457.     // mean
  458.     if (m_modelNormal[i][j][2] > 0.0) {
  459.       m_modelNormal[i][j][0] /= m_modelNormal[i][j][2];
  460.     }
  461.   }        
  462.         }
  463.       }
  464.     }
  465.   }
  466.   /**
  467.    * The E step of the EM algorithm. Estimate cluster membership 
  468.    * probabilities.
  469.    *
  470.    * @param inst the training instances
  471.    * @param num_cl the number of clusters
  472.    * @return the average log likelihood
  473.    */
  474.   private double E (Instances inst, int num_cl)
  475.     throws Exception
  476.   {
  477.     int i, j, l;
  478.     double prob;
  479.     double loglk = 0.0;
  480.     for (l = 0; l < inst.numInstances(); l++) {
  481.       for (i = 0; i < num_cl; i++) {
  482. m_weights[l][i] = m_priors[i];
  483.       }
  484.       for (j = 0; j < m_num_attribs; j++) {
  485. double max = 0;
  486. for (i = 0; i < num_cl; i++) {
  487.   
  488.           if (!inst.instance(l).isMissing(j)) {
  489.             if (inst.attribute(j).isNominal()) {
  490.               m_weights[l][i] *= 
  491. m_model[i][j].getProbability(inst.instance(l).value(j));
  492.       
  493.             }
  494.             else {
  495.               // numeric attribute
  496.               m_weights[l][i] *= normalDens(inst.instance(l).value(j), 
  497.     m_modelNormal[i][j][0], 
  498.     m_modelNormal[i][j][1]);
  499.       if (Double.isInfinite(m_weights[l][i])) {
  500. throw new Exception("Joint density has overflowed. Try "
  501.     +"increasing the minimum allowable "
  502.     +"standard deviation for normal "
  503.     +"density calculation.");
  504.       }
  505.             }
  506.     if (m_weights[l][i] > max) {
  507.       max = m_weights[l][i];
  508.     }
  509.           }
  510.         }
  511. if (max > 0 && max < 1e-75) { // check for underflow
  512.   for (int zz = 0; zz < num_cl; zz++) {
  513.     // rescale
  514.     m_weights[l][zz] *= 1e75;
  515.   }
  516. }
  517.       }
  518.       
  519.       double temp1 = 0;
  520.       
  521.       for (i = 0; i < num_cl; i++) {
  522.         temp1 += m_weights[l][i];
  523.       }
  524.       
  525.       if (temp1 > 0) {
  526.         loglk += Math.log(temp1);
  527.       }
  528.       
  529.       // normalise the weights for this instance
  530.       try {
  531. Utils.normalize(m_weights[l]);
  532.       } catch (Exception e) {
  533. throw new Exception("An instance has zero cluster memberships. Try "
  534.     +"increasing the minimum allowable "
  535.     +"standard deviation for normal "
  536.     +"density calculation.");
  537.       }
  538.     }
  539.     
  540.     // reestimate priors
  541.     estimate_priors(inst, num_cl);
  542.     return  loglk/inst.numInstances();
  543.   }
  544.   
  545.   
  546.   /**
  547.    * Constructor.
  548.    *
  549.    **/
  550.   public EM () {
  551.     resetOptions();
  552.   }
  553.   /**
  554.    * Reset to default options
  555.    */
  556.   protected void resetOptions () {
  557.     m_minStdDev = 1e-6;
  558.     m_max_iterations = 100;
  559.     m_rseed = 100;
  560.     m_num_clusters = -1;
  561.     m_initialNumClusters = -1;
  562.     m_verbose = false;
  563.   }
  564.   /**
  565.    * Outputs the generated clusters into a string.
  566.    */
  567.   public String toString () {
  568.     StringBuffer text = new StringBuffer();
  569.     text.append("nEMn==n");
  570.     if (m_initialNumClusters == -1) {
  571.       text.append("nNumber of clusters selected by cross validation: "
  572.   +m_num_clusters+"n");
  573.     } else {
  574.       text.append("nNumber of clusters: " + m_num_clusters + "n");
  575.     }
  576.     for (int j = 0; j < m_num_clusters; j++) {
  577.       text.append("nCluster: " + j + " Prior probability: " 
  578.   + Utils.doubleToString(m_priors[j], 4) + "nn");
  579.       for (int i = 0; i < m_num_attribs; i++) {
  580.         text.append("Attribute: " + m_theInstances.attribute(i).name() + "n");
  581.         if (m_theInstances.attribute(i).isNominal()) {
  582.           if (m_model[j][i] != null) {
  583.             text.append(m_model[j][i].toString());
  584.           }
  585.         }
  586.         else {
  587.           text.append("Normal Distribution. Mean = " 
  588.       + Utils.doubleToString(m_modelNormal[j][i][0], 4) 
  589.       + " StdDev = " 
  590.       + Utils.doubleToString(m_modelNormal[j][i][1], 4) 
  591.       + "n");
  592.         }
  593.       }
  594.     }
  595.     return  text.toString();
  596.   }
  597.   /**
  598.    * verbose output for debugging
  599.    * @param inst the training instances
  600.    */
  601.   private void EM_Report (Instances inst) {
  602.     int i, j, l, m;
  603.     System.out.println("======================================");
  604.     for (j = 0; j < m_num_clusters; j++) {
  605.       for (i = 0; i < m_num_attribs; i++) {
  606. System.out.println("Clust: " + j + " att: " + i + "n");
  607. if (m_theInstances.attribute(i).isNominal()) {
  608.   if (m_model[j][i] != null) {
  609.     System.out.println(m_model[j][i].toString());
  610.   }
  611. }
  612. else {
  613.   System.out.println("Normal Distribution. Mean = " 
  614.      + Utils.doubleToString(m_modelNormal[j][i][0]
  615.     , 8, 4) 
  616.      + " StandardDev = " 
  617.      + Utils.doubleToString(m_modelNormal[j][i][1]
  618.     , 8, 4) 
  619.      + " WeightSum = " 
  620.      + Utils.doubleToString(m_modelNormal[j][i][2]
  621.     , 8, 4));
  622. }
  623.       }
  624.     }
  625.     
  626.     for (l = 0; l < inst.numInstances(); l++) {
  627.       m = Utils.maxIndex(m_weights[l]);
  628.       System.out.print("Inst " + Utils.doubleToString((double)l, 5, 0) 
  629.        + " Class " + m + "t");
  630.       for (j = 0; j < m_num_clusters; j++) {
  631. System.out.print(Utils.doubleToString(m_weights[l][j], 7, 5) + "  ");
  632.       }
  633.       System.out.println();
  634.     }
  635.   }
  636.   /**
  637.    * estimate the number of clusters by cross validation on the training
  638.    * data.
  639.    *
  640.    * @return the number of clusters selected
  641.    */
  642.   private int CVClusters ()
  643.     throws Exception
  644.   {
  645.     double CVLogLikely = -Double.MAX_VALUE;
  646.     double templl, tll;
  647.     boolean CVdecreased = true;
  648.     int num_cl = 1;
  649.     int i;
  650.     Random cvr;
  651.     Instances trainCopy;
  652.     int numFolds = (m_theInstances.numInstances() < 10) 
  653.       ? m_theInstances.numInstances() 
  654.       : 10;
  655.     while (CVdecreased) {
  656.       CVdecreased = false;
  657.       cvr = new Random(m_rseed);
  658.       trainCopy = new Instances(m_theInstances);
  659.       trainCopy.randomize(cvr);
  660.       // theInstances.stratify(10);
  661.       templl = 0.0;
  662.       for (i = 0; i < numFolds; i++) {
  663. Instances cvTrain = trainCopy.trainCV(numFolds, i);
  664. Instances cvTest = trainCopy.testCV(numFolds, i);
  665. EM_Init(cvTrain, num_cl);
  666. iterate(cvTrain, num_cl, false);
  667. tll = E(cvTest, num_cl);
  668. if (m_verbose) {
  669.   System.out.println("# clust: " + num_cl + " Fold: " + i 
  670.      + " Loglikely: " + tll);
  671. }
  672. templl += tll;
  673.       }
  674.       templl /= (double)numFolds;
  675.       if (m_verbose) {
  676. System.out.println("===================================" 
  677.    + "==============n# clust: " 
  678.    + num_cl 
  679.    + " Mean Loglikely: " 
  680.    + templl 
  681.    + "n================================" 
  682.    + "=================");
  683.       }
  684.       if (templl > CVLogLikely) {
  685. CVLogLikely = templl;
  686. CVdecreased = true;
  687. num_cl++;
  688.       }
  689.     }
  690.     if (m_verbose) {
  691.       System.out.println("Number of clusters: " + (num_cl - 1));
  692.     }
  693.     return  num_cl - 1;
  694.   }
  695.   /**
  696.    * Returns the number of clusters.
  697.    *
  698.    * @return the number of clusters generated for a training dataset.
  699.    * @exception Exception if number of clusters could not be returned
  700.    * successfully
  701.    */
  702.   public int numberOfClusters ()
  703.     throws Exception
  704.   {
  705.     if (m_num_clusters == -1) {
  706.       throw  new Exception("Haven't generated any clusters!");
  707.     }
  708.     return  m_num_clusters;
  709.   }
  710.   /**
  711.    * Generates a clusterer. Has to initialize all fields of the clusterer
  712.    * that are not being set via options.
  713.    *
  714.    * @param data set of instances serving as training data 
  715.    * @exception Exception if the clusterer has not been 
  716.    * generated successfully
  717.    */
  718.   public void buildClusterer (Instances data)
  719.     throws Exception {
  720.     if (data.checkForStringAttributes()) {
  721.       throw  new Exception("Can't handle string attributes!");
  722.     }
  723.     m_theInstances = data;
  724.     doEM();
  725.     
  726.     // save memory
  727.     m_theInstances = new Instances(m_theInstances,0);
  728.   }
  729.   /**
  730.    * Computes the density for a given instance.
  731.    * 
  732.    * @param inst the instance to compute the density for
  733.    * @return the density.
  734.    * @exception Exception if the density could not be computed
  735.    * successfully
  736.    */
  737.   public double densityForInstance(Instance inst) throws Exception {
  738.     return Utils.sum(weightsForInstance(inst));
  739.   }
  740.   /**
  741.    * Predicts the cluster memberships for a given instance.
  742.    *
  743.    * @param data set of test instances
  744.    * @param instance the instance to be assigned a cluster.
  745.    * @return an array containing the estimated membership 
  746.    * probabilities of the test instance in each cluster (this 
  747.    * should sum to at most 1)
  748.    * @exception Exception if distribution could not be 
  749.    * computed successfully
  750.    */
  751.   public double[] distributionForInstance (Instance inst)
  752.     throws Exception {
  753.     double [] distrib = weightsForInstance(inst);
  754.     Utils.normalize(distrib);
  755.     return distrib;
  756.   }
  757.   /**
  758.    * Returns the weights (indicating cluster membership) for a given instance
  759.    * 
  760.    * @param inst the instance to be assigned a cluster
  761.    * @return an array of weights
  762.    * @exception Exception if weights could not be computed
  763.    */
  764.   protected double[] weightsForInstance(Instance inst)
  765.     throws Exception {
  766.     int i, j;
  767.     double prob;
  768.     double[] wghts = new double[m_num_clusters];
  769.     for (i = 0; i < m_num_clusters; i++) {
  770.       prob = 1.0;
  771.       for (j = 0; j < m_num_attribs; j++) {
  772. if (!inst.isMissing(j)) {
  773.   if (inst.attribute(j).isNominal()) {
  774.     prob *= m_model[i][j].getProbability(inst.value(j));
  775.   }
  776.   else { // numeric attribute
  777.     prob *= normalDens(inst.value(j), 
  778.        m_modelNormal[i][j][0], 
  779.        m_modelNormal[i][j][1]);
  780.   }
  781. }
  782.       }
  783.       wghts[i] = (prob*m_priors[i]);
  784.     }
  785.     return  wghts;
  786.   }
  787.   /**
  788.    * Perform the EM algorithm
  789.    */
  790.   private void doEM ()
  791.     throws Exception
  792.   {
  793.     if (m_verbose) {
  794.       System.out.println("Seed: " + m_rseed);
  795.     }
  796.     m_rr = new Random(m_rseed);
  797.     m_num_instances = m_theInstances.numInstances();
  798.     m_num_attribs = m_theInstances.numAttributes();
  799.     if (m_verbose) {
  800.       System.out.println("Number of instances: " 
  801.  + m_num_instances 
  802.  + "nNumber of atts: " 
  803.  + m_num_attribs 
  804.  + "n");
  805.     }
  806.     // setDefaultStdDevs(theInstances);
  807.     // cross validate to determine number of clusters?
  808.     if (m_initialNumClusters == -1) {
  809.       if (m_theInstances.numInstances() > 9) {
  810. m_num_clusters = CVClusters();
  811.       } else {
  812. m_num_clusters = 1;
  813.       }
  814.     }
  815.     // fit full training set
  816.     EM_Init(m_theInstances, m_num_clusters);
  817.     m_loglikely = iterate(m_theInstances, m_num_clusters, m_verbose);
  818.   }
  819.   /**
  820.    * iterates the M and E steps until the log likelihood of the data
  821.    * converges.
  822.    *
  823.    * @param inst the training instances.
  824.    * @param num_cl the number of clusters.
  825.    * @param report be verbose.
  826.    * @return the log likelihood of the data
  827.    */
  828.   private double iterate (Instances inst, int num_cl, boolean report)
  829.     throws Exception
  830.   {
  831.     int i;
  832.     double llkold = 0.0;
  833.     double llk = 0.0;
  834.     if (report) {
  835.       EM_Report(inst);
  836.     }
  837.     for (i = 0; i < m_max_iterations; i++) {
  838.       M(inst, num_cl);
  839.       llkold = llk;
  840.       llk = E(inst, num_cl);
  841.       if (report) {
  842. System.out.println("Loglikely: " + llk);
  843.       }
  844.       if (i > 0) {
  845. if ((llk - llkold) < 1e-6) {
  846.   break;
  847. }
  848.       }
  849.     }
  850.     if (report) {
  851.       EM_Report(inst);
  852.     }
  853.     return  llk;
  854.   }
  855.   // ============
  856.   // Test method.
  857.   // ============
  858.   /**
  859.    * Main method for testing this class.
  860.    *
  861.    * @param argv should contain the following arguments: <p>
  862.    * -t training file [-T test file] [-N number of clusters] [-S random seed]
  863.    */
  864.   public static void main (String[] argv) {
  865.     try {
  866.       System.out.println(ClusterEvaluation.
  867.  evaluateClusterer(new EM(), argv));
  868.     }
  869.     catch (Exception e) {
  870.       System.out.println(e.getMessage());
  871.       e.printStackTrace();
  872.     }
  873.   }
  874. }