Code/Resource
Windows Develop
Linux-Unix program
Internet-Socket-Network
Web Server
Browser Client
Ftp Server
Ftp Client
Browser Plugins
Proxy Server
Email Server
Email Client
WEB Mail
Firewall-Security
Telnet Server
Telnet Client
ICQ-IM-Chat
Search Engine
Sniffer Package capture
Remote Control
xml-soap-webservice
P2P
WEB(ASP,PHP,...)
TCP/IP Stack
SNMP
Grid Computing
SilverLight
DNS
Cluster Service
Network Security
Communication-Mobile
Game Program
Editor
Multimedia program
Graph program
Compiler program
Compress-Decompress algrithms
Crypt_Decrypt algrithms
Mathimatics-Numerical algorithms
MultiLanguage
Disk/Storage
Java Develop
assembly language
Applications
Other systems
Database system
Embeded-SCM Develop
FlashMX/Flex
source in ebook
Delphi VCL
OS Develop
MiddleWare
MPI
MacOS develop
LabView
ELanguage
Software/Tools
E-Books
Artical/Document
NeuralNetwork.java
Package: Weka-3-2.rar [view]
Upload User: rhdiban
Upload Date: 2013-08-09
Package Size: 15085k
Code Size: 71k
Category:
Windows Develop
Development Platform:
Java
- /*
- * This program is free software; you can redistribute it and/or modify
- * it under the terms of the GNU General Public License as published by
- * the Free Software Foundation; either version 2 of the License, or
- * (at your option) any later version.
- *
- * This program is distributed in the hope that it will be useful,
- * but WITHOUT ANY WARRANTY; without even the implied warranty of
- * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- * GNU General Public License for more details.
- *
- * You should have received a copy of the GNU General Public License
- * along with this program; if not, write to the Free Software
- * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
- */
- /*
- * NeuralConnection.java
- * Copyright (C) 2000 Malcolm Ware
- */
- package weka.classifiers.functions.neural;
- import java.util.*;
- import java.awt.*;
- import java.awt.event.*;
- import javax.swing.*;
- import weka.classifiers.*;
- import weka.core.*;
- import weka.filters.unsupervised.attribute.NominalToBinary;
- import weka.filters.Filter;
- /**
- * A Classifier that uses backpropagation to classify instances.
- * This network can be built by hand, created by an algorithm or both.
- * The network can also be monitored and modified during training time.
- * The nodes in this network are all sigmoid (except for when the class
- * is numeric in which case the the output nodes become unthresholded linear
- * units).
- *
- * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
- * @version $Revision: 1.10 $
- */
- public class NeuralNetwork extends DistributionClassifier
- implements OptionHandler, WeightedInstancesHandler {
- /**
- * Main method for testing this class.
- *
- * @param argv should contain command line options (see setOptions)
- */
- public static void main(String [] argv) {
- try {
- System.out.println(Evaluation.evaluateModel(new NeuralNetwork(), argv));
- } catch (Exception e) {
- System.err.println(e.getMessage());
- e.printStackTrace();
- }
- System.exit(0);
- }
- /**
- * This inner class is used to connect the nodes in the network up to
- * the data that they are classifying, Note that objects of this class are
- * only suitable to go on the attribute side or class side of the network
- * and not both.
- */
- protected class NeuralEnd extends NeuralConnection {
- /**
- * the value that represents the instance value this node represents.
- * For an input it is the attribute number, for an output, if nominal
- * it is the class value.
- */
- private int m_link;
- /** True if node is an input, False if it's an output. */
- private boolean m_input;
- public NeuralEnd(String id) {
- super(id);
- m_link = 0;
- m_input = true;
- }
- /**
- * Call this function to determine if the point at x,y is on the unit.
- * @param g The graphics context for font size info.
- * @param x The x coord.
- * @param y The y coord.
- * @param w The width of the display.
- * @param h The height of the display.
- * @return True if the point is on the unit, false otherwise.
- */
- public boolean onUnit(Graphics g, int x, int y, int w, int h) {
- FontMetrics fm = g.getFontMetrics();
- int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
- int t = (int)(m_y * h) - fm.getHeight() / 2;
- if (x < l || x > l + fm.stringWidth(m_id) + 4
- || y < t || y > t + fm.getHeight() + fm.getDescent() + 4) {
- return false;
- }
- return true;
- }
- /**
- * This will draw the node id to the graphics context.
- * @param g The graphics context.
- * @param w The width of the drawing area.
- * @param h The height of the drawing area.
- */
- public void drawNode(Graphics g, int w, int h) {
- if ((m_type & PURE_INPUT) == PURE_INPUT) {
- g.setColor(Color.green);
- }
- else {
- g.setColor(Color.orange);
- }
- FontMetrics fm = g.getFontMetrics();
- int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
- int t = (int)(m_y * h) - fm.getHeight() / 2;
- g.fill3DRect(l, t, fm.stringWidth(m_id) + 4
- , fm.getHeight() + fm.getDescent() + 4
- , true);
- g.setColor(Color.black);
- g.drawString(m_id, l + 2, t + fm.getHeight() + 2);
- }
- /**
- * Call this function to draw the node highlighted.
- * @param g The graphics context.
- * @param w The width of the drawing area.
- * @param h The height of the drawing area.
- */
- public void drawHighlight(Graphics g, int w, int h) {
- g.setColor(Color.black);
- FontMetrics fm = g.getFontMetrics();
- int l = (int)(m_x * w) - fm.stringWidth(m_id) / 2;
- int t = (int)(m_y * h) - fm.getHeight() / 2;
- g.fillRect(l - 2, t - 2, fm.stringWidth(m_id) + 8
- , fm.getHeight() + fm.getDescent() + 8);
- drawNode(g, w, h);
- }
- /**
- * Call this to get the output value of this unit.
- * @param calculate True if the value should be calculated if it hasn't
- * been already.
- * @return The output value, or NaN, if the value has not been calculated.
- */
- public double outputValue(boolean calculate) {
- if (Double.isNaN(m_unitValue) && calculate) {
- if (m_input) {
- if (m_currentInstance.isMissing(m_link)) {
- m_unitValue = 0;
- }
- else {
- m_unitValue = m_currentInstance.value(m_link);
- }
- }
- else {
- //node is an output.
- m_unitValue = 0;
- for (int noa = 0; noa < m_numInputs; noa++) {
- m_unitValue += m_inputList[noa].outputValue(true);
- }
- if (m_numeric && m_normalizeClass) {
- //then scale the value;
- //this scales linearly from between -1 and 1
- m_unitValue = m_unitValue *
- m_attributeRanges[m_instances.classIndex()] +
- m_attributeBases[m_instances.classIndex()];
- }
- }
- }
- return m_unitValue;
- }
- /**
- * Call this to get the error value of this unit, which in this case is
- * the difference between the predicted class, and the actual class.
- * @param calculate True if the value should be calculated if it hasn't
- * been already.
- * @return The error value, or NaN, if the value has not been calculated.
- */
- public double errorValue(boolean calculate) {
- if (!Double.isNaN(m_unitValue) && Double.isNaN(m_unitError)
- && calculate) {
- if (m_input) {
- m_unitError = 0;
- for (int noa = 0; noa < m_numOutputs; noa++) {
- m_unitError += m_outputList[noa].errorValue(true);
- }
- }
- else {
- if (m_currentInstance.classIsMissing()) {
- m_unitError = .1;
- }
- else if (m_instances.classAttribute().isNominal()) {
- if (m_currentInstance.classValue() == m_link) {
- m_unitError = 1 - m_unitValue;
- }
- else {
- m_unitError = 0 - m_unitValue;
- }
- }
- else if (m_numeric) {
- if (m_normalizeClass) {
- if (m_attributeRanges[m_instances.classIndex()] == 0) {
- m_unitError = 0;
- }
- else {
- m_unitError = (m_currentInstance.classValue() - m_unitValue ) /
- m_attributeRanges[m_instances.classIndex()];
- //m_numericRange;
- }
- }
- else {
- m_unitError = m_currentInstance.classValue() - m_unitValue;
- }
- }
- }
- }
- return m_unitError;
- }
- /**
- * Call this to reset the value and error for this unit, ready for the next
- * run. This will also call the reset function of all units that are
- * connected as inputs to this one.
- * This is also the time that the update for the listeners will be
- * performed.
- */
- public void reset() {
- if (!Double.isNaN(m_unitValue) || !Double.isNaN(m_unitError)) {
- m_unitValue = Double.NaN;
- m_unitError = Double.NaN;
- m_weightsUpdated = false;
- for (int noa = 0; noa < m_numInputs; noa++) {
- m_inputList[noa].reset();
- }
- }
- }
- /**
- * Call this function to set What this end unit represents.
- * @param input True if this unit is used for entering an attribute,
- * False if it's used for determining a class value.
- * @param val The attribute number or class type that this unit represents.
- * (for nominal attributes).
- */
- public void setLink(boolean input, int val) throws Exception {
- m_input = input;
- if (input) {
- m_type = PURE_INPUT;
- }
- else {
- m_type = PURE_OUTPUT;
- }
- if (val < 0 || (input && val > m_instances.numAttributes())
- || (!input && m_instances.classAttribute().isNominal()
- && val > m_instances.classAttribute().numValues())) {
- m_link = 0;
- }
- else {
- m_link = val;
- }
- }
- /**
- * @return link for this node.
- */
- public int getLink() {
- return m_link;
- }
- }
- /** Inner class used to draw the nodes onto.(uses the node lists!!)
- * This will also handle the user input. */
- private class NodePanel extends JPanel {
- /**
- * The constructor.
- */
- public NodePanel() {
- addMouseListener(new MouseAdapter() {
- public void mousePressed(MouseEvent e) {
- if (!m_stopped) {
- return;
- }
- if ((e.getModifiers() & e.BUTTON1_MASK) == e.BUTTON1_MASK) {
- Graphics g = NodePanel.this.getGraphics();
- int x = e.getX();
- int y = e.getY();
- int w = NodePanel.this.getWidth();
- int h = NodePanel.this.getHeight();
- int u = 0;
- FastVector tmp = new FastVector(4);
- for (int noa = 0; noa < m_numAttributes; noa++) {
- if (m_inputs[noa].onUnit(g, x, y, w, h)) {
- tmp.addElement(m_inputs[noa]);
- selection(tmp,
- (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
- , true);
- return;
- }
- }
- for (int noa = 0; noa < m_numClasses; noa++) {
- if (m_outputs[noa].onUnit(g, x, y, w, h)) {
- tmp.addElement(m_outputs[noa]);
- selection(tmp,
- (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
- , true);
- return;
- }
- }
- for (int noa = 0; noa < m_neuralNodes.length; noa++) {
- if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
- tmp.addElement(m_neuralNodes[noa]);
- selection(tmp,
- (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
- , true);
- return;
- }
- }
- NeuralNode temp = new NeuralNode(String.valueOf(m_nextId),
- m_random, m_sigmoidUnit);
- m_nextId++;
- temp.setX((double)e.getX() / w);
- temp.setY((double)e.getY() / h);
- tmp.addElement(temp);
- addNode(temp);
- selection(tmp, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
- , true);
- }
- else {
- //then right click
- Graphics g = NodePanel.this.getGraphics();
- int x = e.getX();
- int y = e.getY();
- int w = NodePanel.this.getWidth();
- int h = NodePanel.this.getHeight();
- int u = 0;
- FastVector tmp = new FastVector(4);
- for (int noa = 0; noa < m_numAttributes; noa++) {
- if (m_inputs[noa].onUnit(g, x, y, w, h)) {
- tmp.addElement(m_inputs[noa]);
- selection(tmp,
- (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
- , false);
- return;
- }
- }
- for (int noa = 0; noa < m_numClasses; noa++) {
- if (m_outputs[noa].onUnit(g, x, y, w, h)) {
- tmp.addElement(m_outputs[noa]);
- selection(tmp,
- (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
- , false);
- return;
- }
- }
- for (int noa = 0; noa < m_neuralNodes.length; noa++) {
- if (m_neuralNodes[noa].onUnit(g, x, y, w, h)) {
- tmp.addElement(m_neuralNodes[noa]);
- selection(tmp,
- (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
- , false);
- return;
- }
- }
- selection(null, (e.getModifiers() & e.CTRL_MASK) == e.CTRL_MASK
- , false);
- }
- }
- });
- }
- /**
- * This function gets called when the user has clicked something
- * It will amend the current selection or connect the current selection
- * to the new selection.
- * Or if nothing was selected and the right button was used it will
- * delete the node.
- * @param v The units that were selected.
- * @param ctrl True if ctrl was held down.
- * @param left True if it was the left mouse button.
- */
- private void selection(FastVector v, boolean ctrl, boolean left) {
- if (v == null) {
- //then unselect all.
- m_selected.removeAllElements();
- repaint();
- return;
- }
- //then exclusive or the new selection with the current one.
- if ((ctrl || m_selected.size() == 0) && left) {
- boolean removed = false;
- for (int noa = 0; noa < v.size(); noa++) {
- removed = false;
- for (int nob = 0; nob < m_selected.size(); nob++) {
- if (v.elementAt(noa) == m_selected.elementAt(nob)) {
- //then remove that element
- m_selected.removeElementAt(nob);
- removed = true;
- break;
- }
- }
- if (!removed) {
- m_selected.addElement(v.elementAt(noa));
- }
- }
- repaint();
- return;
- }
- if (left) {
- //then connect the current selection to the new one.
- for (int noa = 0; noa < m_selected.size(); noa++) {
- for (int nob = 0; nob < v.size(); nob++) {
- NeuralConnection
- .connect((NeuralConnection)m_selected.elementAt(noa)
- , (NeuralConnection)v.elementAt(nob));
- }
- }
- }
- else if (m_selected.size() > 0) {
- //then disconnect the current selection from the new one.
- for (int noa = 0; noa < m_selected.size(); noa++) {
- for (int nob = 0; nob < v.size(); nob++) {
- NeuralConnection
- .disconnect((NeuralConnection)m_selected.elementAt(noa)
- , (NeuralConnection)v.elementAt(nob));
- NeuralConnection
- .disconnect((NeuralConnection)v.elementAt(nob)
- , (NeuralConnection)m_selected.elementAt(noa));
- }
- }
- }
- else {
- //then remove the selected node. (it was right clicked while
- //no other units were selected
- for (int noa = 0; noa < v.size(); noa++) {
- ((NeuralConnection)v.elementAt(noa)).removeAllInputs();
- ((NeuralConnection)v.elementAt(noa)).removeAllOutputs();
- removeNode((NeuralConnection)v.elementAt(noa));
- }
- }
- repaint();
- }
- /**
- * This will paint the nodes ontot the panel.
- * @param g The graphics context.
- */
- public void paintComponent(Graphics g) {
- super.paintComponent(g);
- int x = getWidth();
- int y = getHeight();
- if (25 * m_numAttributes > 25 * m_numClasses &&
- 25 * m_numAttributes > y) {
- setSize(x, 25 * m_numAttributes);
- }
- else if (25 * m_numClasses > y) {
- setSize(x, 25 * m_numClasses);
- }
- else {
- setSize(x, y);
- }
- y = getHeight();
- for (int noa = 0; noa < m_numAttributes; noa++) {
- m_inputs[noa].drawInputLines(g, x, y);
- }
- for (int noa = 0; noa < m_numClasses; noa++) {
- m_outputs[noa].drawInputLines(g, x, y);
- m_outputs[noa].drawOutputLines(g, x, y);
- }
- for (int noa = 0; noa < m_neuralNodes.length; noa++) {
- m_neuralNodes[noa].drawInputLines(g, x, y);
- }
- for (int noa = 0; noa < m_numAttributes; noa++) {
- m_inputs[noa].drawNode(g, x, y);
- }
- for (int noa = 0; noa < m_numClasses; noa++) {
- m_outputs[noa].drawNode(g, x, y);
- }
- for (int noa = 0; noa < m_neuralNodes.length; noa++) {
- m_neuralNodes[noa].drawNode(g, x, y);
- }
- for (int noa = 0; noa < m_selected.size(); noa++) {
- ((NeuralConnection)m_selected.elementAt(noa)).drawHighlight(g, x, y);
- }
- }
- }
- /**
- * This provides the basic controls for working with the neuralnetwork
- * @author Malcolm Ware (mfw4@cs.waikato.ac.nz)
- * @version $Revision: 1.10 $
- */
- class ControlPanel extends JPanel {
- /** The start stop button. */
- public JButton m_startStop;
- /** The button to accept the network (even if it hasn't done all epochs. */
- public JButton m_acceptButton;
- /** A label to state the number of epochs processed so far. */
- public JPanel m_epochsLabel;
- /** A label to state the total number of epochs to be processed. */
- public JLabel m_totalEpochsLabel;
- /** A text field to allow the changing of the total number of epochs. */
- public JTextField m_changeEpochs;
- /** A label to state the learning rate. */
- public JLabel m_learningLabel;
- /** A label to state the momentum. */
- public JLabel m_momentumLabel;
- /** A text field to allow the changing of the learning rate. */
- public JTextField m_changeLearning;
- /** A text field to allow the changing of the momentum. */
- public JTextField m_changeMomentum;
- /** A label to state roughly the accuracy of the network.(because the
- accuracy is calculated per epoch, but the network is changing
- throughout each epoch train).
- */
- public JPanel m_errorLabel;
- /** The constructor. */
- public ControlPanel() {
- setBorder(BorderFactory.createTitledBorder("Controls"));
- m_totalEpochsLabel = new JLabel("Num Of Epochs ");
- m_epochsLabel = new JPanel(){
- public void paintComponent(Graphics g) {
- super.paintComponent(g);
- g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
- g.drawString("Epoch " + m_epoch, 0, 10);
- }
- };
- m_epochsLabel.setFont(m_totalEpochsLabel.getFont());
- m_changeEpochs = new JTextField();
- m_changeEpochs.setText("" + m_numEpochs);
- m_errorLabel = new JPanel(){
- public void paintComponent(Graphics g) {
- super.paintComponent(g);
- g.setColor(m_controlPanel.m_totalEpochsLabel.getForeground());
- if (m_valSize == 0) {
- g.drawString("Error per Epoch = " +
- Utils.doubleToString(m_error, 7), 0, 10);
- }
- else {
- g.drawString("Validation Error per Epoch = "
- + Utils.doubleToString(m_error, 7), 0, 10);
- }
- }
- };
- m_errorLabel.setFont(m_epochsLabel.getFont());
- m_learningLabel = new JLabel("Learning Rate = ");
- m_momentumLabel = new JLabel("Momentum = ");
- m_changeLearning = new JTextField();
- m_changeMomentum = new JTextField();
- m_changeLearning.setText("" + m_learningRate);
- m_changeMomentum.setText("" + m_momentum);
- setLayout(new BorderLayout(15, 10));
- m_stopIt = true;
- m_accepted = false;
- m_startStop = new JButton("Start");
- m_startStop.setActionCommand("Start");
- m_acceptButton = new JButton("Accept");
- m_acceptButton.setActionCommand("Accept");
- JPanel buttons = new JPanel();
- buttons.setLayout(new BoxLayout(buttons, BoxLayout.Y_AXIS));
- buttons.add(m_startStop);
- buttons.add(m_acceptButton);
- add(buttons, BorderLayout.WEST);
- JPanel data = new JPanel();
- data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
- Box ab = new Box(BoxLayout.X_AXIS);
- ab.add(m_epochsLabel);
- data.add(ab);
- ab = new Box(BoxLayout.X_AXIS);
- Component b = Box.createGlue();
- ab.add(m_totalEpochsLabel);
- ab.add(m_changeEpochs);
- m_changeEpochs.setMaximumSize(new Dimension(200, 20));
- ab.add(b);
- data.add(ab);
- ab = new Box(BoxLayout.X_AXIS);
- ab.add(m_errorLabel);
- data.add(ab);
- add(data, BorderLayout.CENTER);
- data = new JPanel();
- data.setLayout(new BoxLayout(data, BoxLayout.Y_AXIS));
- ab = new Box(BoxLayout.X_AXIS);
- b = Box.createGlue();
- ab.add(m_learningLabel);
- ab.add(m_changeLearning);
- m_changeLearning.setMaximumSize(new Dimension(200, 20));
- ab.add(b);
- data.add(ab);
- ab = new Box(BoxLayout.X_AXIS);
- b = Box.createGlue();
- ab.add(m_momentumLabel);
- ab.add(m_changeMomentum);
- m_changeMomentum.setMaximumSize(new Dimension(200, 20));
- ab.add(b);
- data.add(ab);
- add(data, BorderLayout.EAST);
- m_startStop.addActionListener(new ActionListener() {
- public void actionPerformed(ActionEvent e) {
- if (e.getActionCommand().equals("Start")) {
- m_stopIt = false;
- m_startStop.setText("Stop");
- m_startStop.setActionCommand("Stop");
- int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
- m_numEpochs = n;
- m_changeEpochs.setText("" + m_numEpochs);
- double m=Double.valueOf(m_changeLearning.getText()).
- doubleValue();
- setLearningRate(m);
- m_changeLearning.setText("" + m_learningRate);
- m = Double.valueOf(m_changeMomentum.getText()).doubleValue();
- setMomentum(m);
- m_changeMomentum.setText("" + m_momentum);
- blocker(false);
- }
- else if (e.getActionCommand().equals("Stop")) {
- m_stopIt = true;
- m_startStop.setText("Start");
- m_startStop.setActionCommand("Start");
- }
- }
- });
- m_acceptButton.addActionListener(new ActionListener() {
- public void actionPerformed(ActionEvent e) {
- m_accepted = true;
- blocker(false);
- }
- });
- m_changeEpochs.addActionListener(new ActionListener() {
- public void actionPerformed(ActionEvent e) {
- int n = Integer.valueOf(m_changeEpochs.getText()).intValue();
- if (n > 0) {
- m_numEpochs = n;
- blocker(false);
- }
- }
- });
- }
- }
- /** The training instances. */
- private Instances m_instances;
- /** The current instance running through the network. */
- private Instance m_currentInstance;
- /** A flag to say that it's a numeric class. */
- private boolean m_numeric;
- /** The ranges for all the attributes. */
- private double[] m_attributeRanges;
- /** The base values for all the attributes. */
- private double[] m_attributeBases;
- /** The output units.(only feeds the errors, does no calcs) */
- private NeuralEnd[] m_outputs;
- /** The input units.(only feeds the inputs does no calcs) */
- private NeuralEnd[] m_inputs;
- /** All the nodes that actually comprise the logical neural net. */
- private NeuralConnection[] m_neuralNodes;
- /** The number of classes. */
- private int m_numClasses = 0;
- /** The number of attributes. */
- private int m_numAttributes = 0; //note the number doesn't include the class.
- /** The panel the nodes are displayed on. */
- private NodePanel m_nodePanel;
- /** The control panel. */
- private ControlPanel m_controlPanel;
- /** The next id number available for default naming. */
- private int m_nextId;
- /** A Vector list of the units currently selected. */
- private FastVector m_selected;
- /** A Vector list of the graphers. */
- private FastVector m_graphers;
- /** The number of epochs to train through. */
- private int m_numEpochs;
- /** a flag to state if the network should be running, or stopped. */
- private boolean m_stopIt;
- /** a flag to state that the network has in fact stopped. */
- private boolean m_stopped;
- /** a flag to state that the network should be accepted the way it is. */
- private boolean m_accepted;
- /** The window for the network. */
- private JFrame m_win;
- /** A flag to tell the build classifier to automatically build a neural net.
- */
- private boolean m_autoBuild;
- /** A flag to state that the gui for the network should be brought up.
- To allow interaction while training. */
- private boolean m_gui;
- /** An int to say how big the validation set should be. */
- private int m_valSize;
- /** The number to to use to quit on validation testing. */
- private int m_driftThreshold;
- /** The number used to seed the random number generator. */
- private long m_randomSeed;
- /** The actual random number generator. */
- private Random m_random;
- /** A flag to state that a nominal to binary filter should be used. */
- private boolean m_useNomToBin;
- /** The actual filter. */
- private NominalToBinary m_nominalToBinaryFilter;
- /** The string that defines the hidden layers */
- private String m_hiddenLayers;
- /** This flag states that the user wants the input values normalized. */
- private boolean m_normalizeAttributes;
- /** This flag states that the user wants the learning rate to decay. */
- private boolean m_decay;
- /** This is the learning rate for the network. */
- private double m_learningRate;
- /** This is the momentum for the network. */
- private double m_momentum;
- /** Shows the number of the epoch that the network just finished. */
- private int m_epoch;
- /** Shows the error of the epoch that the network just finished. */
- private double m_error;
- /** This flag states that the user wants the network to restart if it
- * is found to be generating infinity or NaN for the error value. This
- * would restart the network with the current options except that the
- * learning rate would be smaller than before, (perhaps half of its current
- * value). This option will not be available if the gui is chosen (if the
- * gui is open the user can fix the network themselves, it is an
- * architectural minefield for the network to be reset with the gui open). */
- private boolean m_reset;
- /** This flag states that the user wants the class to be normalized while
- * processing in the network is done. (the final answer will be in the
- * original range regardless). This option will only be used when the class
- * is numeric. */
- private boolean m_normalizeClass;
- /**
- * this is a sigmoid unit.
- */
- private SigmoidUnit m_sigmoidUnit;
- /**
- * This is a linear unit.
- */
- private LinearUnit m_linearUnit;
- /**
- * The constructor.
- */
- public NeuralNetwork() {
- m_instances = null;
- m_currentInstance = null;
- m_controlPanel = null;
- m_nodePanel = null;
- m_epoch = 0;
- m_error = 0;
- m_outputs = new NeuralEnd[0];
- m_inputs = new NeuralEnd[0];
- m_numAttributes = 0;
- m_numClasses = 0;
- m_neuralNodes = new NeuralConnection[0];
- m_selected = new FastVector(4);
- m_graphers = new FastVector(2);
- m_nextId = 0;
- m_stopIt = true;
- m_stopped = true;
- m_accepted = false;
- m_numeric = false;
- m_random = null;
- m_nominalToBinaryFilter = new NominalToBinary();
- m_sigmoidUnit = new SigmoidUnit();
- m_linearUnit = new LinearUnit();
- //setting all the options to their defaults. To completely change these
- //defaults they will also need to be changed down the bottom in the
- //setoptions function (the text info in the accompanying functions should
- //also be changed to reflect the new defaults
- m_normalizeClass = true;
- m_normalizeAttributes = true;
- m_autoBuild = true;
- m_gui = false;
- m_useNomToBin = true;
- m_driftThreshold = 20;
- m_numEpochs = 500;
- m_valSize = 0;
- m_randomSeed = 0;
- m_hiddenLayers = "a";
- m_learningRate = .3;
- m_momentum = .2;
- m_reset = true;
- m_decay = false;
- }
- /**
- * @param d True if the learning rate should decay.
- */
- public void setDecay(boolean d) {
- m_decay = d;
- }
- /**
- * @return the flag for having the learning rate decay.
- */
- public boolean getDecay() {
- return m_decay;
- }
- /**
- * This sets the network up to be able to reset itself with the current
- * settings and the learning rate at half of what it is currently. This
- * will only happen if the network creates NaN or infinite errors. Also this
- * will continue to happen until the network is trained properly. The
- * learning rate will also get set back to it's original value at the end of
- * this. This can only be set to true if the GUI is not brought up.
- * @param r True if the network should restart with it's current options
- * and set the learning rate to half what it currently is.
- */
- public void setReset(boolean r) {
- if (m_gui) {
- r = false;
- }
- m_reset = r;
- }
- /**
- * @return The flag for reseting the network.
- */
- public boolean getReset() {
- return m_reset;
- }
- /**
- * @param c True if the class should be normalized (the class will only ever
- * be normalized if it is numeric). (Normalization puts the range between
- * -1 - 1).
- */
- public void setNormalizeNumericClass(boolean c) {
- m_normalizeClass = c;
- }
- /**
- * @return The flag for normalizing a numeric class.
- */
- public boolean getNormalizeNumericClass() {
- return m_normalizeClass;
- }
- /**
- * @param a True if the attributes should be normalized (even nominal
- * attributes will get normalized here) (range goes between -1 - 1).
- */
- public void setNormalizeAttributes(boolean a) {
- m_normalizeAttributes = a;
- }
- /**
- * @return The flag for normalizing attributes.
- */
- public boolean getNormalizeAttributes() {
- return m_normalizeAttributes;
- }
- /**
- * @param f True if a nominalToBinary filter should be used on the
- * data.
- */
- public void setNominalToBinaryFilter(boolean f) {
- m_useNomToBin = f;
- }
- /**
- * @return The flag for nominal to binary filter use.
- */
- public boolean getNominalToBinaryFilter() {
- return m_useNomToBin;
- }
- /**
- * This seeds the random number generator, that is used when a random
- * number is needed for the network.
- * @param l The seed.
- */
- public void setRandomSeed(long l) {
- if (l >= 0) {
- m_randomSeed = l;
- }
- }
- /**
- * @return The seed for the random number generator.
- */
- public long getRandomSeed() {
- return m_randomSeed;
- }
- /**
- * This sets the threshold to use for when validation testing is being done.
- * It works by ending testing once the error on the validation set has
- * consecutively increased a certain number of times.
- * @param t The threshold to use for this.
- */
- public void setValidationThreshold(int t) {
- if (t > 0) {
- m_driftThreshold = t;
- }
- }
- /**
- * @return The threshold used for validation testing.
- */
- public int getValidationThreshold() {
- return m_driftThreshold;
- }
- /**
- * The learning rate can be set using this command.
- * NOTE That this is a static variable so it affect all networks that are
- * running.
- * Must be greater than 0 and no more than 1.
- * @param l The New learning rate.
- */
- public void setLearningRate(double l) {
- if (l > 0 && l <= 1) {
- m_learningRate = l;
- if (m_controlPanel != null) {
- m_controlPanel.m_changeLearning.setText("" + l);
- }
- }
- }
- /**
- * @return The learning rate for the nodes.
- */
- public double getLearningRate() {
- return m_learningRate;
- }
- /**
- * The momentum can be set using this command.
- * THE same conditions apply to this as to the learning rate.
- * @param m The new Momentum.
- */
- public void setMomentum(double m) {
- if (m >= 0 && m <= 1) {
- m_momentum = m;
- if (m_controlPanel != null) {
- m_controlPanel.m_changeMomentum.setText("" + m);
- }
- }
- }
- /**
- * @return The momentum for the nodes.
- */
- public double getMomentum() {
- return m_momentum;
- }
- /**
- * This will set whether the network is automatically built
- * or if it is left up to the user. (there is nothing to stop a user
- * from altering an autobuilt network however).
- * @param a True if the network should be auto built.
- */
- public void setAutoBuild(boolean a) {
- if (!m_gui) {
- a = true;
- }
- m_autoBuild = a;
- }
- /**
- * @return The auto build state.
- */
- public boolean getAutoBuild() {
- return m_autoBuild;
- }
- /**
- * This will set what the hidden layers are made up of when auto build is
- * enabled. Note to have no hidden units, just put a single 0, Any more
- * 0's will indicate that the string is badly formed and make it unaccepted.
- * Negative numbers, and floats will do the same. There are also some
- * wildcards. These are 'a' = (number of attributes + number of classes) / 2,
- * 'i' = number of attributes, 'o' = number of classes, and 't' = number of
- * attributes + number of classes.
- * @param h A string with a comma seperated list of numbers. Each number is
- * the number of nodes to be on a hidden layer.
- */
- public void setHiddenLayers(String h) {
- String tmp = "";
- StringTokenizer tok = new StringTokenizer(h, ",");
- if (tok.countTokens() == 0) {
- return;
- }
- double dval;
- int val;
- String c;
- boolean first = true;
- while (tok.hasMoreTokens()) {
- c = tok.nextToken().trim();
- if (c.equals("a") || c.equals("i") || c.equals("o") ||
- c.equals("t")) {
- tmp += c;
- }
- else {
- dval = Double.valueOf(c).doubleValue();
- val = (int)dval;
- if ((val == dval && (val != 0 || (tok.countTokens() == 0 && first)) &&
- val >= 0)) {
- tmp += val;
- }
- else {
- return;
- }
- }
- first = false;
- if (tok.hasMoreTokens()) {
- tmp += ", ";
- }
- }
- m_hiddenLayers = tmp;
- }
- /**
- * @return A string representing the hidden layers, each number is the number
- * of nodes on a hidden layer.
- */
- public String getHiddenLayers() {
- return m_hiddenLayers;
- }
- /**
- * This will set whether A GUI is brought up to allow interaction by the user
- * with the neural network during training.
- * @param a True if gui should be created.
- */
- public void setGUI(boolean a) {
- m_gui = a;
- if (!a) {
- setAutoBuild(true);
- }
- else {
- setReset(false);
- }
- }
- /**
- * @return The true if should show gui.
- */
- public boolean getGUI() {
- return m_gui;
- }
- /**
- * This will set the size of the validation set.
- * @param a The size of the validation set, as a percentage of the whole.
- */
- public void setValidationSetSize(int a) {
- if (a < 0 || a > 99) {
- return;
- }
- m_valSize = a;
- }
- /**
- * @return The percentage size of the validation set.
- */
- public int getValidationSetSize() {
- return m_valSize;
- }
- /**
- * Set the number of training epochs to perform.
- * Must be greater than 0.
- * @param n The number of epochs to train through.
- */
- public void setTrainingTime(int n) {
- if (n > 0) {
- m_numEpochs = n;
- }
- }
- /**
- * @return The number of epochs to train through.
- */
- public int getTrainingTime() {
- return m_numEpochs;
- }
- /**
- * Call this function to place a node into the network list.
- * @param n The node to place in the list.
- */
- private void addNode(NeuralConnection n) {
- NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length + 1];
- for (int noa = 0; noa < m_neuralNodes.length; noa++) {
- temp1[noa] = m_neuralNodes[noa];
- }
- temp1[temp1.length-1] = n;
- m_neuralNodes = temp1;
- }
- /**
- * Call this function to remove the passed node from the list.
- * This will only remove the node if it is in the neuralnodes list.
- * @param n The neuralConnection to remove.
- * @return True if removed false if not (because it wasn't there).
- */
- private boolean removeNode(NeuralConnection n) {
- NeuralConnection[] temp1 = new NeuralConnection[m_neuralNodes.length - 1];
- int skip = 0;
- for (int noa = 0; noa < m_neuralNodes.length; noa++) {
- if (n == m_neuralNodes[noa]) {
- skip++;
- }
- else if (!((noa - skip) >= temp1.length)) {
- temp1[noa - skip] = m_neuralNodes[noa];
- }
- else {
- return false;
- }
- }
- m_neuralNodes = temp1;
- return true;
- }
- /**
- * This function sets what the m_numeric flag to represent the passed class
- * it also performs the normalization of the attributes if applicable
- * and sets up the info to normalize the class. (note that regardless of
- * the options it will fill an array with the range and base, set to
- * normalize all attributes and the class to be between -1 and 1)
- * @param inst the instances.
- * @return The modified instances. This needs to be done. If the attributes
- * are normalized then deep copies will be made of all the instances which
- * will need to be passed back out.
- */
- private Instances setClassType(Instances inst) throws Exception {
- if (inst != null) {
- // x bounds
- double min=Double.POSITIVE_INFINITY;
- double max=Double.NEGATIVE_INFINITY;
- double value;
- m_attributeRanges = new double[inst.numAttributes()];
- m_attributeBases = new double[inst.numAttributes()];
- for (int noa = 0; noa < inst.numAttributes(); noa++) {
- min = Double.POSITIVE_INFINITY;
- max = Double.NEGATIVE_INFINITY;
- for (int i=0; i < inst.numInstances();i++) {
- if (!inst.instance(i).isMissing(noa)) {
- value = inst.instance(i).value(noa);
- if (value < min) {
- min = value;
- }
- if (value > max) {
- max = value;
- }
- }
- }
- m_attributeRanges[noa] = (max - min) / 2;
- m_attributeBases[noa] = (max + min) / 2;
- if (noa != inst.classIndex() && m_normalizeAttributes) {
- for (int i = 0; i < inst.numInstances(); i++) {
- if (m_attributeRanges[noa] != 0) {
- inst.instance(i).setValue(noa, (inst.instance(i).value(noa)
- - m_attributeBases[noa]) /
- m_attributeRanges[noa]);
- }
- else {
- inst.instance(i).setValue(noa, inst.instance(i).value(noa) -
- m_attributeBases[noa]);
- }
- }
- }
- }
- if (inst.classAttribute().isNumeric()) {
- m_numeric = true;
- }
- else {
- m_numeric = false;
- }
- }
- return inst;
- }
- /**
- * A function used to stop the code that called buildclassifier
- * from continuing on before the user has finished the decision tree.
- * @param tf True to stop the thread, False to release the thread that is
- * waiting there (if one).
- */
- public synchronized void blocker(boolean tf) {
- if (tf) {
- try {
- wait();
- } catch(InterruptedException e) {
- }
- }
- else {
- notifyAll();
- }
- }
- /**
- * Call this function to update the control panel for the gui.
- */
- private void updateDisplay() {
- if (m_gui) {
- m_controlPanel.m_errorLabel.repaint();
- m_controlPanel.m_epochsLabel.repaint();
- }
- }
- /**
- * this will reset all the nodes in the network.
- */
- private void resetNetwork() {
- for (int noc = 0; noc < m_numClasses; noc++) {
- m_outputs[noc].reset();
- }
- }
- /**
- * This will cause the output values of all the nodes to be calculated.
- * Note that the m_currentInstance is used to calculate these values.
- */
- private void calculateOutputs() {
- for (int noc = 0; noc < m_numClasses; noc++) {
- //get the values.
- m_outputs[noc].outputValue(true);
- }
- }
- /**
- * This will cause the error values to be calculated for all nodes.
- * Note that the m_currentInstance is used to calculate these values.
- * Also the output values should have been calculated first.
- * @return The squared error.
- */
- private double calculateErrors() throws Exception {
- double ret = 0, temp = 0;
- for (int noc = 0; noc < m_numAttributes; noc++) {
- //get the errors.
- m_inputs[noc].errorValue(true);
- }
- for (int noc = 0; noc < m_numClasses; noc++) {
- temp = m_outputs[noc].errorValue(false);
- ret += temp * temp;
- }
- return ret;
- }
- /**
- * This will cause the weight values to be updated based on the learning
- * rate, momentum and the errors that have been calculated for each node.
- * @param l The learning rate to update with.
- * @param m The momentum to update with.
- */
- private void updateNetworkWeights(double l, double m) {
- for (int noc = 0; noc < m_numClasses; noc++) {
- //update weights
- m_outputs[noc].updateWeights(l, m);
- }
- }
- /**
- * This creates the required input units.
- */
- private void setupInputs() throws Exception {
- m_inputs = new NeuralEnd[m_numAttributes];
- int now = 0;
- for (int noa = 0; noa < m_numAttributes+1; noa++) {
- if (m_instances.classIndex() != noa) {
- m_inputs[noa - now] = new NeuralEnd(m_instances.attribute(noa).name());
- m_inputs[noa - now].setX(.1);
- m_inputs[noa - now].setY((noa - now + 1.0) / (m_numAttributes + 1));
- m_inputs[noa - now].setLink(true, noa);
- }
- else {
- now = 1;
- }
- }
- }
- /**
- * This creates the required output units.
- */
- private void setupOutputs() throws Exception {
- m_outputs = new NeuralEnd[m_numClasses];
- for (int noa = 0; noa < m_numClasses; noa++) {
- if (m_numeric) {
- m_outputs[noa] = new NeuralEnd(m_instances.classAttribute().name());
- }
- else {
- m_outputs[noa]= new NeuralEnd(m_instances.classAttribute().value(noa));
- }
- m_outputs[noa].setX(.9);
- m_outputs[noa].setY((noa + 1.0) / (m_numClasses + 1));
- m_outputs[noa].setLink(false, noa);
- NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
- m_sigmoidUnit);
- m_nextId++;
- temp.setX(.75);
- temp.setY((noa + 1.0) / (m_numClasses + 1));
- addNode(temp);
- NeuralConnection.connect(temp, m_outputs[noa]);
- }
- }
- /**
- * Call this function to automatically generate the hidden units
- */
- private void setupHiddenLayer()
- {
- StringTokenizer tok = new StringTokenizer(m_hiddenLayers, ",");
- int val = 0; //num of nodes in a layer
- int prev = 0; //used to remember the previous layer
- int num = tok.countTokens(); //number of layers
- String c;
- for (int noa = 0; noa < num; noa++) {
- //note that I am using the Double to get the value rather than the
- //Integer class, because for some reason the Double implementation can
- //handle leading white space and the integer version can't!?!
- c = tok.nextToken().trim();
- if (c.equals("a")) {
- val = (m_numAttributes + m_numClasses) / 2;
- }
- else if (c.equals("i")) {
- val = m_numAttributes;
- }
- else if (c.equals("o")) {
- val = m_numClasses;
- }
- else if (c.equals("t")) {
- val = m_numAttributes + m_numClasses;
- }
- else {
- val = Double.valueOf(c).intValue();
- }
- for (int nob = 0; nob < val; nob++) {
- NeuralNode temp = new NeuralNode(String.valueOf(m_nextId), m_random,
- m_sigmoidUnit);
- m_nextId++;
- temp.setX(.5 / (num) * noa + .25);
- temp.setY((nob + 1.0) / (val + 1));
- addNode(temp);
- if (noa > 0) {
- //then do connections
- for (int noc = m_neuralNodes.length - nob - 1 - prev;
- noc < m_neuralNodes.length - nob - 1; noc++) {
- NeuralConnection.connect(m_neuralNodes[noc], temp);
- }
- }
- }
- prev = val;
- }
- tok = new StringTokenizer(m_hiddenLayers, ",");
- c = tok.nextToken();
- if (c.equals("a")) {
- val = (m_numAttributes + m_numClasses) / 2;
- }
- else if (c.equals("i")) {
- val = m_numAttributes;
- }
- else if (c.equals("o")) {
- val = m_numClasses;
- }
- else if (c.equals("t")) {
- val = m_numAttributes + m_numClasses;
- }
- else {
- val = Double.valueOf(c).intValue();
- }
- if (val == 0) {
- for (int noa = 0; noa < m_numAttributes; noa++) {
- for (int nob = 0; nob < m_numClasses; nob++) {
- NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
- }
- }
- }
- else {
- for (int noa = 0; noa < m_numAttributes; noa++) {
- for (int nob = m_numClasses; nob < m_numClasses + val; nob++) {
- NeuralConnection.connect(m_inputs[noa], m_neuralNodes[nob]);
- }
- }
- for (int noa = m_neuralNodes.length - prev; noa < m_neuralNodes.length;
- noa++) {
- for (int nob = 0; nob < m_numClasses; nob++) {
- NeuralConnection.connect(m_neuralNodes[noa], m_neuralNodes[nob]);
- }
- }
- }
- }
- /**
- * This will go through all the nodes and check if they are connected
- * to a pure output unit. If so they will be set to be linear units.
- * If not they will be set to be sigmoid units.
- */
- private void setEndsToLinear() {
- for (int noa = 0; noa < m_neuralNodes.length; noa++) {
- if ((m_neuralNodes[noa].getType() & NeuralConnection.OUTPUT) ==
- NeuralConnection.OUTPUT) {
- ((NeuralNode)m_neuralNodes[noa]).setMethod(m_linearUnit);
- }
- else {
- ((NeuralNode)m_neuralNodes[noa]).setMethod(m_sigmoidUnit);
- }
- }
- }
- /**
- * Call this function to build and train a neural network for the training
- * data provided.
- * @param i The training data.
- * @exception Throws exception if can't build classification properly.
- */
- public void buildClassifier(Instances i) throws Exception {
- if (i.checkForStringAttributes()) {
- throw new UnsupportedAttributeTypeException("Cannot handle string attributes!");
- }
- if (i.numInstances() == 0) {
- throw new IllegalArgumentException("No training instances.");
- }
- m_epoch = 0;
- m_error = 0;
- m_instances = null;
- m_currentInstance = null;
- m_controlPanel = null;
- m_nodePanel = null;
- m_outputs = new NeuralEnd[0];
- m_inputs = new NeuralEnd[0];
- m_numAttributes = 0;
- m_numClasses = 0;
- m_neuralNodes = new NeuralConnection[0];
- m_selected = new FastVector(4);
- m_graphers = new FastVector(2);
- m_nextId = 0;
- m_stopIt = true;
- m_stopped = true;
- m_accepted = false;
- m_instances = new Instances(i);
- m_instances.deleteWithMissingClass();
- if (m_instances.numInstances() == 0) {
- m_instances = null;
- throw new IllegalArgumentException("All class values missing.");
- }
- m_random = new Random(m_randomSeed);
- m_instances.randomize(m_random);
- if (m_useNomToBin) {
- m_nominalToBinaryFilter = new NominalToBinary();
- m_nominalToBinaryFilter.setInputFormat(m_instances);
- m_instances = Filter.useFilter(m_instances,
- m_nominalToBinaryFilter);
- }
- m_numAttributes = m_instances.numAttributes() - 1;
- m_numClasses = m_instances.numClasses();
- setClassType(m_instances);
- //this sets up the validation set.
- Instances valSet = null;
- //numinval is needed later
- int numInVal = (int)(m_valSize / 100.0 * m_instances.numInstances());
- if (m_valSize > 0) {
- if (numInVal == 0) {
- numInVal = 1;
- }
- valSet = new Instances(m_instances, 0, numInVal);
- }
- ///////////
- setupInputs();
- setupOutputs();
- if (m_autoBuild) {
- setupHiddenLayer();
- }
- /////////////////////////////
- //this sets up the gui for usage
- if (m_gui) {
- m_win = new JFrame();
- m_win.addWindowListener(new WindowAdapter() {
- public void windowClosing(WindowEvent e) {
- boolean k = m_stopIt;
- m_stopIt = true;
- int well =JOptionPane.showConfirmDialog(m_win,
- "Are You Sure...n"
- + "Click Yes To Accept"
- + " The Neural Network"
- + "n Click No To Return",
- "Accept Neural Network",
- JOptionPane.YES_NO_OPTION);
- if (well == 0) {
- m_win.setDefaultCloseOperation(JFrame.DISPOSE_ON_CLOSE);
- m_accepted = true;
- blocker(false);
- }
- else {
- m_win.setDefaultCloseOperation(JFrame.DO_NOTHING_ON_CLOSE);
- }
- m_stopIt = k;
- }
- });
- m_win.getContentPane().setLayout(new BorderLayout());
- m_win.setTitle("Neural Network");
- m_nodePanel = new NodePanel();
- JScrollPane sp = new JScrollPane(m_nodePanel,
- JScrollPane.VERTICAL_SCROLLBAR_ALWAYS,
- JScrollPane.HORIZONTAL_SCROLLBAR_NEVER);
- m_controlPanel = new ControlPanel();
- m_win.getContentPane().add(sp, BorderLayout.CENTER);
- m_win.getContentPane().add(m_controlPanel, BorderLayout.SOUTH);
- m_win.setSize(640, 480);
- m_win.show();
- }
- //This sets up the initial state of the gui
- if (m_gui) {
- blocker(true);
- m_controlPanel.m_changeEpochs.setEnabled(false);
- m_controlPanel.m_changeLearning.setEnabled(false);
- m_controlPanel.m_changeMomentum.setEnabled(false);
- }
- //For silly situations in which the network gets accepted before training
- //commenses
- if (m_numeric) {
- setEndsToLinear();
- }
- if (m_accepted) {
- m_win.dispose();
- m_controlPanel = null;
- m_nodePanel = null;
- m_instances = new Instances(m_instances, 0);
- return;
- }
- //connections done.
- double right = 0;
- double driftOff = 0;
- double lastRight = Double.POSITIVE_INFINITY;
- double tempRate;
- double totalWeight = 0;
- double totalValWeight = 0;
- double origRate = m_learningRate; //only used for when reset
- //ensure that at least 1 instance is trained through.
- if (numInVal == m_instances.numInstances()) {
- numInVal--;
- }
- if (numInVal < 0) {
- numInVal = 0;
- }
- for (int noa = numInVal; noa < m_instances.numInstances(); noa++) {
- if (!m_instances.instance(noa).classIsMissing()) {
- totalWeight += m_instances.instance(noa).weight();
- }
- }
- if (m_valSize != 0) {
- for (int noa = 0; noa < valSet.numInstances(); noa++) {
- if (!valSet.instance(noa).classIsMissing()) {
- totalValWeight += valSet.instance(noa).weight();
- }
- }
- }
- m_stopped = false;
- for (int noa = 1; noa < m_numEpochs + 1; noa++) {
- right = 0;
- for (int nob = numInVal; nob < m_instances.numInstances(); nob++) {
- m_currentInstance = m_instances.instance(nob);
- if (!m_currentInstance.classIsMissing()) {
- //this is where the network updating (and training occurs, for the
- //training set
- resetNetwork();
- calculateOutputs();
- tempRate = m_learningRate * m_currentInstance.weight();
- if (m_decay) {
- tempRate /= noa;
- }
- right += (calculateErrors() / m_instances.numClasses()) *
- m_currentInstance.weight();
- updateNetworkWeights(tempRate, m_momentum);
- }
- }
- right /= totalWeight;
- if (Double.isInfinite(right) || Double.isNaN(right)) {
- if (!m_reset) {
- m_instances = null;
- throw new Exception("Network cannot train. Try restarting with a" +
- " smaller learning rate.");
- }
- else {
- //reset the network
- m_learningRate /= 2;
- buildClassifier(i);
- m_learningRate = origRate;
- m_instances = new Instances(m_instances, 0);
- return;
- }
- }
- ////////////////////////do validation testing if applicable
- if (m_valSize != 0) {
- right = 0;
- for (int nob = 0; nob < valSet.numInstances(); nob++) {
- m_currentInstance = valSet.instance(nob);
- if (!m_currentInstance.classIsMissing()) {
- //this is where the network updating occurs, for the validation set
- resetNetwork();
- calculateOutputs();
- right += (calculateErrors() / valSet.numClasses())
- * m_currentInstance.weight();
- //note 'right' could be calculated here just using
- //the calculate output values. This would be faster.
- //be less modular
- }
- }
- if (right < lastRight) {
- driftOff = 0;
- }
- else {
- driftOff++;
- }
- lastRight = right;
- if (driftOff > m_driftThreshold || noa + 1 >= m_numEpochs) {
- m_accepted = true;
- }
- right /= totalValWeight;
- }
- m_epoch = noa;
- m_error = right;
- //shows what the neuralnet is upto if a gui exists.
- updateDisplay();
- //This junction controls what state the gui is in at the end of each
- //epoch, Such as if it is paused, if it is resumable etc...
- if (m_gui) {
- while ((m_stopIt || (m_epoch >= m_numEpochs && m_valSize == 0)) &&
- !m_accepted) {
- m_stopIt = true;
- m_stopped = true;
- if (m_epoch >= m_numEpochs && m_valSize == 0) {
- m_controlPanel.m_startStop.setEnabled(false);
- }
- else {
- m_controlPanel.m_startStop.setEnabled(true);
- }
- m_controlPanel.m_startStop.setText("Start");
- m_controlPanel.m_startStop.setActionCommand("Start");
- m_controlPanel.m_changeEpochs.setEnabled(true);
- m_controlPanel.m_changeLearning.setEnabled(true);
- m_controlPanel.m_changeMomentum.setEnabled(true);
- blocker(true);
- if (m_numeric) {
- setEndsToLinear();
- }
- }
- m_controlPanel.m_changeEpochs.setEnabled(false);
- m_controlPanel.m_changeLearning.setEnabled(false);
- m_controlPanel.m_changeMomentum.setEnabled(false);
- m_stopped = false;
- //if the network has been accepted stop the training loop
- if (m_accepted) {
- m_win.dispose();
- m_controlPanel = null;
- m_nodePanel = null;
- m_instances = new Instances(m_instances, 0);
- return;
- }
- }
- if (m_accepted) {
- m_instances = new Instances(m_instances, 0);
- return;
- }
- }
- if (m_gui) {
- m_win.dispose();
- m_controlPanel = null;
- m_nodePanel = null;
- }
- m_instances = new Instances(m_instances, 0);
- }
- /**
- * Call this function to predict the class of an instance once a
- * classification model has been built with the buildClassifier call.
- * @param i The instance to classify.
- * @return A double array filled with the probabilities of each class type.
- * @exception if can't classify instance.
- */
- public double[] distributionForInstance(Instance i) throws Exception {
- if (m_useNomToBin) {
- m_nominalToBinaryFilter.input(i);
- m_currentInstance = m_nominalToBinaryFilter.output();
- }
- else {
- m_currentInstance = i;
- }
- if (m_normalizeAttributes) {
- for (int noa = 0; noa < m_instances.numAttributes(); noa++) {
- if (noa != m_instances.classIndex()) {
- if (m_attributeRanges[noa] != 0) {
- m_currentInstance.setValue(noa, (m_currentInstance.value(noa) -
- m_attributeBases[noa]) /
- m_attributeRanges[noa]);
- }
- else {
- m_currentInstance.setValue(noa, m_currentInstance.value(noa) -
- m_attributeBases[noa]);
- }
- }
- }
- }
- resetNetwork();
- //since all the output values are needed.
- //They are calculated manually here and the values collected.
- double[] theArray = new double[m_numClasses];
- for (int noa = 0; noa < m_numClasses; noa++) {
- theArray[noa] = m_outputs[noa].outputValue(true);
- }
- if (m_instances.classAttribute().isNumeric()) {
- return theArray;
- }
- //now normalize the array
- double count = 0;
- for (int noa = 0; noa < m_numClasses; noa++) {
- count += theArray[noa];
- }
- if (count <= 0) {
- return null;
- }
- for (int noa = 0; noa < m_numClasses; noa++) {
- theArray[noa] /= count;
- }
- return theArray;
- }
- /**
- * Returns an enumeration describing the available options.
- *
- * @return an enumeration of all the available options.
- */
- public Enumeration listOptions() {
- Vector newVector = new Vector(14);
- newVector.addElement(new Option(
- "tLearning Rate for the backpropagation algorithm.n"
- +"t(Value should be between 0 - 1, Default = 0.3).",
- "L", 1, "-L <learning rate>"));
- newVector.addElement(new Option(
- "tMomentum Rate for the backpropagation algorithm.n"
- +"t(Value should be between 0 - 1, Default = 0.2).",
- "M", 1, "-M <momentum>"));
- newVector.addElement(new Option(
- "tNumber of epochs to train through.n"
- +"t(Default = 500).",
- "N", 1,"-N <number of epochs>"));
- newVector.addElement(new Option(
- "tPercentage size of validation set to use to terminate" +
- " training (if this is non zero it can pre-empt num of epochs.n"
- +"t(Value should be between 0 - 100, Default = 0).",
- "V", 1, "-V <percentage size of validation set>"));
- newVector.addElement(new Option(
- "tThe value used to seed the random number generator" +
- "t(Value should be >= 0 and and a long, Default = 0).",
- "S", 1, "-S <seed>"));
- newVector.addElement(new Option(
- "tThe consequetive number of errors allowed for validation" +
- " testing before the netwrok terminates." +
- "t(Value should be > 0, Default = 20).",
- "E", 1, "-E <threshold for number of consequetive errors>"));
- newVector.addElement(new Option(
- "tGUI will be opened.n"
- +"t(Use this to bring up a GUI).",
- "G", 0,"-G"));
- newVector.addElement(new Option(
- "tAutocreation of the network connections will NOT be done.n"
- +"t(This will be ignored if -G is NOT set)",
- "A", 0,"-A"));
- newVector.addElement(new Option(
- "tA NominalToBinary filter will NOT automatically be used.n"
- +"t(Set this to not use a NominalToBinary filter).",
- "B", 0,"-B"));
- newVector.addElement(new Option(
- "tThe hidden layers to be created for the network.n"
- +"t(Value should be a list of comma seperated Natural numbers" +
- " or the letters 'a' = (attribs + classes) / 2, 'i'" +
- " = attribs, 'o' = classes, 't' = attribs .+ classes)" +
- " For wildcard values" +
- ",Default = a).",
- "H", 1, "-H <comma seperated numbers for nodes on each layer>"));
- newVector.addElement(new Option(
- "tNormalizing a numeric class will NOT be done.n"
- +"t(Set this to not normalize the class if it's numeric).",
- "C", 0,"-C"));
- newVector.addElement(new Option(
- "tNormalizing the attributes will NOT be done.n"
- +"t(Set this to not normalize the attributes).",
- "I", 0,"-I"));
- newVector.addElement(new Option(
- "tReseting the network will NOT be allowed.n"
- +"t(Set this to not allow the network to reset).",
- "R", 0,"-R"));
- newVector.addElement(new Option(
- "tLearning rate decay will occur.n"
- +"t(Set this to cause the learning rate to decay).",
- "D", 0,"-D"));
- return newVector.elements();
- }
- /**
- * Parses a given list of options. Valid options are:<p>
- *
- * -L num <br>
- * Set the learning rate.
- * (default 0.3) <p>
- *
- * -M num <br>
- * Set the momentum
- * (default 0.2) <p>
- *
- * -N num <br>
- * Set the number of epochs to train through.
- * (default 500) <p>
- *
- * -V num <br>
- * Set the percentage size of the validation set from the training to use.
- * (default 0 (no validation set is used, instead num of epochs is used) <p>
- *
- * -S num <br>
- * Set the seed for the random number generator.
- * (default 0) <p>
- *
- * -E num <br>
- * Set the threshold for the number of consequetive errors allowed during
- * validation testing.
- * (default 20) <p>
- *
- * -G <br>
- * Bring up a GUI for the neural net.
- * <p>
- *
- * -A <br>
- * Do not automatically create the connections in the net.
- * (can only be used if -G is specified) <p>
- *
- * -B <br>
- * Do Not automatically Preprocess the instances with a nominal to binary
- * filter. <p>
- *
- * -H str <br>
- * Set the number of nodes to be used on each layer. Each number represents
- * its own layer and the num of nodes on that layer. Each number should be
- * comma seperated. There are also the wildcards 'a', 'i', 'o', 't'
- * (default 4) <p>
- *
- * -C <br>
- * Do not automatically Normalize the class if it's numeric. <p>
- *
- * -I <br>
- * Do not automatically Normalize the attributes. <p>
- *
- * -R <br>
- * Do not allow the network to be automatically reset. <p>
- *
- * -D <br>
- * Cause the learning rate to decay as training is done. <p>
- *
- * @param options the list of options as an array of strings
- * @exception Exception if an option is not supported
- */
- public void setOptions(String[] options) throws Exception {
- //the defaults can be found here!!!!
- String learningString = Utils.getOption('L', options);
- if (learningString.length() != 0) {
- setLearningRate((new Double(learningString)).doubleValue());
- } else {
- setLearningRate(0.3);
- }
- String momentumString = Utils.getOption('M', options);
- if (momentumString.length() != 0) {
- setMomentum((new Double(momentumString)).doubleValue());
- } else {
- setMomentum(0.2);
- }
- String epochsString = Utils.getOption('N', options);
- if (epochsString.length() != 0) {
- setTrainingTime(Integer.parseInt(epochsString));
- } else {
- setTrainingTime(500);
- }
- String valSizeString = Utils.getOption('V', options);
- if (valSizeString.length() != 0) {
- setValidationSetSize(Integer.parseInt(valSizeString));
- } else {
- setValidationSetSize(0);
- }
- String seedString = Utils.getOption('S', options);
- if (seedString.length() != 0) {
- setRandomSeed(Long.parseLong(seedString));
- } else {
- setRandomSeed(0);
- }
- String thresholdString = Utils.getOption('E', options);
- if (thresholdString.length() != 0) {
- setValidationThreshold(Integer.parseInt(thresholdString));
- } else {
- setValidationThreshold(20);
- }
- String hiddenLayers = Utils.getOption('H', options);
- if (hiddenLayers.length() != 0) {
- setHiddenLayers(hiddenLayers);
- } else {
- setHiddenLayers("a");
- }
- if (Utils.getFlag('G', options)) {
- setGUI(true);
- } else {
- setGUI(false);
- } //small note. since the gui is the only option that can change the other
- //options this should be set first to allow the other options to set
- //properly
- if (Utils.getFlag('A', options)) {
- setAutoBuild(false);
- } else {
- setAutoBuild(true);
- }
- if (Utils.getFlag('B', options)) {
- setNominalToBinaryFilter(false);
- } else {
- setNominalToBinaryFilter(true);
- }
- if (Utils.getFlag('C', options)) {
- setNormalizeNumericClass(false);
- } else {
- setNormalizeNumericClass(true);
- }
- if (Utils.getFlag('I', options)) {
- setNormalizeAttributes(false);
- } else {
- setNormalizeAttributes(true);
- }
- if (Utils.getFlag('R', options)) {
- setReset(false);
- } else {
- setReset(true);
- }
- if (Utils.getFlag('D', options)) {
- setDecay(true);
- } else {
- setDecay(false);
- }
- Utils.checkForRemainingOptions(options);
- }
- /**
- * Gets the current settings of NeuralNet.
- *
- * @return an array of strings suitable for passing to setOptions()
- */
- public String [] getOptions() {
- String [] options = new String [21];
- int current = 0;
- options[current++] = "-L"; options[current++] = "" + getLearningRate();
- options[current++] = "-M"; options[current++] = "" + getMomentum();
- options[current++] = "-N"; options[current++] = "" + getTrainingTime();
- options[current++] = "-V"; options[current++] = "" +getValidationSetSize();
- options[current++] = "-S"; options[current++] = "" + getRandomSeed();
- options[current++] = "-E"; options[current++] =""+getValidationThreshold();
- options[current++] = "-H"; options[current++] = getHiddenLayers();
- if (getGUI()) {
- options[current++] = "-G";
- }
- if (!getAutoBuild()) {
- options[current++] = "-A";
- }
- if (!getNominalToBinaryFilter()) {
- options[current++] = "-B";
- }
- if (!getNormalizeNumericClass()) {
- options[current++] = "-C";
- }
- if (!getNormalizeAttributes()) {
- options[current++] = "-I";
- }
- if (!getReset()) {
- options[current++] = "-R";
- }
- if (getDecay()) {
- options[current++] = "-D";
- }
- while (current < options.length) {
- options[current++] = "";
- }
- return options;
- }
- /**
- * @return string describing the model.
- */
- public String toString() {
- StringBuffer model = new StringBuffer(m_neuralNodes.length * 100);
- //just a rough size guess
- NeuralNode con;
- double[] weights;
- NeuralConnection[] inputs;
- for (int noa = 0; noa < m_neuralNodes.length; noa++) {
- con = (NeuralNode) m_neuralNodes[noa]; //this would need a change
- //for items other than nodes!!!
- weights = con.getWeights();
- inputs = con.getInputs();
- if (con.getMethod() instanceof SigmoidUnit) {
- model.append("Sigmoid ");
- }
- else if (con.getMethod() instanceof LinearUnit) {
- model.append("Linear ");
- }
- model.append("Node " + con.getId() + "n Inputs Weightsn");
- model.append(" Threshold " + weights[0] + "n");
- for (int nob = 1; nob < con.getNumInputs() + 1; nob++) {
- if ((inputs[nob - 1].getType() & NeuralConnection.PURE_INPUT)
- == NeuralConnection.PURE_INPUT) {
- model.append(" Attrib " +
- m_instances.attribute(((NeuralEnd)inputs[nob-1]).
- getLink()).name()
- + " " + weights[nob] + "n");
- }
- else {
- model.append(" Node " + inputs[nob-1].getId() + " " +
- weights[nob] + "n");
- }
- }
- }
- //now put in the ends
- for (int noa = 0; noa < m_outputs.length; noa++) {
- inputs = m_outputs[noa].getInputs();
- model.append("Class " +
- m_instances.classAttribute().
- value(m_outputs[noa].getLink()) +
- "n Inputn");
- for (int nob = 0; nob < m_outputs[noa].getNumInputs(); nob++) {
- if ((inputs[nob].getType() & NeuralConnection.PURE_INPUT)
- == NeuralConnection.PURE_INPUT) {
- model.append(" Attrib " +
- m_instances.attribute(((NeuralEnd)inputs[nob]).
- getLink()).name() + "n");
- }
- else {
- model.append(" Node " + inputs[nob].getId() + "n");
- }
- }
- }
- return model.toString();
- }
- /**
- * This will return a string describing the classifier.
- * @return The string.
- */
- public String globalInfo() {
- return "This neural network uses backpropagation to train.";
- }
- /**
- * @return a string to describe the learning rate option.
- */
- public String learningRateTipText() {
- return "The amount the" +
- " weights are updated.";
- }
- /**
- * @return a string to describe the momentum option.
- */
- public String momentumTipText() {
- return "Momentum applied to the weights during updating.";
- }
- /**
- * @return a string to describe the AutoBuild option.
- */
- public String autoBuildTipText() {
- return "Adds and connects up hidden layers in the network.";
- }
- /**
- * @return a string to describe the random seed option.
- */
- public String randomSeedTipText() {
- return "Seed used to initialise the random number generator." +
- "Random numbers are used for setting the initial weights of the" +
- " connections betweem nodes, and also for shuffling the training data.";
- }
- /**
- * @return a string to describe the validation threshold option.
- */
- public String validationThresholdTipText() {
- return "Used to terminate validation testing." +
- "The value here dictates how many times in a row the validation set" +
- " error can get worse before training is terminated.";
- }
- /**
- * @return a string to describe the GUI option.
- */
- public String GUITipText() {
- return "Brings up a gui interface." +
- " This will allow the pausing and altering of the nueral network" +
- " during training.nn" +
- "* To add a node left click (this node will be automatically selected," +
- " ensure no other nodes were selected).n" +
- "* To select a node left click on it either while no other node is" +
- " selected or while holding down the control key (this toggles that" +
- " node as being selected and not selected.n" +
- "* To connect a node, first have the start node(s) selected, then click"+
- " either the end node or on an empty space (this will create a new node"+
- " that is connected with the selected nodes). The selection status of" +
- " nodes will stay the same after the connection. (Note these are" +
- " directed connections, also a connection between two nodes will not" +
- " be established more than once and certain connections that are" +
- " deemed to be invalid will not be made).n" +
- "* To remove a connection select one of the connected node(s) in the" +
- " connection and then right click the other node (it does not matter" +
- " whether the node is the start or end the connection will be removed" +
- ").n" +
- "* To remove a node right click it while no other nodes (including it)" +
- " are selected. (This will also remove all connections to it)n." +
- "* To deselect a node either left click it while holding down control," +
- " or right click on empty space.n" +
- "* The raw inputs are provided from the labels on the left.n" +
- "* The red nodes are hidden layers.n" +
- "* The orange nodes are the output nodes.n" +
- "* The labels on the right show the class the output node represents." +
- " Note that with a numeric class the output node will automatically be" +
- " made into an unthresholded linear unit.nn" +
- "Alterations to the neural network can only be done while the network" +
- " is not running, This also applies to the learning rate and other" +
- " fields on the control panel.nn" +
- "* You can accept the network as being finished at any time.n" +
- "* The network is automatically paused at the beginning.n" +
- "* There is a running indication of what epoch the network is up to" +
- " and what the (rough) error for that epoch was (or for" +
- " the validation if that is being used). Note that this error value" +
- " is based on a network that changes as the value is computed." +
- " (also depending on whether" +
- " the class is normalized will effect the error reported for numeric" +
- " classes.n" +
- "* Once the network is done it will pause again and either wait to be" +
- " accepted or trained more.nn" +
- "Note that if the gui is not set the network will not require any" +
- " interaction.n";
- }
- /**
- * @return a string to describe the validation size option.
- */
- public String validationSetSizeTipText() {
- return "The percentage size of the validation set." +
- "(The training will continue until it is observed that" +
- " the error on the validation set has been consistently getting" +
- " worse, or if the training time is reached).n" +
- "If This is set to zero no validation set will be used and instead" +
- " the network will train for the specified number of epochs.";
- }
- /**
- * @return a string to describe the learning rate option.
- */
- public String trainingTimeTipText() {
- return "The number of epochs to train through." +
- " If the validation set is non-zero then it can terminate the network" +
- " early";
- }
- /**
- * @return a string to describe the nominal to binary option.
- */
- public String nominalToBinaryFilterTipText() {
- return "This will preprocess the instances with the filter." +
- " This could help improve performance if there are nominal attributes" +
- " in the data.";
- }
- /**
- * @return a string to describe the hidden layers in the network.
- */
- public String hiddenLayersTipText() {
- return "This defines the hidden layers of the neural network." +
- " This is a list of positive whole numbers. 1 for each hidden layer." +
- " Comma seperated. To have no hidden layers put a single 0 here." +
- " This will only be used if autobuild is set. There are also wildcard" +
- " values 'a' = (attribs + classes) / 2, 'i' = attribs, 'o' = classes" +
- " , 't' = attribs + classes.";
- }
- /**
- * @return a string to describe the nominal to binary option.
- */
- public String normalizeNumericClassTipText() {
- return "This will normalize the class if it's numeric." +
- " This could help improve performance of the network, It normalizes" +
- " the class to be between -1 and 1. Note that this is only internally" +
- ", the output will be scaled back to the original range.";
- }
- /**
- * @return a string to describe the nominal to binary option.
- */
- public String normalizeAttributesTipText() {
- return "This will normalize the attributes." +
- " This could help improve performance of the network." +
- " This is not reliant on the class being numeric. This will also" +
- " normalize nominal attributes as well (after they have been run" +
- " through the nominal to binary filter if that is in use) so that the" +
- " nominal values are between -1 and 1";
- }
- /**
- * @return a string to describe the Reset option.
- */
- public String resetTipText() {
- return "This will allow the network to reset with a lower learning rate." +
- " If the network diverges from the answer this will automatically" +
- " reset the network with a lower learning rate and begin training" +
- " again. This option is only available if the gui is not set. Note" +
- " that if the network diverges but isn't allowed to reset it will" +
- " fail the training process and return an error message.";
- }
- /**
- * @return a string to describe the Decay option.
- */
- public String decayTipText() {
- return "This will cause the learning rate to decrease." +
- " This will divide the starting learning rate by the epoch number, to" +
- " determine what the current learning rate should be. This may help" +
- " to stop the network from diverging from the target output, as well" +
- " as improve general performance. Note that the decaying learning" +
- " rate will not be shown in the gui, only the original learning rate" +
- ". If the learning rate is changed in the gui, this is treated as the" +
- " starting learning rate.";
- }
- }