From d5a023ca022d74c3e061a8d65fbd48bcaebba81a Mon Sep 17 00:00:00 2001 From: "Will Conrad (Mac)" Date: Wed, 1 Jun 2022 18:17:35 -0400 Subject: [PATCH] Added getActivations method --- .../basicneuralnetwork/NeuralNetwork.java | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/src/main/java/basicneuralnetwork/NeuralNetwork.java b/src/main/java/basicneuralnetwork/NeuralNetwork.java index 4c17363..5c8b345 100644 --- a/src/main/java/basicneuralnetwork/NeuralNetwork.java +++ b/src/main/java/basicneuralnetwork/NeuralNetwork.java @@ -5,6 +5,7 @@ import basicneuralnetwork.utilities.MatrixUtilities; import org.ejml.simple.SimpleMatrix; +import java.util.ArrayList; import java.util.Arrays; import java.util.Random; @@ -14,7 +15,7 @@ public class NeuralNetwork { private ActivationFunctionFactory activationFunctionFactory = new ActivationFunctionFactory(); - + private Random random = new Random(); // Dimensions of the neural network @@ -25,7 +26,9 @@ public class NeuralNetwork { private SimpleMatrix[] weights; private SimpleMatrix[] biases; - + + private ArrayList activations; + private double learningRate; private String activationFunctionKey; @@ -119,14 +122,31 @@ public double[] guess(double[] input) { // Transform array to matrix SimpleMatrix output = MatrixUtilities.arrayToMatrix(input); + //Stores an the activation matrix for each layer + activations = new ArrayList(); + activations.add(output); + for (int i = 0; i < hiddenLayers + 1; i++) { output = calculateLayer(weights[i], biases[i], output, activationFunction); + activations.add(output); } return MatrixUtilities.getColumnFromMatrixAsArray(output, 0); } } + //Return 2D array of the activation values for each neuron in each layer + //Neurons are updated every time guess() is called + public double[][] getActivations() + { + double[][] values = new double[hiddenLayers + 2][]; + for (int m = 0; m < activations.size(); m++) + { + values[m] = MatrixUtilities.getColumnFromMatrixAsArray(activations.get(m), 0); + } + return values; + } + public void train(double[] inputArray, double[] targetArray) { if (inputArray.length != inputNodes) { throw new WrongDimensionException(inputArray.length, inputNodes, "Input");