-
Notifications
You must be signed in to change notification settings - Fork 277
/
mnist.py
47 lines (43 loc) · 1.55 KB
/
mnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""
PyTorch version: https://github.com/pytorch/examples/blob/master/mnist/main.py
TensorFlow version: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/tutorials/mnist/mnist.py
"""
# pip install thinc ml_datasets typer
from thinc.api import Model, chain, Relu, Softmax, Adam
import ml_datasets
from wasabi import msg
from tqdm import tqdm
import typer
def main(
n_hidden: int = 256, dropout: float = 0.2, n_iter: int = 10, batch_size: int = 128
):
# Define the model
model: Model = chain(
Relu(nO=n_hidden, dropout=dropout),
Relu(nO=n_hidden, dropout=dropout),
Softmax(),
)
# Load the data
(train_X, train_Y), (dev_X, dev_Y) = ml_datasets.mnist()
# Set any missing shapes for the model.
model.initialize(X=train_X[:5], Y=train_Y[:5])
train_data = model.ops.multibatch(batch_size, train_X, train_Y, shuffle=True)
dev_data = model.ops.multibatch(batch_size, dev_X, dev_Y)
# Create the optimizer.
optimizer = Adam(0.001)
for i in range(n_iter):
for X, Y in tqdm(train_data, leave=False):
Yh, backprop = model.begin_update(X)
backprop(Yh - Y)
model.finish_update(optimizer)
# Evaluate and print progress
correct = 0
total = 0
for X, Y in dev_data:
Yh = model.predict(X)
correct += (Yh.argmax(axis=1) == Y.argmax(axis=1)).sum()
total += Yh.shape[0]
score = correct / total
msg.row((i, f"{score:.3f}"), widths=(3, 5))
if __name__ == "__main__":
typer.run(main)