|
| 1 | +# Learner: Nguyen Truong Thinh |
| 2 | +# Contact me: [email protected] || +84393280504 |
| 3 | +# |
| 4 | +# Topic: Supervised Learning: The zen of Testing |
| 5 | +# A neural network implementation |
| 6 | + |
| 7 | +# An MNIST data (pre-shuffled) loader that splits data into training, validation & test sets. |
| 8 | +import matplotlib.pyplot as plt |
| 9 | +import numpy as np |
| 10 | +from ml.supervised_learning.neural_networks import training_the_network as tn |
| 11 | +from ml.util import load_images, load_labels, one_hot_encoding |
| 12 | + |
| 13 | +TRAIN_IMAGE = "../../../fundamentals/datasets/mnist/train-images-idx3-ubyte.gz" |
| 14 | +TRAIN_LABEL = "../../../fundamentals/datasets/mnist/train-labels-idx1-ubyte.gz" |
| 15 | +TEST_IMAGE = "../../../fundamentals/datasets/mnist/t10k-images-idx3-ubyte.gz" |
| 16 | +TEST_LABEL = "../../../fundamentals/datasets/mnist/t10k-labels-idx1-ubyte.gz" |
| 17 | + |
| 18 | +# X_train/ X_validation/ X_test: 60k/ 5k/ 5k images |
| 19 | +# Each image has 784 elements (28 * 28 pixels) |
| 20 | +X_train = load_images(TRAIN_IMAGE) |
| 21 | +X_test_all = load_images(TEST_IMAGE) # To ensure best practice: np.random.shuffle(X_test_all) |
| 22 | +X_validation, X_test = np.split(X_test_all, 2) |
| 23 | + |
| 24 | +# 60K labels, each a single digit from 0 to 9 |
| 25 | +Y_train_unencoded = load_labels(TRAIN_LABEL) |
| 26 | +# Y_train: 60k labels, each consisting of 10 one-hot-encoded elements |
| 27 | +Y_train = one_hot_encoding(Y_train_unencoded, 10) |
| 28 | +# Y_validation/ Y_test: 5k/ 5k labels, each a single digit from 0 to 9 |
| 29 | +Y_test_all = load_labels(TEST_LABEL) # To ensure best practice: np.random.shuffle(Y_test_all) |
| 30 | +Y_validation, Y_test = np.split(Y_test_all, 2) |
| 31 | + |
| 32 | + |
| 33 | +# This loss() takes different parameters than the ones in other source files |
| 34 | +def loss(_x, _y, _w1, _w2): |
| 35 | + _y_hat, _ = tn.forward(_x, _w1, _w2) |
| 36 | + return -np.sum(_y * np.log(_y_hat)) / _y.shape[0] |
| 37 | + |
| 38 | + |
| 39 | +def train(x_train, y_train, x_test, y_test, _n_hidden_nodes, iterations, lr): |
| 40 | + n_input_variables = x_train.shape[1] |
| 41 | + n_classes = y_train.shape[1] |
| 42 | + # Initialize all the weights at zero |
| 43 | + # _w1 = np.zeros((n_input_variables + 1, _n_hidden_nodes)) |
| 44 | + # _w2 = np.zeros((_n_hidden_nodes + 1, n_classes)) |
| 45 | + # Initialize all the weights with good initialization |
| 46 | + _w1, _w2 = tn.initialize_weights(n_input_variables, _n_hidden_nodes, n_classes) |
| 47 | + _training_losses = [] |
| 48 | + _test_losses = [] |
| 49 | + |
| 50 | + for i in range(iterations): |
| 51 | + y_hat_train, h = tn.forward(x_train, _w1, _w2) |
| 52 | + y_hat_test, _ = tn.forward(x_test, _w1, _w2) |
| 53 | + w1_gradient, w2_gradient = tn.back(x_train, y_train, y_hat_train, _w2, h) |
| 54 | + _w1 = _w1 - (w1_gradient * lr) |
| 55 | + _w2 = _w2 - (w2_gradient * lr) |
| 56 | + |
| 57 | + training_loss = -np.sum(y_train * np.log(y_hat_train)) / y_train.shape[0] |
| 58 | + _training_losses.append(training_loss) |
| 59 | + test_loss = -np.sum(y_test * np.log(y_hat_test)) / y_test.shape[0] |
| 60 | + _test_losses.append(test_loss) |
| 61 | + |
| 62 | + print("%5d > Training loss: %.5f - Test loss: %.5f" % (i, training_loss, test_loss)) |
| 63 | + |
| 64 | + return _training_losses, _test_losses, _w1, _w2 |
| 65 | + |
| 66 | + |
| 67 | +training_losses, test_losses, w1, w2 = train(X_train, Y_train, |
| 68 | + X_test, |
| 69 | + one_hot_encoding(Y_test, 10), |
| 70 | + _n_hidden_nodes=200, |
| 71 | + iterations=10000, |
| 72 | + lr=0.01) |
| 73 | +training_accuracy = tn.accuracy(X_train, Y_train, w1, w2) |
| 74 | +test_accuracy = tn.accuracy(X_test, Y_test, w1, w2) |
| 75 | +print("Training accuracy: %.2f%%, Test accuracy: %.2f%%" % (training_accuracy, test_accuracy)) |
| 76 | + |
| 77 | +plt.plot(training_losses, label="Training set", color='blue', linestyle='-') |
| 78 | +plt.plot(test_losses, label="Test set", color='green', linestyle='--') |
| 79 | +plt.xlabel("Iterations", fontsize=30) |
| 80 | +plt.ylabel("Loss", fontsize=30) |
| 81 | +plt.xticks(fontsize=15) |
| 82 | +plt.yticks(fontsize=15) |
| 83 | +plt.legend(fontsize=30) |
| 84 | +plt.show() |
0 commit comments