NeuralNetwork.java
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 71k
Category:

Windows Develop

Development Platform:

Java

  1. /*
  2.  *    This program is free software; you can redistribute it and/or modify
  3.  *    it under the terms of the GNU General Public License as published by
  4.  *    the Free Software Foundation; either version 2 of the License, or
  5.  *    (at your option) any later version.
  6.  *
  7.  *    This program is distributed in the hope that it will be useful,
  8.  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
  9.  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  10.  *    GNU General Public License for more details.
  11.  *
  12.  *    You should have received a copy of the GNU General Public License
  13.  *    along with this program; if not, write to the Free Software
  14.  *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15.  */
  16. /*
  17.  *    NeuralConnection.java
  18.  *    Copyright (C) 2000 Malcolm Ware
  19.  */
  20. package weka.classifiers.functions.neural;
  21. import java.util.*;
  22. import java.awt.*;
  23. import java.awt.event.*;
  24. import javax.swing.*;
  25. import weka.classifiers.*;
  26. import weka.core.*;
  27. import weka.filters.unsupervised.attribute.NominalToBinary;
  28. import weka.filters.Filter;
  29. /** 
  30.  * A Classifier that uses backpropagation to classify instances.
  31.  * This network can be built by hand, created by an algorithm or both.
  32.  * The network can also be monitored and modified during training time.
  33.  * The nodes in this network are all sigmoid (except for when the class
  34.  * is numeric in which case the the output nodes become unthresholded linear
  35.  * units).
  36.  *
  37.  * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
  38.  * @version $Revision: 1.10 $
  39.  */
  40. public class NeuralNetwork extends DistributionClassifier 
  41.   implements OptionHandler, WeightedInstancesHandler {
  42.   
  43.   /**
  44.    * Main method for testing this class.
  45.    *
  46.    * @param argv should contain command line options (see setOptions)
  47.    */
  48.   public static void main(String [] argv) {
  49.     
  50.     try {
  51.       System.out.println(Evaluation.evaluateModel(new NeuralNetwork(), argv));
  52.     } catch (Exception e) {
  53.       System.err.println(e.getMessage());
  54.       e.printStackTrace();
  55.     }
  56.     System.exit(0);
  57.   }
  58.   
  59.   /** 
  60.    * This inner class is used to connect the nodes in the network up to
  61.    * the data that they are classifying, Note that objects of this class are
  62.    * only suitable to go on the attribute side or class side of the network
  63.    * and not both.
  64.    */
  65.   protected class NeuralEnd extends NeuralConnection {
  66.     
  67.     
  68.     /** 
  69.      * the value that represents the instance value this node represents. 
  70.      * For an input it is the attribute number, for an output, if nominal
  71.      * it is the class value. 
  72.      */
  73.     private int m_link;
  74.     
  75.     /** True if node is an input, False if it's an output. */
  76.     private boolean m_input;
  77.     public NeuralEnd(String id) {
  78.       super(id);
  79.       m_link = 0;
  80.       m_input = true;
  81.       
  82.     }
  83.   
  84.     /**
  85.      * Call this function to determine if the point at x,y is on the unit.
  86.      * @param g The graphics context for font size info.
  87.      * @param x The x coord.
  88.      * @param y The y coord.
  89.      * @param w The width of the display.
  90.      * @param h The height of the display.
  91.      * @return True if the point is on the unit, false otherwise.
  92.      */
  93.     public boolean onUnit(Graphics g, int x, int y, int w, int h) {
  94.       
  95.       FontMetrics fm = g.getFontMetrics();
  96.       int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
  97.       int t = (int)(m_y * h) - fm.getHeight() / 2;
  98.       if (x < l || x > l + fm.stringWidth(m_id) + 4 
  99.   || y < t || y > t + fm.getHeight() + fm.getDescent() + 4) {
  100. return false;
  101.       }
  102.       return true;
  103.       
  104.     }
  105.    
  106.     /**
  107.      * This will draw the node id to the graphics context.
  108.      * @param g The graphics context.
  109.      * @param w The width of the drawing area.
  110.      * @param h The height of the drawing area.
  111.      */
  112.     public void drawNode(Graphics g, int w, int h) {
  113.       
  114.       if ((m_type & PURE_INPUT) == PURE_INPUT) {
  115. g.setColor(Color.green);
  116.       }
  117.       else {
  118. g.setColor(Color.orange);
  119.       }
  120.       
  121.       FontMetrics fm = g.getFontMetrics();
  122.       int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
  123.       int t = (int)(m_y * h) - fm.getHeight() / 2;
  124.       g.fill3DRect(l, t, fm.stringWidth(m_id) + 4
  125.    , fm.getHeight() + fm.getDescent() + 4
  126.    , true);
  127.       g.setColor(Color.black);
  128.       
  129.       g.drawString(m_id, l + 2, t + fm.getHeight() + 2);
  130.     }
  131.     /**
  132.      * Call this function to draw the node highlighted.
  133.      * @param g The graphics context.
  134.      * @param w The width of the drawing area.
  135.      * @param h The height of the drawing area.
  136.      */
  137.     public void drawHighlight(Graphics g, int w, int h) {
  138.       
  139.       g.setColor(Color.black);
  140.       FontMetrics fm = g.getFontMetrics();
  141.       int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
  142.       int t = (int)(m_y * h) - fm.getHeight() / 2;
  143.       g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8
  144.  , fm.getHeight() + fm.getDescent() + 8); 
  145.       drawNode(g, w, h);
  146.     }
  147.     
  148.     /**
  149.      * Call this to get the output value of this unit. 
  150.      * @param calculate True if the value should be calculated if it hasn't 
  151.      * been already.
  152.      * @return The output value, or NaN, if the value has not been calculated.
  153.      */
  154.     public double outputValue(boolean calculate) {
  155.      
  156.       if (Double.isNaN(m_unitValue) && calculate) {
  157. if (m_input) {
  158.   if (m_currentInstance.isMissing(m_link)) {
  159.     m_unitValue = 0;
  160.   }
  161.   else {
  162.     
  163.     m_unitValue = m_currentInstance.value(m_link);
  164.   }
  165. }
  166. else {
  167.   //node is an output.
  168.   m_unitValue = 0;
  169.   for (int noa = 0; noa < m_numInputs; noa++) {
  170.     m_unitValue += m_inputList[noa].outputValue(true);
  171.    
  172.   }
  173.   if (m_numeric && m_normalizeClass) {
  174.     //then scale the value;
  175.     //this scales linearly from between -1 and 1
  176.     m_unitValue = m_unitValue * 
  177.       m_attributeRanges[m_instances.classIndex()] + 
  178.       m_attributeBases[m_instances.classIndex()];
  179.   }
  180. }
  181.       }
  182.       return m_unitValue;
  183.       
  184.       
  185.     }
  186.     
  187.     /**
  188.      * Call this to get the error value of this unit, which in this case is
  189.      * the difference between the predicted class, and the actual class.
  190.      * @param calculate True if the value should be calculated if it hasn't 
  191.      * been already.
  192.      * @return The error value, or NaN, if the value has not been calculated.
  193.      */
  194.     public double errorValue(boolean calculate) {
  195.       
  196.       if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError) 
  197.   && calculate) {
  198. if (m_input) {
  199.   m_unitError = 0;
  200.   for (int noa = 0; noa < m_numOutputs; noa++) {
  201.     m_unitError += m_outputList[noa].errorValue(true);
  202.   }
  203. }
  204. else {
  205.   if (m_currentInstance.classIsMissing()) {
  206.     m_unitError = .1;  
  207.   }
  208.   else if (m_instances.classAttribute().isNominal()) {
  209.     if (m_currentInstance.classValue() == m_link) {
  210.       m_unitError = 1 - m_unitValue;
  211.     }
  212.     else {
  213.       m_unitError = 0 - m_unitValue;
  214.     }
  215.   }
  216.   else if (m_numeric) {
  217.     
  218.     if (m_normalizeClass) {
  219.       if (m_attributeRanges[m_instances.classIndex()] == 0) {
  220. m_unitError = 0;
  221.       }
  222.       else {
  223. m_unitError = (m_currentInstance.classValue() - m_unitValue ) /
  224.   m_attributeRanges[m_instances.classIndex()];
  225. //m_numericRange;
  226.       }
  227.     }
  228.     else {
  229.       m_unitError = m_currentInstance.classValue() - m_unitValue;
  230.     }
  231.   }
  232. }
  233.       }
  234.       return m_unitError;
  235.     }
  236.     
  237.     
  238.     /**
  239.      * Call this to reset the value and error for this unit, ready for the next
  240.      * run. This will also call the reset function of all units that are 
  241.      * connected as inputs to this one.
  242.      * This is also the time that the update for the listeners will be 
  243.      * performed.
  244.      */
  245.     public void reset() {
  246.       
  247.       if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
  248. m_unitValue = Double.NaN;
  249. m_unitError = Double.NaN;
  250. m_weightsUpdated = false;
  251. for (int noa = 0; noa < m_numInputs; noa++) {
  252.   m_inputList[noa].reset();
  253. }
  254.       }
  255.     }
  256.     
  257.     
  258.     /** 
  259.      * Call this function to set What this end unit represents.
  260.      * @param input True if this unit is used for entering an attribute,
  261.      * False if it's used for determining a class value.
  262.      * @param val The attribute number or class type that this unit represents.
  263.      * (for nominal attributes).
  264.      */
  265.     public void setLink(boolean input, int val) throws Exception {
  266.       m_input = input;
  267.       
  268.       if (input) {
  269. m_type = PURE_INPUT;
  270.       }
  271.       else {
  272. m_type = PURE_OUTPUT;
  273.       }
  274.       if (val < 0 || (input && val > m_instances.numAttributes()) 
  275.   || (!input && m_instances.classAttribute().isNominal() 
  276.       && val > m_instances.classAttribute().numValues())) {
  277. m_link = 0;
  278.       }
  279.       else {
  280. m_link = val;
  281.       }
  282.     }
  283.     
  284.     /**
  285.      * @return link for this node.
  286.      */
  287.     public int getLink() {
  288.       return m_link;
  289.     }
  290.     
  291.   }
  292.   
  293.  
  294.   /** Inner class used to draw the nodes onto.(uses the node lists!!) 
  295.    * This will also handle the user input. */
  296.   private class NodePanel extends JPanel {
  297.     
  298.     /**
  299.      * The constructor.
  300.      */
  301.     public NodePanel() {
  302.       
  303.       addMouseListener(new MouseAdapter() {
  304.   
  305.   public void mousePressed(MouseEvent e) {
  306.     
  307.     if (!m_stopped) {
  308.       return;
  309.     }
  310.     if ((e.getModifiers() & e.BUTTON1_MASK) == e.BUTTON1_MASK) {
  311.       Graphics g = NodePanel.this.getGraphics();
  312.       int x = e.getX();
  313.       int y = e.getY();
  314.       int w = NodePanel.this.getWidth();
  315.       int h = NodePanel.this.getHeight();
  316.       int u = 0;
  317.       FastVector tmp = new FastVector(4);
  318.       for (int noa = 0; noa < m_numAttributes; noa++) {
  319. if (m_inputs[noa].onUnit(g, x, y, w, h)) {
  320.   tmp.addElement(m_inputs[noa]);
  321.   selection(tmp, 
  322.     (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
  323.     , true);
  324.   return;
  325. }
  326.       }
  327.       for (int noa = 0; noa < m_numClasses; noa++) {
  328. if (m_outputs[noa].onUnit(g, x, y, w, h)) {
  329.   tmp.addElement(m_outputs[noa]);
  330.   selection(tmp,
  331.     (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
  332.     , true);
  333.   return;
  334. }
  335.       }
  336.       for (int noa = 0; noa < m_neuralNodes.length; noa++) {
  337. if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
  338.   tmp.addElement(m_neuralNodes[noa]);
  339.   selection(tmp,
  340.     (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
  341.     , true);
  342.   return;
  343. }
  344.       }
  345.       NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), 
  346.        m_random, m_sigmoidUnit);
  347.       m_nextId++;
  348.       temp.setX((double)e.getX() / w);
  349.       temp.setY((double)e.getY() / h);
  350.       tmp.addElement(temp);
  351.       addNode(temp);
  352.       selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
  353. , true);
  354.     }
  355.     else {
  356.       //then right click
  357.       Graphics g = NodePanel.this.getGraphics();
  358.       int x = e.getX();
  359.       int y = e.getY();
  360.       int w = NodePanel.this.getWidth();
  361.       int h = NodePanel.this.getHeight();
  362.       int u = 0;
  363.       FastVector tmp = new FastVector(4);
  364.       for (int noa = 0; noa < m_numAttributes; noa++) {
  365. if (m_inputs[noa].onUnit(g, x, y, w, h)) {
  366.   tmp.addElement(m_inputs[noa]);
  367.   selection(tmp, 
  368.     (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
  369.     , false);
  370.   return;
  371. }
  372.       }
  373.       for (int noa = 0; noa < m_numClasses; noa++) {
  374. if (m_outputs[noa].onUnit(g, x, y, w, h)) {
  375.   tmp.addElement(m_outputs[noa]);
  376.   selection(tmp,
  377.     (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
  378.     , false);
  379.   return;
  380. }
  381.       }
  382.       for (int noa = 0; noa < m_neuralNodes.length; noa++) {
  383. if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
  384.   tmp.addElement(m_neuralNodes[noa]);
  385.   selection(tmp,
  386.     (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
  387.     , false);
  388.   return;
  389. }
  390.       }
  391.       selection(null, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
  392. , false);
  393.     }
  394.   }
  395. });
  396.     }
  397.     
  398.     
  399.     /**
  400.      * This function gets called when the user has clicked something
  401.      * It will amend the current selection or connect the current selection
  402.      * to the new selection.
  403.      * Or if nothing was selected and the right button was used it will 
  404.      * delete the node.
  405.      * @param v The units that were selected.
  406.      * @param ctrl True if ctrl was held down.
  407.      * @param left True if it was the left mouse button.
  408.      */
  409.     private void selection(FastVector v, boolean ctrl, boolean left) {
  410.       
  411.       if (v == null) {
  412. //then unselect all.
  413. m_selected.removeAllElements();
  414. repaint();
  415. return;
  416.       }
  417.       
  418.       //then exclusive or the new selection with the current one.
  419.       if ((ctrl || m_selected.size() == 0) && left) {
  420. boolean removed = false;
  421. for (int noa = 0; noa < v.size(); noa++) {
  422.   removed = false;
  423.   for (int nob = 0; nob < m_selected.size(); nob++) {
  424.     if (v.elementAt(noa) == m_selected.elementAt(nob)) {
  425.       //then remove that element
  426.       m_selected.removeElementAt(nob);
  427.       removed = true;
  428.       break;
  429.     }
  430.   }
  431.   if (!removed) {
  432.     m_selected.addElement(v.elementAt(noa));
  433.   }
  434. }
  435. repaint();
  436. return;
  437.       }
  438.       
  439.       if (left) {
  440. //then connect the current selection to the new one.
  441. for (int noa = 0; noa < m_selected.size(); noa++) {
  442.   for (int nob = 0; nob < v.size(); nob++) {
  443.     NeuralConnection
  444.       .connect((NeuralConnection)m_selected.elementAt(noa)
  445.        , (NeuralConnection)v.elementAt(nob));
  446.   }
  447. }
  448.       }
  449.       else if (m_selected.size() > 0) {
  450. //then disconnect the current selection from the new one.
  451. for (int noa = 0; noa < m_selected.size(); noa++) {
  452.   for (int nob = 0; nob < v.size(); nob++) {
  453.     NeuralConnection
  454.       .disconnect((NeuralConnection)m_selected.elementAt(noa)
  455.   , (NeuralConnection)v.elementAt(nob));
  456.     
  457.     NeuralConnection
  458.       .disconnect((NeuralConnection)v.elementAt(nob)
  459.   , (NeuralConnection)m_selected.elementAt(noa));
  460.     
  461.   }
  462. }
  463.       }
  464.       else {
  465. //then remove the selected node. (it was right clicked while 
  466. //no other units were selected
  467. for (int noa = 0; noa < v.size(); noa++) {
  468.   ((NeuralConnection)v.elementAt(noa)).removeAllInputs();
  469.   ((NeuralConnection)v.elementAt(noa)).removeAllOutputs();
  470.   removeNode((NeuralConnection)v.elementAt(noa));
  471. }
  472.       }
  473.       repaint();
  474.     }
  475.     /**
  476.      * This will paint the nodes ontot the panel.
  477.      * @param g The graphics context.
  478.      */
  479.     public void paintComponent(Graphics g) {
  480.       super.paintComponent(g);
  481.       int x = getWidth();
  482.       int y = getHeight();
  483.       if (25 * m_numAttributes > 25 * m_numClasses && 
  484.   25 * m_numAttributes > y) {
  485. setSize(x, 25 * m_numAttributes);
  486.       }
  487.       else if (25 * m_numClasses > y) {
  488. setSize(x, 25 * m_numClasses);
  489.       }
  490.       else {
  491. setSize(x, y);
  492.       }
  493.       y = getHeight();
  494.       for (int noa = 0; noa < m_numAttributes; noa++) {
  495. m_inputs[noa].drawInputLines(g, x, y);
  496.       }
  497.       for (int noa = 0; noa < m_numClasses; noa++) {
  498. m_outputs[noa].drawInputLines(g, x, y);
  499. m_outputs[noa].drawOutputLines(g, x, y);
  500.       }
  501.       for (int noa = 0; noa < m_neuralNodes.length; noa++) {
  502. m_neuralNodes[noa].drawInputLines(g, x, y);
  503.       }
  504.       for (int noa = 0; noa < m_numAttributes; noa++) {
  505. m_inputs[noa].drawNode(g, x, y);
  506.       }
  507.       for (int noa = 0; noa < m_numClasses; noa++) {
  508. m_outputs[noa].drawNode(g, x, y);
  509.       }
  510.       for (int noa = 0; noa < m_neuralNodes.length; noa++) {
  511. m_neuralNodes[noa].drawNode(g, x, y);
  512.       }
  513.       for (int noa = 0; noa < m_selected.size(); noa++) {
  514. ((NeuralConnection)m_selected.elementAt(noa)).drawHighlight(g, x, y);
  515.       }
  516.     }
  517.   }
  518.   /** 
  519.    * This provides the basic controls for working with the neuralnetwork
  520.    * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
  521.    * @version $Revision: 1.10 $
  522.    */
  523.   class ControlPanel extends JPanel {
  524.     
  525.     /** The start stop button. */
  526.     public JButton m_startStop;
  527.     
  528.     /** The button to accept the network (even if it hasn't done all epochs. */
  529.     public JButton m_acceptButton;
  530.     
  531.     /** A label to state the number of epochs processed so far. */
  532.     public JPanel m_epochsLabel;
  533.     
  534.     /** A label to state the total number of epochs to be processed. */
  535.     public JLabel m_totalEpochsLabel;
  536.     
  537.     /** A text field to allow the changing of the total number of epochs. */
  538.     public JTextField m_changeEpochs;
  539.     
  540.     /** A label to state the learning rate. */
  541.     public JLabel m_learningLabel;
  542.     
  543.     /** A label to state the momentum. */
  544.     public JLabel m_momentumLabel;
  545.     
  546.     /** A text field to allow the changing of the learning rate. */
  547.     public JTextField m_changeLearning;
  548.     
  549.     /** A text field to allow the changing of the momentum. */
  550.     public JTextField m_changeMomentum;
  551.     
  552.     /** A label to state roughly the accuracy of the network.(because the
  553. accuracy is calculated per epoch, but the network is changing 
  554. throughout each epoch train).
  555.     */
  556.     public JPanel m_errorLabel;
  557.     
  558.     /** The constructor. */
  559.     public ControlPanel() { 
  560.       setBorder(BorderFactory.createTitledBorder("Controls"));
  561.       
  562.       m_totalEpochsLabel = new JLabel("Num Of Epochs  ");
  563.       m_epochsLabel = new JPanel(){ 
  564.   public void paintComponent(Graphics g) {
  565.     super.paintComponent(g);
  566.     g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
  567.     g.drawString("Epoch  " + m_epoch, 0, 10);
  568.   }
  569. };
  570.       m_epochsLabel.setFont(m_totalEpochsLabel.getFont());
  571.       
  572.       m_changeEpochs = new JTextField();
  573.       m_changeEpochs.setText("" + m_numEpochs);
  574.       m_errorLabel = new JPanel(){
  575.   public void paintComponent(Graphics g) {
  576.     super.paintComponent(g);
  577.     g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
  578.     if (m_valSize == 0) {
  579.       g.drawString("Error per Epoch = " + 
  580.    Utils.doubleToString(m_error, 7), 0, 10);
  581.     }
  582.     else {
  583.       g.drawString("Validation Error per Epoch = "
  584.    + Utils.doubleToString(m_error, 7), 0, 10);
  585.     }
  586.   }
  587. };
  588.       m_errorLabel.setFont(m_epochsLabel.getFont());
  589.       
  590.       m_learningLabel = new JLabel("Learning Rate = ");
  591.       m_momentumLabel = new JLabel("Momentum = ");
  592.       m_changeLearning = new JTextField();
  593.       m_changeMomentum = new JTextField();
  594.       m_changeLearning.setText("" + m_learningRate);
  595.       m_changeMomentum.setText("" + m_momentum);
  596.       setLayout(new BorderLayout(15, 10));
  597.       m_stopIt = true;
  598.       m_accepted = false;
  599.       m_startStop = new JButton("Start");
  600.       m_startStop.setActionCommand("Start");
  601.       
  602.       m_acceptButton = new JButton("Accept");
  603.       m_acceptButton.setActionCommand("Accept");
  604.       
  605.       JPanel buttons = new JPanel();
  606.       buttons.setLayout(new BoxLayout(buttons, BoxLayout.Y_AXIS));
  607.       buttons.add(m_startStop);
  608.       buttons.add(m_acceptButton);
  609.       add(buttons, BorderLayout.WEST);
  610.       JPanel data = new JPanel();
  611.       data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
  612.       
  613.       Box ab = new Box(BoxLayout.X_AXIS);
  614.       ab.add(m_epochsLabel);
  615.       data.add(ab);
  616.       
  617.       ab = new Box(BoxLayout.X_AXIS);
  618.       Component b = Box.createGlue();
  619.       ab.add(m_totalEpochsLabel);
  620.       ab.add(m_changeEpochs);
  621.       m_changeEpochs.setMaximumSize(new Dimension(200, 20));
  622.       ab.add(b);
  623.       data.add(ab);
  624.       
  625.       ab = new Box(BoxLayout.X_AXIS);
  626.       ab.add(m_errorLabel);
  627.       data.add(ab);
  628.       
  629.       add(data, BorderLayout.CENTER);
  630.       
  631.       data = new JPanel();
  632.       data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
  633.       ab = new Box(BoxLayout.X_AXIS);
  634.       b = Box.createGlue();
  635.       ab.add(m_learningLabel);
  636.       ab.add(m_changeLearning);
  637.       m_changeLearning.setMaximumSize(new Dimension(200, 20));
  638.       ab.add(b);
  639.       data.add(ab);
  640.       
  641.       ab = new Box(BoxLayout.X_AXIS);
  642.       b = Box.createGlue();
  643.       ab.add(m_momentumLabel);
  644.       ab.add(m_changeMomentum);
  645.       m_changeMomentum.setMaximumSize(new Dimension(200, 20));
  646.       ab.add(b);
  647.       data.add(ab);
  648.       
  649.       add(data, BorderLayout.EAST);
  650.       
  651.       m_startStop.addActionListener(new ActionListener() {
  652.   public void actionPerformed(ActionEvent e) {
  653.     if (e.getActionCommand().equals("Start")) {
  654.       m_stopIt = false;
  655.       m_startStop.setText("Stop");
  656.       m_startStop.setActionCommand("Stop");
  657.       int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
  658.       
  659.       m_numEpochs = n;
  660.       m_changeEpochs.setText("" + m_numEpochs);
  661.       
  662.       double m=Double.valueOf(m_changeLearning.getText()).
  663. doubleValue();
  664.       setLearningRate(m);
  665.       m_changeLearning.setText("" + m_learningRate);
  666.       
  667.       m = Double.valueOf(m_changeMomentum.getText()).doubleValue();
  668.       setMomentum(m);
  669.       m_changeMomentum.setText("" + m_momentum);
  670.       
  671.       blocker(false);
  672.     }
  673.     else if (e.getActionCommand().equals("Stop")) {
  674.       m_stopIt = true;
  675.       m_startStop.setText("Start");
  676.       m_startStop.setActionCommand("Start");
  677.     }
  678.   }
  679. });
  680.       
  681.       m_acceptButton.addActionListener(new ActionListener() {
  682.   public void actionPerformed(ActionEvent e) {
  683.     m_accepted = true;
  684.     blocker(false);
  685.   }
  686. });
  687.       
  688.       m_changeEpochs.addActionListener(new ActionListener() {
  689.   public void actionPerformed(ActionEvent e) {
  690.     int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
  691.     if (n > 0) {
  692.       m_numEpochs = n;
  693.       blocker(false);
  694.     }
  695.   }
  696. });
  697.     }
  698.   }
  699.   
  700.     
  701.   /** The training instances. */
  702.   private Instances m_instances;
  703.   
  704.   /** The current instance running through the network. */
  705.   private Instance m_currentInstance;
  706.   
  707.   /** A flag to say that it's a numeric class. */
  708.   private boolean m_numeric;
  709.   /** The ranges for all the attributes. */
  710.   private double[] m_attributeRanges;
  711.   /** The base values for all the attributes. */
  712.   private double[] m_attributeBases;
  713.   /** The output units.(only feeds the errors, does no calcs) */
  714.   private NeuralEnd[] m_outputs;
  715.   /** The input units.(only feeds the inputs does no calcs) */
  716.   private NeuralEnd[] m_inputs;
  717.   /** All the nodes that actually comprise the logical neural net. */
  718.   private NeuralConnection[] m_neuralNodes;
  719.   /** The number of classes. */
  720.   private int m_numClasses = 0;
  721.   
  722.   /** The number of attributes. */
  723.   private int m_numAttributes = 0; //note the number doesn't include the class.
  724.   
  725.   /** The panel the nodes are displayed on. */
  726.   private NodePanel m_nodePanel;
  727.   
  728.   /** The control panel. */
  729.   private ControlPanel m_controlPanel;
  730.   /** The next id number available for default naming. */
  731.   private int m_nextId;
  732.    
  733.   /** A Vector list of the units currently selected. */
  734.   private FastVector m_selected;
  735.   /** A Vector list of the graphers. */
  736.   private FastVector m_graphers;
  737.   /** The number of epochs to train through. */
  738.   private int m_numEpochs;
  739.   /** a flag to state if the network should be running, or stopped. */
  740.   private boolean m_stopIt;
  741.   /** a flag to state that the network has in fact stopped. */
  742.   private boolean m_stopped;
  743.   /** a flag to state that the network should be accepted the way it is. */
  744.   private boolean m_accepted;
  745.   /** The window for the network. */
  746.   private JFrame m_win;
  747.   /** A flag to tell the build classifier to automatically build a neural net.
  748.    */
  749.   private boolean m_autoBuild;
  750.   /** A flag to state that the gui for the network should be brought up.
  751.       To allow interaction while training. */
  752.   private boolean m_gui;
  753.   /** An int to say how big the validation set should be. */
  754.   private int m_valSize;
  755.   /** The number to to use to quit on validation testing. */
  756.   private int m_driftThreshold;
  757.   /** The number used to seed the random number generator. */
  758.   private long m_randomSeed;
  759.   /** The actual random number generator. */
  760.   private Random m_random;
  761.   /** A flag to state that a nominal to binary filter should be used. */
  762.   private boolean m_useNomToBin;
  763.   
  764.   /** The actual filter. */
  765.   private NominalToBinary m_nominalToBinaryFilter;
  766.   /** The string that defines the hidden layers */
  767.   private String m_hiddenLayers;
  768.   /** This flag states that the user wants the input values normalized. */
  769.   private boolean m_normalizeAttributes;
  770.   /** This flag states that the user wants the learning rate to decay. */
  771.   private boolean m_decay;
  772.   /** This is the learning rate for the network. */
  773.   private double m_learningRate;
  774.   /** This is the momentum for the network. */
  775.   private double m_momentum;
  776.   /** Shows the number of the epoch that the network just finished. */
  777.   private int m_epoch;
  778.   /** Shows the error of the epoch that the network just finished. */
  779.   private double m_error;
  780.   /** This flag states that the user wants the network to restart if it
  781.    * is found to be generating infinity or NaN for the error value. This
  782.    * would restart the network with the current options except that the
  783.    * learning rate would be smaller than before, (perhaps half of its current
  784.    * value). This option will not be available if the gui is chosen (if the
  785.    * gui is open the user can fix the network themselves, it is an 
  786.    * architectural minefield for the network to be reset with the gui open). */
  787.   private boolean m_reset;
  788.   /** This flag states that the user wants the class to be normalized while
  789.    * processing in the network is done. (the final answer will be in the
  790.    * original range regardless). This option will only be used when the class
  791.    * is numeric. */
  792.   private boolean m_normalizeClass;
  793.   /**
  794.    * this is a sigmoid unit. 
  795.    */
  796.   private SigmoidUnit m_sigmoidUnit;
  797.   
  798.   /**
  799.    * This is a linear unit.
  800.    */
  801.   private LinearUnit m_linearUnit;
  802.   
  803.   /**
  804.    * The constructor.
  805.    */
  806.   public NeuralNetwork() {
  807.     m_instances = null;
  808.     m_currentInstance = null;
  809.     m_controlPanel = null;
  810.     m_nodePanel = null;
  811.     m_epoch = 0;
  812.     m_error = 0;
  813.     
  814.     
  815.     m_outputs = new NeuralEnd[0];
  816.     m_inputs = new NeuralEnd[0];
  817.     m_numAttributes = 0;
  818.     m_numClasses = 0;
  819.     m_neuralNodes = new NeuralConnection[0];
  820.     m_selected = new FastVector(4);
  821.     m_graphers = new FastVector(2);
  822.     m_nextId = 0;
  823.     m_stopIt = true;
  824.     m_stopped = true;
  825.     m_accepted = false;
  826.     m_numeric = false;
  827.     m_random = null;
  828.     m_nominalToBinaryFilter = new NominalToBinary();
  829.     m_sigmoidUnit = new SigmoidUnit();
  830.     m_linearUnit = new LinearUnit();
  831.     //setting all the options to their defaults. To completely change these
  832.     //defaults they will also need to be changed down the bottom in the 
  833.     //setoptions function (the text info in the accompanying functions should 
  834.     //also be changed to reflect the new defaults
  835.     m_normalizeClass = true;
  836.     m_normalizeAttributes = true;
  837.     m_autoBuild = true;
  838.     m_gui = false;
  839.     m_useNomToBin = true;
  840.     m_driftThreshold = 20;
  841.     m_numEpochs = 500;
  842.     m_valSize = 0;
  843.     m_randomSeed = 0;
  844.     m_hiddenLayers = "a";
  845.     m_learningRate = .3;
  846.     m_momentum = .2;
  847.     m_reset = true;
  848.     m_decay = false;
  849.   }
  850.   /**
  851.    * @param d True if the learning rate should decay.
  852.    */
  853.   public void setDecay(boolean d) {
  854.     m_decay = d;
  855.   }
  856.   
  857.   /**
  858.    * @return the flag for having the learning rate decay.
  859.    */
  860.   public boolean getDecay() {
  861.     return m_decay;
  862.   }
  863.   /**
  864.    * This sets the network up to be able to reset itself with the current 
  865.    * settings and the learning rate at half of what it is currently. This
  866.    * will only happen if the network creates NaN or infinite errors. Also this
  867.    * will continue to happen until the network is trained properly. The 
  868.    * learning rate will also get set back to it's original value at the end of
  869.    * this. This can only be set to true if the GUI is not brought up.
  870.    * @param r True if the network should restart with it's current options
  871.    * and set the learning rate to half what it currently is.
  872.    */
  873.   public void setReset(boolean r) {
  874.     if (m_gui) {
  875.       r = false;
  876.     }
  877.     m_reset = r;
  878.       
  879.   }
  880.   /**
  881.    * @return The flag for reseting the network.
  882.    */
  883.   public boolean getReset() {
  884.     return m_reset;
  885.   }
  886.   
  887.   /**
  888.    * @param c True if the class should be normalized (the class will only ever
  889.    * be normalized if it is numeric). (Normalization puts the range between
  890.    * -1 - 1).
  891.    */
  892.   public void setNormalizeNumericClass(boolean c) {
  893.     m_normalizeClass = c;
  894.   }
  895.   
  896.   /**
  897.    * @return The flag for normalizing a numeric class.
  898.    */
  899.   public boolean getNormalizeNumericClass() {
  900.     return m_normalizeClass;
  901.   }
  902.   /**
  903.    * @param a True if the attributes should be normalized (even nominal
  904.    * attributes will get normalized here) (range goes between -1 - 1).
  905.    */
  906.   public void setNormalizeAttributes(boolean a) {
  907.     m_normalizeAttributes = a;
  908.   }
  909.   /**
  910.    * @return The flag for normalizing attributes.
  911.    */
  912.   public boolean getNormalizeAttributes() {
  913.     return m_normalizeAttributes;
  914.   }
  915.   /**
  916.    * @param f True if a nominalToBinary filter should be used on the
  917.    * data.
  918.    */
  919.   public void setNominalToBinaryFilter(boolean f) {
  920.     m_useNomToBin = f;
  921.   }
  922.   /**
  923.    * @return The flag for nominal to binary filter use.
  924.    */
  925.   public boolean getNominalToBinaryFilter() {
  926.     return m_useNomToBin;
  927.   }
  928.   /**
  929.    * This seeds the random number generator, that is used when a random
  930.    * number is needed for the network.
  931.    * @param l The seed.
  932.    */
  933.   public void setRandomSeed(long l) {
  934.     if (l >= 0) {
  935.       m_randomSeed = l;
  936.     }
  937.   }
  938.   
  939.   /**
  940.    * @return The seed for the random number generator.
  941.    */
  942.   public long getRandomSeed() {
  943.     return m_randomSeed;
  944.   }
  945.   /**
  946.    * This sets the threshold to use for when validation testing is being done.
  947.    * It works by ending testing once the error on the validation set has 
  948.    * consecutively increased a certain number of times.
  949.    * @param t The threshold to use for this.
  950.    */
  951.   public void setValidationThreshold(int t) {
  952.     if (t > 0) {
  953.       m_driftThreshold = t;
  954.     }
  955.   }
  956.   /**
  957.    * @return The threshold used for validation testing.
  958.    */
  959.   public int getValidationThreshold() {
  960.     return m_driftThreshold;
  961.   }
  962.   
  963.   /**
  964.    * The learning rate can be set using this command.
  965.    * NOTE That this is a static variable so it affect all networks that are
  966.    * running.
  967.    * Must be greater than 0 and no more than 1.
  968.    * @param l The New learning rate. 
  969.    */
  970.   public void setLearningRate(double l) {
  971.     if (l > 0 && l <= 1) {
  972.       m_learningRate = l;
  973.     
  974.       if (m_controlPanel != null) {
  975. m_controlPanel.m_changeLearning.setText("" + l);
  976.       }
  977.     }
  978.   }
  979.   /**
  980.    * @return The learning rate for the nodes.
  981.    */
  982.   public double getLearningRate() {
  983.     return m_learningRate;
  984.   }
  985.   /**
  986.    * The momentum can be set using this command.
  987.    * THE same conditions apply to this as to the learning rate.
  988.    * @param m The new Momentum.
  989.    */
  990.   public void setMomentum(double m) {
  991.     if (m >= 0 && m <= 1) {
  992.       m_momentum = m;
  993.   
  994.       if (m_controlPanel != null) {
  995. m_controlPanel.m_changeMomentum.setText("" + m);
  996.       }
  997.     }
  998.   }
  999.   
  1000.   /**
  1001.    * @return The momentum for the nodes.
  1002.    */
  1003.   public double getMomentum() {
  1004.     return m_momentum;
  1005.   }
  1006.   /**
  1007.    * This will set whether the network is automatically built
  1008.    * or if it is left up to the user. (there is nothing to stop a user
  1009.    * from altering an autobuilt network however). 
  1010.    * @param a True if the network should be auto built.
  1011.    */
  1012.   public void setAutoBuild(boolean a) {
  1013.     if (!m_gui) {
  1014.       a = true;
  1015.     }
  1016.     m_autoBuild = a;
  1017.   }
  1018.   /**
  1019.    * @return The auto build state.
  1020.    */
  1021.   public boolean getAutoBuild() {
  1022.     return m_autoBuild;
  1023.   }
  1024.   /**
  1025.    * This will set what the hidden layers are made up of when auto build is
  1026.    * enabled. Note to have no hidden units, just put a single 0, Any more
  1027.    * 0's will indicate that the string is badly formed and make it unaccepted.
  1028.    * Negative numbers, and floats will do the same. There are also some
  1029.    * wildcards. These are 'a' = (number of attributes + number of classes) / 2,
  1030.    * 'i' = number of attributes, 'o' = number of classes, and 't' = number of
  1031.    * attributes + number of classes.
  1032.    * @param h A string with a comma seperated list of numbers. Each number is 
  1033.    * the number of nodes to be on a hidden layer.
  1034.    */
  1035.   public void setHiddenLayers(String h) {
  1036.     String tmp = "";
  1037.     StringTokenizer tok = new StringTokenizer(h, ",");
  1038.     if (tok.countTokens() == 0) {
  1039.       return;
  1040.     }
  1041.     double dval;
  1042.     int val;
  1043.     String c;
  1044.     boolean first = true;
  1045.     while (tok.hasMoreTokens()) {
  1046.       c = tok.nextToken().trim();
  1047.       if (c.equals("a") || c.equals("i") || c.equals("o") || 
  1048.        c.equals("t")) {
  1049. tmp += c;
  1050.       }
  1051.       else {
  1052. dval = Double.valueOf(c).doubleValue();
  1053. val = (int)dval;
  1054. if ((val == dval && (val != 0 || (tok.countTokens() == 0 && first)) && 
  1055.      val >= 0)) {
  1056.   tmp += val;
  1057. }
  1058. else {
  1059.   return;
  1060. }
  1061.       }
  1062.       
  1063.       first = false;
  1064.       if (tok.hasMoreTokens()) {
  1065. tmp += ", ";
  1066.       }
  1067.     }
  1068.     m_hiddenLayers = tmp;
  1069.   }
  1070.   /**
  1071.    * @return A string representing the hidden layers, each number is the number
  1072.    * of nodes on a hidden layer.
  1073.    */
  1074.   public String getHiddenLayers() {
  1075.     return m_hiddenLayers;
  1076.   }
  1077.   /**
  1078.    * This will set whether A GUI is brought up to allow interaction by the user
  1079.    * with the neural network during training.
  1080.    * @param a True if gui should be created.
  1081.    */
  1082.   public void setGUI(boolean a) {
  1083.     m_gui = a;
  1084.     if (!a) {
  1085.       setAutoBuild(true);
  1086.       
  1087.     }
  1088.     else {
  1089.       setReset(false);
  1090.     }
  1091.   }
  1092.   /**
  1093.    * @return The true if should show gui.
  1094.    */
  1095.   public boolean getGUI() {
  1096.     return m_gui;
  1097.   }
  1098.   /**
  1099.    * This will set the size of the validation set.
  1100.    * @param a The size of the validation set, as a percentage of the whole.
  1101.    */
  1102.   public void setValidationSetSize(int a) {
  1103.     if (a < 0 || a > 99) {
  1104.       return;
  1105.     }
  1106.     m_valSize = a;
  1107.   }
  1108.   /**
  1109.    * @return The percentage size of the validation set.
  1110.    */
  1111.   public int getValidationSetSize() {
  1112.     return m_valSize;
  1113.   }
  1114.   
  1115.   
  1116.   
  1117.   /**
  1118.    * Set the number of training epochs to perform.
  1119.    * Must be greater than 0.
  1120.    * @param n The number of epochs to train through.
  1121.    */
  1122.   public void setTrainingTime(int n) {
  1123.     if (n > 0) {
  1124.       m_numEpochs = n;
  1125.     }
  1126.   }
  1127.   /**
  1128.    * @return The number of epochs to train through.
  1129.    */
  1130.   public int getTrainingTime() {
  1131.     return m_numEpochs;
  1132.   }
  1133.   
  1134.   /**
  1135.    * Call this function to place a node into the network list.
  1136.    * @param n The node to place in the list.
  1137.    */
  1138.   private void addNode(NeuralConnection n) {
  1139.     
  1140.     NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length + 1];
  1141.     for (int noa = 0; noa < m_neuralNodes.length; noa++) {
  1142.       temp1[noa] = m_neuralNodes[noa];
  1143.     }
  1144.     temp1[temp1.length-1] = n;
  1145.     m_neuralNodes = temp1;
  1146.   }
  1147.   /** 
  1148.    * Call this function to remove the passed node from the list.
  1149.    * This will only remove the node if it is in the neuralnodes list.
  1150.    * @param n The neuralConnection to remove.
  1151.    * @return True if removed false if not (because it wasn't there).
  1152.    */
  1153.   private boolean removeNode(NeuralConnection n) {
  1154.     NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length - 1];
  1155.     int skip = 0;
  1156.     for (int noa = 0; noa < m_neuralNodes.length; noa++) {
  1157.       if (n == m_neuralNodes[noa]) {
  1158. skip++;
  1159.       }
  1160.       else if (!((noa - skip) >= temp1.length)) {
  1161. temp1[noa - skip] = m_neuralNodes[noa];
  1162.       }
  1163.       else {
  1164. return false;
  1165.       }
  1166.     }
  1167.     m_neuralNodes = temp1;
  1168.     return true;
  1169.   }
  1170.   /**
  1171.    * This function sets what the m_numeric flag to represent the passed class
  1172.    * it also performs the normalization of the attributes if applicable
  1173.    * and sets up the info to normalize the class. (note that regardless of
  1174.    * the options it will fill an array with the range and base, set to 
  1175.    * normalize all attributes and the class to be between -1 and 1)
  1176.    * @param inst the instances.
  1177.    * @return The modified instances. This needs to be done. If the attributes
  1178.    * are normalized then deep copies will be made of all the instances which
  1179.    * will need to be passed back out.
  1180.    */
  1181.   private Instances setClassType(Instances inst) throws Exception {
  1182.     if (inst != null) {
  1183.       // x bounds
  1184.       double min=Double.POSITIVE_INFINITY;
  1185.       double max=Double.NEGATIVE_INFINITY;
  1186.       double value;
  1187.       m_attributeRanges = new double[inst.numAttributes()];
  1188.       m_attributeBases = new double[inst.numAttributes()];
  1189.       for (int noa = 0; noa < inst.numAttributes(); noa++) {
  1190. min = Double.POSITIVE_INFINITY;
  1191. max = Double.NEGATIVE_INFINITY;
  1192. for (int i=0; i < inst.numInstances();i++) {
  1193.   if (!inst.instance(i).isMissing(noa)) {
  1194.     value = inst.instance(i).value(noa);
  1195.     if (value < min) {
  1196.       min = value;
  1197.     }
  1198.     if (value > max) {
  1199.       max = value;
  1200.     }
  1201.   }
  1202. }
  1203. m_attributeRanges[noa] = (max - min) / 2;
  1204. m_attributeBases[noa] = (max + min) / 2;
  1205. if (noa != inst.classIndex() && m_normalizeAttributes) {
  1206.   for (int i = 0; i < inst.numInstances(); i++) {
  1207.     if (m_attributeRanges[noa] != 0) {
  1208.       inst.instance(i).setValue(noa, (inst.instance(i).value(noa)  
  1209.       - m_attributeBases[noa]) /
  1210. m_attributeRanges[noa]);
  1211.     }
  1212.     else {
  1213.       inst.instance(i).setValue(noa, inst.instance(i).value(noa) - 
  1214. m_attributeBases[noa]);
  1215.     }
  1216.   }
  1217. }
  1218.       }
  1219.       if (inst.classAttribute().isNumeric()) {
  1220. m_numeric = true;
  1221.       }
  1222.       else {
  1223. m_numeric = false;
  1224.       }
  1225.     }
  1226.     return inst;
  1227.   }
  1228.   /**
  1229.    * A function used to stop the code that called buildclassifier
  1230.    * from continuing on before the user has finished the decision tree.
  1231.    * @param tf True to stop the thread, False to release the thread that is
  1232.    * waiting there (if one).
  1233.    */
  1234.   public synchronized void blocker(boolean tf) {
  1235.     if (tf) {
  1236.       try {
  1237. wait();
  1238.       } catch(InterruptedException e) {
  1239.       }
  1240.     }
  1241.     else {
  1242.       notifyAll();
  1243.     }
  1244.   }
  1245.   /**
  1246.    * Call this function to update the control panel for the gui.
  1247.    */
  1248.   private void updateDisplay() {
  1249.     
  1250.     if (m_gui) {
  1251.       m_controlPanel.m_errorLabel.repaint();
  1252.       m_controlPanel.m_epochsLabel.repaint();
  1253.     }
  1254.   }
  1255.   
  1256.   /**
  1257.    * this will reset all the nodes in the network.
  1258.    */
  1259.   private void resetNetwork() {
  1260.     for (int noc = 0; noc < m_numClasses; noc++) {
  1261.       m_outputs[noc].reset();
  1262.     }
  1263.   }
  1264.   
  1265.   /**
  1266.    * This will cause the output values of all the nodes to be calculated.
  1267.    * Note that the m_currentInstance is used to calculate these values.
  1268.    */
  1269.   private void calculateOutputs() {
  1270.     for (int noc = 0; noc < m_numClasses; noc++) {
  1271.       //get the values. 
  1272.       m_outputs[noc].outputValue(true);
  1273.     }
  1274.   }
  1275.   /**
  1276.    * This will cause the error values to be calculated for all nodes.
  1277.    * Note that the m_currentInstance is used to calculate these values.
  1278.    * Also the output values should have been calculated first.
  1279.    * @return The squared error.
  1280.    */
  1281.   private double calculateErrors() throws Exception {
  1282.     double ret = 0, temp = 0; 
  1283.     for (int noc = 0; noc < m_numAttributes; noc++) {
  1284.       //get the errors.
  1285.       m_inputs[noc].errorValue(true);
  1286.       
  1287.     }
  1288.     for (int noc = 0; noc < m_numClasses; noc++) {
  1289.       temp = m_outputs[noc].errorValue(false);
  1290.       ret += temp * temp;
  1291.     }    
  1292.     return ret;
  1293.     
  1294.   }
  1295.   /**
  1296.    * This will cause the weight values to be updated based on the learning
  1297.    * rate, momentum and the errors that have been calculated for each node.
  1298.    * @param l The learning rate to update with.
  1299.    * @param m The momentum to update with.
  1300.    */
  1301.   private void updateNetworkWeights(double l, double m) {
  1302.     for (int noc = 0; noc < m_numClasses; noc++) {
  1303.       //update weights
  1304.       m_outputs[noc].updateWeights(l, m);
  1305.     }
  1306.   }
  1307.   
  1308.   /**
  1309.    * This creates the required input units.
  1310.    */
  1311.   private void setupInputs() throws Exception {
  1312.     m_inputs = new NeuralEnd[m_numAttributes];
  1313.     int now = 0;
  1314.     for (int noa = 0; noa < m_numAttributes+1; noa++) {
  1315.       if (m_instances.classIndex() != noa) {
  1316. m_inputs[noa - now] = new NeuralEnd(m_instances.attribute(noa).name());
  1317. m_inputs[noa - now].setX(.1);
  1318. m_inputs[noa - now].setY((noa - now + 1.0) / (m_numAttributes + 1));
  1319. m_inputs[noa - now].setLink(true, noa);
  1320.       }    
  1321.       else {
  1322. now = 1;
  1323.       }
  1324.     }
  1325.   }
  1326.   /**
  1327.    * This creates the required output units.
  1328.    */
  1329.   private void setupOutputs() throws Exception {
  1330.   
  1331.     m_outputs = new NeuralEnd[m_numClasses];
  1332.     for (int noa = 0; noa < m_numClasses; noa++) {
  1333.       if (m_numeric) {
  1334. m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().name());
  1335.       }
  1336.       else {
  1337. m_outputs[noa]= new NeuralEnd(m_instances.classAttribute().value(noa));
  1338.       }
  1339.       
  1340.       m_outputs[noa].setX(.9);
  1341.       m_outputs[noa].setY((noa + 1.0) / (m_numClasses + 1));
  1342.       m_outputs[noa].setLink(false, noa);
  1343.       NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
  1344.        m_sigmoidUnit);
  1345.       m_nextId++;
  1346.       temp.setX(.75);
  1347.       temp.setY((noa + 1.0) / (m_numClasses + 1));
  1348.       addNode(temp);
  1349.       NeuralConnection.connect(temp, m_outputs[noa]);
  1350.     }
  1351.  
  1352.   }
  1353.   
  1354.   /**
  1355.    * Call this function to automatically generate the hidden units
  1356.    */
  1357.   private void setupHiddenLayer()
  1358.   {
  1359.     StringTokenizer tok = new StringTokenizer(m_hiddenLayers, ",");
  1360.     int val = 0;  //num of nodes in a layer
  1361.     int prev = 0; //used to remember the previous layer
  1362.     int num = tok.countTokens(); //number of layers
  1363.     String c;
  1364.     for (int noa = 0; noa < num; noa++) {
  1365.       //note that I am using the Double to get the value rather than the
  1366.       //Integer class, because for some reason the Double implementation can
  1367.       //handle leading white space and the integer version can't!?!
  1368.       c = tok.nextToken().trim();
  1369.       if (c.equals("a")) {
  1370. val = (m_numAttributes + m_numClasses) / 2;
  1371.       }
  1372.       else if (c.equals("i")) {
  1373. val = m_numAttributes;
  1374.       }
  1375.       else if (c.equals("o")) {
  1376. val = m_numClasses;
  1377.       }
  1378.       else if (c.equals("t")) {
  1379. val = m_numAttributes + m_numClasses;
  1380.       }
  1381.       else {
  1382. val = Double.valueOf(c).intValue();
  1383.       }
  1384.       for (int nob = 0; nob < val; nob++) {
  1385. NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
  1386.  m_sigmoidUnit);
  1387. m_nextId++;
  1388. temp.setX(.5 / (num) * noa + .25);
  1389. temp.setY((nob + 1.0) / (val + 1));
  1390. addNode(temp);
  1391. if (noa > 0) {
  1392.   //then do connections
  1393.   for (int noc = m_neuralNodes.length - nob - 1 - prev;
  1394.        noc < m_neuralNodes.length - nob - 1; noc++) {
  1395.     NeuralConnection.connect(m_neuralNodes[noc], temp);
  1396.   }
  1397. }
  1398.       }      
  1399.       prev = val;
  1400.     }
  1401.     tok = new StringTokenizer(m_hiddenLayers, ",");
  1402.     c = tok.nextToken();
  1403.     if (c.equals("a")) {
  1404.       val = (m_numAttributes + m_numClasses) / 2;
  1405.     }
  1406.     else if (c.equals("i")) {
  1407.       val = m_numAttributes;
  1408.     }
  1409.     else if (c.equals("o")) {
  1410.       val = m_numClasses;
  1411.     }
  1412.     else if (c.equals("t")) {
  1413.       val = m_numAttributes + m_numClasses;
  1414.     }
  1415.     else {
  1416.       val = Double.valueOf(c).intValue();
  1417.     }
  1418.     
  1419.     if (val == 0) {
  1420.       for (int noa = 0; noa < m_numAttributes; noa++) {
  1421. for (int nob = 0; nob < m_numClasses; nob++) {
  1422.   NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
  1423. }
  1424.       }
  1425.     }
  1426.     else {
  1427.       for (int noa = 0; noa < m_numAttributes; noa++) {
  1428. for (int nob = m_numClasses; nob < m_numClasses + val; nob++) {
  1429.   NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
  1430. }
  1431.       }
  1432.       for (int noa = m_neuralNodes.length - prev; noa < m_neuralNodes.length;
  1433.    noa++) {
  1434. for (int nob = 0; nob < m_numClasses; nob++) {
  1435.   NeuralConnection.connect(m_neuralNodes[noa], m_neuralNodes[nob]);
  1436. }
  1437.       }
  1438.     }
  1439.     
  1440.   }
  1441.   
  1442.   /**
  1443.    * This will go through all the nodes and check if they are connected
  1444.    * to a pure output unit. If so they will be set to be linear units.
  1445.    * If not they will be set to be sigmoid units.
  1446.    */
  1447.   private void setEndsToLinear() {
  1448.     for (int noa = 0; noa < m_neuralNodes.length; noa++) {
  1449.       if ((m_neuralNodes[noa].getType() & NeuralConnection.OUTPUT) ==
  1450.   NeuralConnection.OUTPUT) {
  1451. ((NeuralNode)m_neuralNodes[noa]).setMethod(m_linearUnit);
  1452.       }
  1453.       else {
  1454. ((NeuralNode)m_neuralNodes[noa]).setMethod(m_sigmoidUnit);
  1455.       }
  1456.     }
  1457.   }
  1458.   
  1459.   /**
  1460.    * Call this function to build and train a neural network for the training
  1461.    * data provided.
  1462.    * @param i The training data.
  1463.    * @exception Throws exception if can't build classification properly.
  1464.    */
  1465.   public void buildClassifier(Instances i) throws Exception {
  1466.     if (i.checkForStringAttributes()) {
  1467.       throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
  1468.     }
  1469.     if (i.numInstances() == 0) {
  1470.       throw new IllegalArgumentException("No training instances.");
  1471.     }
  1472.     m_epoch = 0;
  1473.     m_error = 0;
  1474.     m_instances = null;
  1475.     m_currentInstance = null;
  1476.     m_controlPanel = null;
  1477.     m_nodePanel = null;
  1478.     
  1479.     
  1480.     m_outputs = new NeuralEnd[0];
  1481.     m_inputs = new NeuralEnd[0];
  1482.     m_numAttributes = 0;
  1483.     m_numClasses = 0;
  1484.     m_neuralNodes = new NeuralConnection[0];
  1485.     
  1486.     m_selected = new FastVector(4);
  1487.     m_graphers = new FastVector(2);
  1488.     m_nextId = 0;
  1489.     m_stopIt = true;
  1490.     m_stopped = true;
  1491.     m_accepted = false;    
  1492.     m_instances = new Instances(i);
  1493.     m_instances.deleteWithMissingClass();
  1494.     if (m_instances.numInstances() == 0) {
  1495.       m_instances = null;
  1496.       throw new IllegalArgumentException("All class values missing.");
  1497.     }
  1498.     m_random = new Random(m_randomSeed);
  1499.     m_instances.randomize(m_random);
  1500.     if (m_useNomToBin) {
  1501.       m_nominalToBinaryFilter = new NominalToBinary();
  1502.       m_nominalToBinaryFilter.setInputFormat(m_instances);
  1503.       m_instances = Filter.useFilter(m_instances,
  1504.      m_nominalToBinaryFilter);
  1505.     }
  1506.     m_numAttributes = m_instances.numAttributes() - 1;
  1507.     m_numClasses = m_instances.numClasses();
  1508.  
  1509.     
  1510.     setClassType(m_instances);
  1511.     
  1512.    
  1513.     //this sets up the validation set.
  1514.     Instances valSet = null;
  1515.     //numinval is needed later
  1516.     int numInVal = (int)(m_valSize / 100.0 * m_instances.numInstances());
  1517.     if (m_valSize > 0) {
  1518.       if (numInVal == 0) {
  1519. numInVal = 1;
  1520.       }
  1521.       valSet = new Instances(m_instances, 0, numInVal);
  1522.     }
  1523.     ///////////
  1524.     setupInputs();
  1525.       
  1526.     setupOutputs();    
  1527.     if (m_autoBuild) {
  1528.       setupHiddenLayer();
  1529.     }
  1530.     
  1531.     /////////////////////////////
  1532.     //this sets up the gui for usage
  1533.     if (m_gui) {
  1534.       m_win = new JFrame();
  1535.       
  1536.       m_win.addWindowListener(new WindowAdapter() {
  1537.   public void windowClosing(WindowEvent e) {
  1538.     boolean k = m_stopIt;
  1539.     m_stopIt = true;
  1540.     int well =JOptionPane.showConfirmDialog(m_win, 
  1541.     "Are You Sure...n"
  1542.     + "Click Yes To Accept"
  1543.     + " The Neural Network" 
  1544.     + "n Click No To Return",
  1545.     "Accept Neural Network", 
  1546.     JOptionPane.YES_NO_OPTION);
  1547.     
  1548.     if (well == 0) {
  1549.       m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
  1550.       m_accepted = true;
  1551.       blocker(false);
  1552.     }
  1553.     else {
  1554.       m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE);
  1555.     }
  1556.     m_stopIt = k;
  1557.   }
  1558. });
  1559.       
  1560.       m_win.getContentPane().setLayout(new BorderLayout());
  1561.       m_win.setTitle("Neural Network");
  1562.       m_nodePanel = new NodePanel();
  1563.       JScrollPane sp = new JScrollPane(m_nodePanel,
  1564.        JScrollPane.VERTICAL_SCROLLBAR_ALWAYS, 
  1565.        JScrollPane.HORIZONTAL_SCROLLBAR_NEVER);
  1566.       m_controlPanel = new ControlPanel();
  1567.            
  1568.       m_win.getContentPane().add(sp, BorderLayout.CENTER);
  1569.       m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH);
  1570.       m_win.setSize(640, 480);
  1571.       m_win.show();
  1572.     }
  1573.    
  1574.     //This sets up the initial state of the gui
  1575.     if (m_gui) {
  1576.       blocker(true);
  1577.       m_controlPanel.m_changeEpochs.setEnabled(false);
  1578.       m_controlPanel.m_changeLearning.setEnabled(false);
  1579.       m_controlPanel.m_changeMomentum.setEnabled(false);
  1580.     } 
  1581.     
  1582.     //For silly situations in which the network gets accepted before training
  1583.     //commenses
  1584.     if (m_numeric) {
  1585.       setEndsToLinear();
  1586.     }
  1587.     if (m_accepted) {
  1588.       m_win.dispose();
  1589.       m_controlPanel = null;
  1590.       m_nodePanel = null;
  1591.       m_instances = new Instances(m_instances, 0);
  1592.       return;
  1593.     }
  1594.     //connections done.
  1595.     double right = 0;
  1596.     double driftOff = 0;
  1597.     double lastRight = Double.POSITIVE_INFINITY;
  1598.     double tempRate;
  1599.     double totalWeight = 0;
  1600.     double totalValWeight = 0;
  1601.     double origRate = m_learningRate; //only used for when reset
  1602.     
  1603.     //ensure that at least 1 instance is trained through.
  1604.     if (numInVal == m_instances.numInstances()) {
  1605.       numInVal--;
  1606.     }
  1607.     if (numInVal < 0) {
  1608.       numInVal = 0;
  1609.     }
  1610.     for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {
  1611.       if (!m_instances.instance(noa).classIsMissing()) {
  1612. totalWeight += m_instances.instance(noa).weight();
  1613.       }
  1614.     }
  1615.     if (m_valSize != 0) {
  1616.       for (int noa = 0; noa < valSet.numInstances(); noa++) {
  1617. if (!valSet.instance(noa).classIsMissing()) {
  1618.   totalValWeight += valSet.instance(noa).weight();
  1619. }
  1620.       }
  1621.     }
  1622.     m_stopped = false;
  1623.      
  1624.     for (int noa = 1; noa < m_numEpochs + 1; noa++) {
  1625.       right = 0;
  1626.       for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {
  1627. m_currentInstance = m_instances.instance(nob);
  1628. if (!m_currentInstance.classIsMissing()) {
  1629.    
  1630.   //this is where the network updating (and training occurs, for the
  1631.   //training set
  1632.   resetNetwork();
  1633.   calculateOutputs();
  1634.   tempRate = m_learningRate * m_currentInstance.weight();  
  1635.   if (m_decay) {
  1636.     tempRate /= noa;
  1637.   }
  1638.   right += (calculateErrors() / m_instances.numClasses()) *
  1639.     m_currentInstance.weight();
  1640.   updateNetworkWeights(tempRate, m_momentum);
  1641.   
  1642. }
  1643.       }
  1644.       right /= totalWeight;
  1645.       if (Double.isInfinite(right) || Double.isNaN(right)) {
  1646. if (!m_reset) {
  1647.   m_instances = null;
  1648.   throw new Exception("Network cannot train. Try restarting with a" +
  1649.       " smaller learning rate.");
  1650. }
  1651. else {
  1652.   //reset the network
  1653.   m_learningRate /= 2;
  1654.   buildClassifier(i);
  1655.   m_learningRate = origRate;
  1656.   m_instances = new Instances(m_instances, 0);   
  1657.   return;
  1658. }
  1659.       }
  1660.       ////////////////////////do validation testing if applicable
  1661.       if (m_valSize != 0) {
  1662. right = 0;
  1663. for (int nob = 0; nob < valSet.numInstances(); nob++) {
  1664.   m_currentInstance = valSet.instance(nob);
  1665.   if (!m_currentInstance.classIsMissing()) {
  1666.     //this is where the network updating occurs, for the validation set
  1667.     resetNetwork();
  1668.     calculateOutputs();
  1669.     right += (calculateErrors() / valSet.numClasses()) 
  1670.       * m_currentInstance.weight();
  1671.     //note 'right' could be calculated here just using
  1672.     //the calculate output values. This would be faster.
  1673.     //be less modular
  1674.   }
  1675.   
  1676. }
  1677. if (right < lastRight) {
  1678.   driftOff = 0;
  1679. }
  1680. else {
  1681.   driftOff++;
  1682. }
  1683. lastRight = right;
  1684. if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
  1685.   m_accepted = true;
  1686. }
  1687. right /= totalValWeight;
  1688.       }
  1689.       m_epoch = noa;
  1690.       m_error = right;
  1691.       //shows what the neuralnet is upto if a gui exists. 
  1692.       updateDisplay();
  1693.       //This junction controls what state the gui is in at the end of each
  1694.       //epoch, Such as if it is paused, if it is resumable etc...
  1695.       if (m_gui) {
  1696. while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) && 
  1697. !m_accepted) {
  1698.   m_stopIt = true;
  1699.   m_stopped = true;
  1700.   if (m_epoch >= m_numEpochs && m_valSize == 0) {
  1701.     
  1702.     m_controlPanel.m_startStop.setEnabled(false);
  1703.   }
  1704.   else {
  1705.     m_controlPanel.m_startStop.setEnabled(true);
  1706.   }
  1707.   m_controlPanel.m_startStop.setText("Start");
  1708.   m_controlPanel.m_startStop.setActionCommand("Start");
  1709.   m_controlPanel.m_changeEpochs.setEnabled(true);
  1710.   m_controlPanel.m_changeLearning.setEnabled(true);
  1711.   m_controlPanel.m_changeMomentum.setEnabled(true);
  1712.   
  1713.   blocker(true);
  1714.   if (m_numeric) {
  1715.     setEndsToLinear();
  1716.   }
  1717. }
  1718. m_controlPanel.m_changeEpochs.setEnabled(false);
  1719. m_controlPanel.m_changeLearning.setEnabled(false);
  1720. m_controlPanel.m_changeMomentum.setEnabled(false);
  1721. m_stopped = false;
  1722. //if the network has been accepted stop the training loop
  1723. if (m_accepted) {
  1724.   m_win.dispose();
  1725.   m_controlPanel = null;
  1726.   m_nodePanel = null;
  1727.   m_instances = new Instances(m_instances, 0);
  1728.   return;
  1729. }
  1730.       }
  1731.       if (m_accepted) {
  1732. m_instances = new Instances(m_instances, 0);
  1733. return;
  1734.       }
  1735.     }
  1736.     if (m_gui) {
  1737.       m_win.dispose();
  1738.       m_controlPanel = null;
  1739.       m_nodePanel = null;
  1740.     }
  1741.     m_instances = new Instances(m_instances, 0);  
  1742.   }
  1743.   /**
  1744.    * Call this function to predict the class of an instance once a 
  1745.    * classification model has been built with the buildClassifier call.
  1746.    * @param i The instance to classify.
  1747.    * @return A double array filled with the probabilities of each class type.
  1748.    * @exception if can't classify instance.
  1749.    */
  1750.   public double[] distributionForInstance(Instance i) throws Exception {
  1751.     
  1752.     if (m_useNomToBin) {
  1753.       m_nominalToBinaryFilter.input(i);
  1754.       m_currentInstance = m_nominalToBinaryFilter.output();
  1755.     }
  1756.     else {
  1757.       m_currentInstance = i;
  1758.     }
  1759.     
  1760.     if (m_normalizeAttributes) {
  1761.       for (int noa = 0; noa < m_instances.numAttributes(); noa++) {
  1762. if (noa != m_instances.classIndex()) {
  1763.   if (m_attributeRanges[noa] != 0) {
  1764.     m_currentInstance.setValue(noa, (m_currentInstance.value(noa) - 
  1765.      m_attributeBases[noa]) / 
  1766.        m_attributeRanges[noa]);
  1767.   }
  1768.   else {
  1769.     m_currentInstance.setValue(noa, m_currentInstance.value(noa) -
  1770.        m_attributeBases[noa]);
  1771.   }
  1772. }
  1773.       }
  1774.     }
  1775.     resetNetwork();
  1776.     
  1777.     //since all the output values are needed.
  1778.     //They are calculated manually here and the values collected.
  1779.     double[] theArray = new double[m_numClasses];
  1780.     for (int noa = 0; noa < m_numClasses; noa++) {
  1781.       theArray[noa] = m_outputs[noa].outputValue(true);
  1782.     }
  1783.     if (m_instances.classAttribute().isNumeric()) {
  1784.       return theArray;
  1785.     }
  1786.     
  1787.     //now normalize the array
  1788.     double count = 0;
  1789.     for (int noa = 0; noa < m_numClasses; noa++) {
  1790.       count += theArray[noa];
  1791.     }
  1792.     if (count <= 0) {
  1793.       return null;
  1794.     }
  1795.     for (int noa = 0; noa < m_numClasses; noa++) {
  1796.       theArray[noa] /= count;
  1797.     }
  1798.     return theArray;
  1799.   }
  1800.   
  1801.   /**
  1802.    * Returns an enumeration describing the available options.
  1803.    *
  1804.    * @return an enumeration of all the available options.
  1805.    */
  1806.   public Enumeration listOptions() {
  1807.     
  1808.     Vector newVector = new Vector(14);
  1809.     newVector.addElement(new Option(
  1810.       "tLearning Rate for the backpropagation algorithm.n"
  1811.       +"t(Value should be between 0 - 1, Default = 0.3).",
  1812.       "L", 1, "-L <learning rate>"));
  1813.     newVector.addElement(new Option(
  1814.       "tMomentum Rate for the backpropagation algorithm.n"
  1815.       +"t(Value should be between 0 - 1, Default = 0.2).",
  1816.       "M", 1, "-M <momentum>"));
  1817.     newVector.addElement(new Option(
  1818.       "tNumber of epochs to train through.n"
  1819.       +"t(Default = 500).",
  1820.       "N", 1,"-N <number of epochs>"));
  1821.     newVector.addElement(new Option(
  1822.       "tPercentage size of validation set to use to terminate" +
  1823.       " training (if this is non zero it can pre-empt num of epochs.n"
  1824.       +"t(Value should be between 0 - 100, Default = 0).",
  1825.       "V", 1, "-V <percentage size of validation set>"));
  1826.     newVector.addElement(new Option(
  1827.       "tThe value used to seed the random number generator" +
  1828.       "t(Value should be >= 0 and and a long, Default = 0).",
  1829.       "S", 1, "-S <seed>"));
  1830.     newVector.addElement(new Option(
  1831.       "tThe consequetive number of errors allowed for validation" +
  1832.       " testing before the netwrok terminates." +
  1833.       "t(Value should be > 0, Default = 20).",
  1834.       "E", 1, "-E <threshold for number of consequetive errors>"));
  1835.     newVector.addElement(new Option(
  1836.               "tGUI will be opened.n"
  1837.       +"t(Use this to bring up a GUI).",
  1838.       "G", 0,"-G"));
  1839.     newVector.addElement(new Option(
  1840.               "tAutocreation of the network connections will NOT be done.n"
  1841.       +"t(This will be ignored if -G is NOT set)",
  1842.       "A", 0,"-A"));
  1843.     newVector.addElement(new Option(
  1844.               "tA NominalToBinary filter will NOT automatically be used.n"
  1845.       +"t(Set this to not use a NominalToBinary filter).",
  1846.       "B", 0,"-B"));
  1847.     newVector.addElement(new Option(
  1848.       "tThe hidden layers to be created for the network.n"
  1849.       +"t(Value should be a list of comma seperated Natural numbers" +
  1850.       " or the letters 'a' = (attribs + classes) / 2, 'i'" +
  1851.       " = attribs, 'o' = classes, 't' = attribs .+ classes)" +
  1852.       " For wildcard values" +
  1853.       ",Default = a).",
  1854.       "H", 1, "-H <comma seperated numbers for nodes on each layer>"));
  1855.     newVector.addElement(new Option(
  1856.               "tNormalizing a numeric class will NOT be done.n"
  1857.       +"t(Set this to not normalize the class if it's numeric).",
  1858.       "C", 0,"-C"));
  1859.     newVector.addElement(new Option(
  1860.               "tNormalizing the attributes will NOT be done.n"
  1861.       +"t(Set this to not normalize the attributes).",
  1862.       "I", 0,"-I"));
  1863.     newVector.addElement(new Option(
  1864.               "tReseting the network will NOT be allowed.n"
  1865.       +"t(Set this to not allow the network to reset).",
  1866.       "R", 0,"-R"));
  1867.     newVector.addElement(new Option(
  1868.               "tLearning rate decay will occur.n"
  1869.       +"t(Set this to cause the learning rate to decay).",
  1870.       "D", 0,"-D"));
  1871.     
  1872.     
  1873.     return newVector.elements();
  1874.   }
  1875.   /**
  1876.    * Parses a given list of options. Valid options are:<p>
  1877.    *
  1878.    * -L num <br>
  1879.    * Set the learning rate.
  1880.    * (default 0.3) <p>
  1881.    *
  1882.    * -M num <br>
  1883.    * Set the momentum
  1884.    * (default 0.2) <p>
  1885.    *
  1886.    * -N num <br>
  1887.    * Set the number of epochs to train through.
  1888.    * (default 500) <p>
  1889.    *
  1890.    * -V num <br>
  1891.    * Set the percentage size of the validation set from the training to use.
  1892.    * (default 0 (no validation set is used, instead num of epochs is used) <p>
  1893.    *
  1894.    * -S num <br>
  1895.    * Set the seed for the random number generator.
  1896.    * (default 0) <p>
  1897.    *
  1898.    * -E num <br>
  1899.    * Set the threshold for the number of consequetive errors allowed during
  1900.    * validation testing.
  1901.    * (default 20) <p>
  1902.    *
  1903.    * -G <br>
  1904.    * Bring up a GUI for the neural net.
  1905.    * <p>
  1906.    *
  1907.    * -A <br>
  1908.    * Do not automatically create the connections in the net.
  1909.    * (can only be used if -G is specified) <p>
  1910.    *
  1911.    * -B <br>
  1912.    * Do Not automatically Preprocess the instances with a nominal to binary 
  1913.    * filter. <p>
  1914.    *
  1915.    * -H str <br>
  1916.    * Set the number of nodes to be used on each layer. Each number represents
  1917.    * its own layer and the num of nodes on that layer. Each number should be
  1918.    * comma seperated. There are also the wildcards 'a', 'i', 'o', 't'
  1919.    * (default 4) <p>
  1920.    *
  1921.    * -C <br>
  1922.    * Do not automatically Normalize the class if it's numeric. <p>
  1923.    *
  1924.    * -I <br>
  1925.    * Do not automatically Normalize the attributes. <p>
  1926.    *
  1927.    * -R <br>
  1928.    * Do not allow the network to be automatically reset. <p>
  1929.    *
  1930.    * -D <br>
  1931.    * Cause the learning rate to decay as training is done. <p>
  1932.    *
  1933.    * @param options the list of options as an array of strings
  1934.    * @exception Exception if an option is not supported
  1935.    */
  1936.   public void setOptions(String[] options) throws Exception {
  1937.     //the defaults can be found here!!!!
  1938.     String learningString = Utils.getOption('L', options);
  1939.     if (learningString.length() != 0) {
  1940.       setLearningRate((new Double(learningString)).doubleValue());
  1941.     } else {
  1942.       setLearningRate(0.3);
  1943.     }
  1944.     String momentumString = Utils.getOption('M', options);
  1945.     if (momentumString.length() != 0) {
  1946.       setMomentum((new Double(momentumString)).doubleValue());
  1947.     } else {
  1948.       setMomentum(0.2);
  1949.     }
  1950.     String epochsString = Utils.getOption('N', options);
  1951.     if (epochsString.length() != 0) {
  1952.       setTrainingTime(Integer.parseInt(epochsString));
  1953.     } else {
  1954.       setTrainingTime(500);
  1955.     }
  1956.     String valSizeString = Utils.getOption('V', options);
  1957.     if (valSizeString.length() != 0) {
  1958.       setValidationSetSize(Integer.parseInt(valSizeString));
  1959.     } else {
  1960.       setValidationSetSize(0);
  1961.     }
  1962.     String seedString = Utils.getOption('S', options);
  1963.     if (seedString.length() != 0) {
  1964.       setRandomSeed(Long.parseLong(seedString));
  1965.     } else {
  1966.       setRandomSeed(0);
  1967.     }
  1968.     String thresholdString = Utils.getOption('E', options);
  1969.     if (thresholdString.length() != 0) {
  1970.       setValidationThreshold(Integer.parseInt(thresholdString));
  1971.     } else {
  1972.       setValidationThreshold(20);
  1973.     }
  1974.     String hiddenLayers = Utils.getOption('H', options);
  1975.     if (hiddenLayers.length() != 0) {
  1976.       setHiddenLayers(hiddenLayers);
  1977.     } else {
  1978.       setHiddenLayers("a");
  1979.     }
  1980.     if (Utils.getFlag('G', options)) {
  1981.       setGUI(true);
  1982.     } else {
  1983.       setGUI(false);
  1984.     } //small note. since the gui is the only option that can change the other
  1985.     //options this should be set first to allow the other options to set 
  1986.     //properly
  1987.     if (Utils.getFlag('A', options)) {
  1988.       setAutoBuild(false);
  1989.     } else {
  1990.       setAutoBuild(true);
  1991.     }
  1992.     if (Utils.getFlag('B', options)) {
  1993.       setNominalToBinaryFilter(false);
  1994.     } else {
  1995.       setNominalToBinaryFilter(true);
  1996.     }
  1997.     if (Utils.getFlag('C', options)) {
  1998.       setNormalizeNumericClass(false);
  1999.     } else {
  2000.       setNormalizeNumericClass(true);
  2001.     }
  2002.     if (Utils.getFlag('I', options)) {
  2003.       setNormalizeAttributes(false);
  2004.     } else {
  2005.       setNormalizeAttributes(true);
  2006.     }
  2007.     if (Utils.getFlag('R', options)) {
  2008.       setReset(false);
  2009.     } else {
  2010.       setReset(true);
  2011.     }
  2012.     if (Utils.getFlag('D', options)) {
  2013.       setDecay(true);
  2014.     } else {
  2015.       setDecay(false);
  2016.     }
  2017.     
  2018.     Utils.checkForRemainingOptions(options);
  2019.   }
  2020.   
  2021.   /**
  2022.    * Gets the current settings of NeuralNet.
  2023.    *
  2024.    * @return an array of strings suitable for passing to setOptions()
  2025.    */
  2026.   public String [] getOptions() {
  2027.     String [] options = new String [21];
  2028.     int current = 0;
  2029.     options[current++] = "-L"; options[current++] = "" + getLearningRate(); 
  2030.     options[current++] = "-M"; options[current++] = "" + getMomentum();
  2031.     options[current++] = "-N"; options[current++] = "" + getTrainingTime(); 
  2032.     options[current++] = "-V"; options[current++] = "" +getValidationSetSize();
  2033.     options[current++] = "-S"; options[current++] = "" + getRandomSeed();
  2034.     options[current++] = "-E"; options[current++] =""+getValidationThreshold();
  2035.     options[current++] = "-H"; options[current++] = getHiddenLayers();
  2036.     if (getGUI()) {
  2037.       options[current++] = "-G";
  2038.     }
  2039.     if (!getAutoBuild()) {
  2040.       options[current++] = "-A";
  2041.     }
  2042.     if (!getNominalToBinaryFilter()) {
  2043.       options[current++] = "-B";
  2044.     }
  2045.     if (!getNormalizeNumericClass()) {
  2046.       options[current++] = "-C";
  2047.     }
  2048.     if (!getNormalizeAttributes()) {
  2049.       options[current++] = "-I";
  2050.     }
  2051.     if (!getReset()) {
  2052.       options[current++] = "-R";
  2053.     }
  2054.     if (getDecay()) {
  2055.       options[current++] = "-D";
  2056.     }
  2057.     
  2058.     while (current < options.length) {
  2059.       options[current++] = "";
  2060.     }
  2061.     return options;
  2062.   }
  2063.   
  2064.   /**
  2065.    * @return string describing the model.
  2066.    */
  2067.   public String toString() {
  2068.     StringBuffer model = new StringBuffer(m_neuralNodes.length * 100); 
  2069.     //just a rough size guess
  2070.     NeuralNode con;
  2071.     double[] weights;
  2072.     NeuralConnection[] inputs;
  2073.     for (int noa = 0; noa < m_neuralNodes.length; noa++) {
  2074.       con = (NeuralNode) m_neuralNodes[noa];  //this would need a change
  2075.                                               //for items other than nodes!!!
  2076.       weights = con.getWeights();
  2077.       inputs = con.getInputs();
  2078.       if (con.getMethod() instanceof SigmoidUnit) {
  2079. model.append("Sigmoid ");
  2080.       }
  2081.       else if (con.getMethod() instanceof LinearUnit) {
  2082. model.append("Linear ");
  2083.       }
  2084.       model.append("Node " + con.getId() + "n    Inputs    Weightsn");
  2085.       model.append("    Threshold    " + weights[0] + "n");
  2086.       for (int nob = 1; nob < con.getNumInputs() + 1; nob++) {
  2087. if ((inputs[nob - 1].getType() & NeuralConnection.PURE_INPUT) 
  2088.     == NeuralConnection.PURE_INPUT) {
  2089.   model.append("    Attrib " + 
  2090.        m_instances.attribute(((NeuralEnd)inputs[nob-1]).
  2091.      getLink()).name()
  2092.        + "    " + weights[nob] + "n");
  2093. }
  2094. else {
  2095.   model.append("    Node " + inputs[nob-1].getId() + "    " +
  2096.        weights[nob] + "n");
  2097. }
  2098.       }      
  2099.     }
  2100.     //now put in the ends
  2101.     for (int noa = 0; noa < m_outputs.length; noa++) {
  2102.       inputs = m_outputs[noa].getInputs();
  2103.       model.append("Class " + 
  2104.    m_instances.classAttribute().
  2105.    value(m_outputs[noa].getLink()) + 
  2106.    "n    Inputn");
  2107.       for (int nob = 0; nob < m_outputs[noa].getNumInputs(); nob++) {
  2108. if ((inputs[nob].getType() & NeuralConnection.PURE_INPUT)
  2109.     == NeuralConnection.PURE_INPUT) {
  2110.   model.append("    Attrib " +
  2111.        m_instances.attribute(((NeuralEnd)inputs[nob]).
  2112.      getLink()).name() + "n");
  2113. }
  2114. else {
  2115.   model.append("    Node " + inputs[nob].getId() + "n");
  2116. }
  2117.       }
  2118.     }
  2119.     return model.toString();
  2120.   }
  2121.   /**
  2122.    * This will return a string describing the classifier.
  2123.    * @return The string.
  2124.    */
  2125.   public String globalInfo() {
  2126.     return "This neural network uses backpropagation to train.";
  2127.   }
  2128.   
  2129.   /**
  2130.    * @return a string to describe the learning rate option.
  2131.    */
  2132.   public String learningRateTipText() {
  2133.     return "The amount the" + 
  2134.       " weights are updated.";
  2135.   }
  2136.   
  2137.   /**
  2138.    * @return a string to describe the momentum option.
  2139.    */
  2140.   public String momentumTipText() {
  2141.     return "Momentum applied to the weights during updating.";
  2142.   }
  2143.   /**
  2144.    * @return a string to describe the AutoBuild option.
  2145.    */
  2146.   public String autoBuildTipText() {
  2147.     return "Adds and connects up hidden layers in the network.";
  2148.   }
  2149.   /**
  2150.    * @return a string to describe the random seed option.
  2151.    */
  2152.   public String randomSeedTipText() {
  2153.     return "Seed used to initialise the random number generator." +
  2154.       "Random numbers are used for setting the initial weights of the" +
  2155.       " connections betweem nodes, and also for shuffling the training data.";
  2156.   }
  2157.   
  2158.   /**
  2159.    * @return a string to describe the validation threshold option.
  2160.    */
  2161.   public String validationThresholdTipText() {
  2162.     return "Used to terminate validation testing." +
  2163.       "The value here dictates how many times in a row the validation set" +
  2164.       " error can get worse before training is terminated.";
  2165.   }
  2166.   
  2167.   /**
  2168.    * @return a string to describe the GUI option.
  2169.    */
  2170.   public String GUITipText() {
  2171.     return "Brings up a gui interface." +
  2172.       " This will allow the pausing and altering of the nueral network" +
  2173.       " during training.nn" +
  2174.       "* To add a node left click (this node will be automatically selected," +
  2175.       " ensure no other nodes were selected).n" +
  2176.       "* To select a node left click on it either while no other node is" +
  2177.       " selected or while holding down the control key (this toggles that" +
  2178.       " node as being selected and not selected.n" + 
  2179.       "* To connect a node, first have the start node(s) selected, then click"+
  2180.       " either the end node or on an empty space (this will create a new node"+
  2181.       " that is connected with the selected nodes). The selection status of" +
  2182.       " nodes will stay the same after the connection. (Note these are" +
  2183.       " directed connections, also a connection between two nodes will not" +
  2184.       " be established more than once and certain connections that are" + 
  2185.       " deemed to be invalid will not be made).n" +
  2186.       "* To remove a connection select one of the connected node(s) in the" +
  2187.       " connection and then right click the other node (it does not matter" +
  2188.       " whether the node is the start or end the connection will be removed" +
  2189.       ").n" +
  2190.       "* To remove a node right click it while no other nodes (including it)" +
  2191.       " are selected. (This will also remove all connections to it)n." +
  2192.       "* To deselect a node either left click it while holding down control," +
  2193.       " or right click on empty space.n" +
  2194.       "* The raw inputs are provided from the labels on the left.n" +
  2195.       "* The red nodes are hidden layers.n" +
  2196.       "* The orange nodes are the output nodes.n" +
  2197.       "* The labels on the right show the class the output node represents." +
  2198.       " Note that with a numeric class the output node will automatically be" +
  2199.       " made into an unthresholded linear unit.nn" +
  2200.       "Alterations to the neural network can only be done while the network" +
  2201.       " is not running, This also applies to the learning rate and other" +
  2202.       " fields on the control panel.nn" + 
  2203.       "* You can accept the network as being finished at any time.n" +
  2204.       "* The network is automatically paused at the beginning.n" +
  2205.       "* There is a running indication of what epoch the network is up to" + 
  2206.       " and what the (rough) error for that epoch was (or for" +
  2207.       " the validation if that is being used). Note that this error value" +
  2208.       " is based on a network that changes as the value is computed." +
  2209.       " (also depending on whether" +
  2210.       " the class is normalized will effect the error reported for numeric" +
  2211.       " classes.n" +
  2212.       "* Once the network is done it will pause again and either wait to be" +
  2213.       " accepted or trained more.nn" +
  2214.       "Note that if the gui is not set the network will not require any" +
  2215.       " interaction.n";
  2216.   }
  2217.   
  2218.   /**
  2219.    * @return a string to describe the validation size option.
  2220.    */
  2221.   public String validationSetSizeTipText() {
  2222.     return "The percentage size of the validation set." +
  2223.       "(The training will continue until it is observed that" +
  2224.       " the error on the validation set has been consistently getting" +
  2225.       " worse, or if the training time is reached).n" +
  2226.       "If This is set to zero no validation set will be used and instead" +
  2227.       " the network will train for the specified number of epochs.";
  2228.   }
  2229.   
  2230.   /**
  2231.    * @return a string to describe the learning rate option.
  2232.    */
  2233.   public String trainingTimeTipText() {
  2234.     return "The number of epochs to train through." + 
  2235.       " If the validation set is non-zero then it can terminate the network" +
  2236.       " early";
  2237.   }
  2238.   /**
  2239.    * @return a string to describe the nominal to binary option.
  2240.    */
  2241.   public String nominalToBinaryFilterTipText() {
  2242.     return "This will preprocess the instances with the filter." +
  2243.       " This could help improve performance if there are nominal attributes" +
  2244.       " in the data.";
  2245.   }
  2246.   /**
  2247.    * @return a string to describe the hidden layers in the network.
  2248.    */
  2249.   public String hiddenLayersTipText() {
  2250.     return "This defines the hidden layers of the neural network." +
  2251.       " This is a list of positive whole numbers. 1 for each hidden layer." +
  2252.       " Comma seperated. To have no hidden layers put a single 0 here." +
  2253.       " This will only be used if autobuild is set. There are also wildcard" +
  2254.       " values 'a' = (attribs + classes) / 2, 'i' = attribs, 'o' = classes" +
  2255.       " , 't' = attribs + classes.";
  2256.   }
  2257.   /**
  2258.    * @return a string to describe the nominal to binary option.
  2259.    */
  2260.   public String normalizeNumericClassTipText() {
  2261.     return "This will normalize the class if it's numeric." +
  2262.       " This could help improve performance of the network, It normalizes" +
  2263.       " the class to be between -1 and 1. Note that this is only internally" +
  2264.       ", the output will be scaled back to the original range.";
  2265.   }
  2266.   /**
  2267.    * @return a string to describe the nominal to binary option.
  2268.    */
  2269.   public String normalizeAttributesTipText() {
  2270.     return "This will normalize the attributes." +
  2271.       " This could help improve performance of the network." +
  2272.       " This is not reliant on the class being numeric. This will also" +
  2273.       " normalize nominal attributes as well (after they have been run" +
  2274.       " through the nominal to binary filter if that is in use) so that the" +
  2275.       " nominal values are between -1 and 1";
  2276.   }
  2277.   /**
  2278.    * @return a string to describe the Reset option.
  2279.    */
  2280.   public String resetTipText() {
  2281.     return "This will allow the network to reset with a lower learning rate." +
  2282.       " If the network diverges from the answer this will automatically" +
  2283.       " reset the network with a lower learning rate and begin training" +
  2284.       " again. This option is only available if the gui is not set. Note" +
  2285.       " that if the network diverges but isn't allowed to reset it will" +
  2286.       " fail the training process and return an error message.";
  2287.   }
  2288.   
  2289.   /**
  2290.    * @return a string to describe the Decay option.
  2291.    */
  2292.   public String decayTipText() {
  2293.     return "This will cause the learning rate to decrease." +
  2294.       " This will divide the starting learning rate by the epoch number, to" +
  2295.       " determine what the current learning rate should be. This may help" +
  2296.       " to stop the network from diverging from the target output, as well" +
  2297.       " as improve general performance. Note that the decaying learning" +
  2298.       " rate will not be shown in the gui, only the original learning rate" +
  2299.       ". If the learning rate is changed in the gui, this is treated as the" +
  2300.       " starting learning rate.";
  2301.   }
  2302.     
  2303. }