diff --git a/src/main/java/basicneuralnetwork/NeuralNetwork.java b/src/main/java/basicneuralnetwork/NeuralNetwork.java index 4c17363..5c8b345 100644 --- a/src/main/java/basicneuralnetwork/NeuralNetwork.java +++ b/src/main/java/basicneuralnetwork/NeuralNetwork.java @@ -5,6 +5,7 @@ import basicneuralnetwork.utilities.MatrixUtilities; import org.ejml.simple.SimpleMatrix; +import java.util.ArrayList; import java.util.Arrays; import java.util.Random; @@ -14,7 +15,7 @@ public class NeuralNetwork { private ActivationFunctionFactory activationFunctionFactory = new ActivationFunctionFactory(); - + private Random random = new Random(); // Dimensions of the neural network @@ -25,7 +26,9 @@ public class NeuralNetwork { private SimpleMatrix[] weights; private SimpleMatrix[] biases; - + + private ArrayList activations; + private double learningRate; private String activationFunctionKey; @@ -119,14 +122,31 @@ public double[] guess(double[] input) { // Transform array to matrix SimpleMatrix output = MatrixUtilities.arrayToMatrix(input); + //Stores an the activation matrix for each layer + activations = new ArrayList(); + activations.add(output); + for (int i = 0; i < hiddenLayers + 1; i++) { output = calculateLayer(weights[i], biases[i], output, activationFunction); + activations.add(output); } return MatrixUtilities.getColumnFromMatrixAsArray(output, 0); } } + //Return 2D array of the activation values for each neuron in each layer + //Neurons are updated every time guess() is called + public double[][] getActivations() + { + double[][] values = new double[hiddenLayers + 2][]; + for (int m = 0; m < activations.size(); m++) + { + values[m] = MatrixUtilities.getColumnFromMatrixAsArray(activations.get(m), 0); + } + return values; + } + public void train(double[] inputArray, double[] targetArray) { if (inputArray.length != inputNodes) { throw new WrongDimensionException(inputArray.length, inputNodes, "Input");