SigmoidUnit.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 3k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    SigmoidUnit.java
  18.  *    Copyright (C) 2001 Malcolm Ware
  19.  */
  20. package weka.classifiers.functions.neural;
  21. /**
  22.  * This can be used by the 
  23.  * neuralnode to perform all it's computations (as a sigmoid unit).
  24.  *
  25.  * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
  26.  * @version $Revision: 1.3 $
  27.  */
  28. public class SigmoidUnit implements NeuralMethod {
  29.   
  30.   /**
  31.    * This function calculates what the output value should be.
  32.    * @param node The node to calculate the value for.
  33.    * @return The value.
  34.    */
  35.   public double outputValue(NeuralNode node) {
  36.     double[] weights = node.getWeights();
  37.     NeuralConnection[] inputs = node.getInputs();
  38.     double value = weights[0];
  39.     for (int noa = 0; noa < node.getNumInputs(); noa++) {
  40.       
  41.       value += inputs[noa].outputValue(true) 
  42. * weights[noa+1];
  43.     }
  44.      
  45.     //this I got from the Neural Network faq to combat overflow
  46.     //pretty simple solution really :)
  47.     if (value < -45) {
  48.       value = 0;
  49.     }
  50.     else if (value > 45) {
  51.       value = 1;
  52.     }
  53.     else {
  54.       value = 1 / (1 + Math.exp(-value));
  55.     }  
  56.     return value;
  57.   }
  58.   
  59.   /**
  60.    * This function calculates what the error value should be.
  61.    * @param node The node to calculate the error for.
  62.    * @return The error.
  63.    */
  64.   public double errorValue(NeuralNode node) {
  65.     //then calculate the error.
  66.     
  67.     NeuralConnection[] outputs = node.getOutputs();
  68.     int[] oNums = node.getOutputNums();
  69.     double error = 0;
  70.     
  71.     for (int noa = 0; noa < node.getNumOutputs(); noa++) {
  72.       error += outputs[noa].errorValue(true) 
  73. * outputs[noa].weightValue(oNums[noa]);
  74.     }
  75.     double value = node.outputValue(false);
  76.     error *= value * (1 - value);
  77.     
  78.     return error;
  79.   }
  80.   /**
  81.    * This function will calculate what the change in weights should be
  82.    * and also update them.
  83.    * @param node The node to update the weights for.
  84.    * @param learn The learning rate to use.
  85.    * @param momentum The momentum to use.
  86.    */
  87.   public void updateWeights(NeuralNode node, double learn, double momentum) {
  88.     NeuralConnection[] inputs = node.getInputs();
  89.     double[] cWeights = node.getChangeInWeights();
  90.     double[] weights = node.getWeights();
  91.     double learnTimesError = 0;
  92.     try {
  93.       learnTimesError = learn * node.errorValue(false);
  94.     } catch(Exception e) {}
  95.     double c = learnTimesError + momentum * cWeights[0];
  96.     weights[0] += c;
  97.     cWeights[0] = c;
  98.  
  99.     int stopValue = node.getNumInputs() + 1;
  100.     for (int noa = 1; noa < stopValue; noa++) {
  101.       
  102.       c = learnTimesError * inputs[noa-1].outputValue(false);
  103.       c += momentum * cWeights[noa];
  104.       
  105.       weights[noa] += c;
  106.       cWeights[noa] = c; 
  107.     }
  108.   }
  109.     
  110. }