Skip to content

Commit 3502be9

Browse files
committed
Neural networks: Splitting our pre-shuffled MNIST dataset into batches, use mini-batch GD to help speed up training.
1 parent 813209b commit 3502be9

File tree

2 files changed

+69
-4
lines changed

2 files changed

+69
-4
lines changed

ml/supervised_learning/classifications/our_own_mnist_lib.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,15 @@
1212

1313
from ml.util import prepend_bias, one_hot_encoding
1414

15-
TRAIN_IMAGE = "../../../fundamentals/datasets/mnist/train-images-idx3-ubyte.gz"
16-
TRAIN_LABEL = "../../../fundamentals/datasets/mnist/train-labels-idx1-ubyte.gz"
17-
TEST_IMAGE = "../../../fundamentals/datasets/mnist/t10k-images-idx3-ubyte.gz"
18-
TEST_LABEL = "../../../fundamentals/datasets/mnist/t10k-labels-idx1-ubyte.gz"
15+
# TRAIN_IMAGE = "../../../fundamentals/datasets/mnist/train-images-idx3-ubyte.gz"
16+
# TRAIN_LABEL = "../../../fundamentals/datasets/mnist/train-labels-idx1-ubyte.gz"
17+
# TEST_IMAGE = "../../../fundamentals/datasets/mnist/t10k-images-idx3-ubyte.gz"
18+
# TEST_LABEL = "../../../fundamentals/datasets/mnist/t10k-labels-idx1-ubyte.gz"
19+
20+
TRAIN_IMAGE = "../../../../fundamentals/datasets/mnist/train-images-idx3-ubyte.gz"
21+
TRAIN_LABEL = "../../../../fundamentals/datasets/mnist/train-labels-idx1-ubyte.gz"
22+
TEST_IMAGE = "../../../../fundamentals/datasets/mnist/t10k-images-idx3-ubyte.gz"
23+
TEST_LABEL = "../../../../fundamentals/datasets/mnist/t10k-labels-idx1-ubyte.gz"
1924

2025

2126
def load_images(filename):
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Learner: Nguyen Truong Thinh
2+
# Contact me: [email protected] || +84393280504
3+
#
4+
# Topic: Supervised Learning: Classifications in Action.
5+
# Neural Networks: Are leaps & bounds more powerful than perceptrons.
6+
# The Network's classification: A Classifier's Answers
7+
# A neural network implementation
8+
9+
import numpy as np
10+
11+
from ml.supervised_learning.neural_networks import training_the_network as tn
12+
13+
14+
def prepare_batches(x_train, y_train, batch_size):
15+
"""Splitting our pre-shuffled MNIST dataset into batches"""
16+
x_batches = []
17+
y_batches = []
18+
n_examples = x_train.shape[0]
19+
20+
for batch in range(0, n_examples, batch_size):
21+
batch_end = batch + batch_size
22+
x_batches.append(x_train[batch:batch_end])
23+
y_batches.append(y_train[batch:batch_end])
24+
25+
return x_batches, y_batches
26+
27+
28+
def report(epoch, batch, x_train, y_train, x_test, y_test, _w1, _w2):
29+
"""To check how well the system is learning"""
30+
y_hat, _ = tn.forward(x_train, _w1, _w2)
31+
training_loss = tn.loss(y_train, y_hat)
32+
classifications = tn.classify(x_test, _w1, _w2)
33+
accuracy = np.average(classifications == y_test) * 100.0
34+
print("%5d-%d > Loss: %.8f, Accuracy: %.2f%%" % (epoch, batch, training_loss, accuracy))
35+
36+
37+
def train(x_train, y_train, x_test, y_test, _n_hidden_nodes, epochs, batch_size, lr):
38+
n_input_variables = x_train.shape[1]
39+
n_classes = y_train.shape[1]
40+
# Initialize all the weights at zero
41+
# _w1 = np.zeros((n_input_variables + 1, _n_hidden_nodes))
42+
# _w2 = np.zeros((_n_hidden_nodes + 1, n_classes))
43+
# Initialize all the weights with good initialization
44+
_w1, _w2 = tn.initialize_weights(n_input_variables, _n_hidden_nodes, n_classes)
45+
x_batches, y_batches = prepare_batches(x_train, y_train, batch_size)
46+
47+
for e in range(epochs):
48+
for i in range(len(x_batches)):
49+
y_hat, h = tn.forward(x_batches[i], _w1, _w2)
50+
w1_gradient, w2_gradient = tn.back(x_batches[i], y_batches[i], y_hat, _w2, h)
51+
_w1 = _w1 - (w1_gradient * lr)
52+
_w2 = _w2 - (w2_gradient * lr)
53+
54+
report(e, i, x_train, y_train, x_test, y_test, _w1, _w2)
55+
return _w1, _w2
56+
57+
58+
if __name__ == "__main__":
59+
from ml.supervised_learning.classifications import our_own_mnist_lib as mnist
60+
w1, w2 = train(mnist.X_train, mnist.Y_train, mnist.X_test, mnist.Y_test, _n_hidden_nodes=200, epochs=2, batch_size=20000, lr=0.01)

0 commit comments

Comments
 (0)