Skip to content

Commit

Permalink
Merge pull request #8 from rasbt/ch14-softmax
Browse files Browse the repository at this point in the history
ch14-softmax update
  • Loading branch information
rasbt committed Feb 21, 2022
2 parents d4b9be6 + 6bf3808 commit a09dcc2
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 94 deletions.
294 changes: 204 additions & 90 deletions ch14/ch14_part1.ipynb

Large diffs are not rendered by default.

12 changes: 8 additions & 4 deletions ch14/ch14_part1.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def conv2d(X, W, p=(0, 0), s=(1, 1)):
mnist_dataset = torchvision.datasets.MNIST(root=image_path,
train=True,
transform=transform,
download=False)
download=True)

mnist_valid_dataset = Subset(mnist_dataset, torch.arange(10000))
mnist_train_dataset = Subset(mnist_dataset, torch.arange(10000, len(mnist_dataset)))
Expand All @@ -310,6 +310,8 @@ def conv2d(X, W, p=(0, 0), s=(1, 1)):





batch_size = 64
torch.manual_seed(1)
train_dl = DataLoader(mnist_train_dataset, batch_size, shuffle=True)
Expand Down Expand Up @@ -367,8 +369,6 @@ def conv2d(X, W, p=(0, 0), s=(1, 1)):
model.add_module('dropout', nn.Dropout(p=0.5))

model.add_module('fc2', nn.Linear(1024, 10))
model.add_module('softmax', nn.Softmax(dim=1))




Expand Down Expand Up @@ -430,6 +430,8 @@ def train(model, num_epochs, train_dl, valid_dl):





x_arr = np.arange(len(hist[0])) + 1

fig = plt.figure(figsize=(12, 4))
Expand All @@ -446,6 +448,7 @@ def train(model, num_epochs, train_dl, valid_dl):
ax.set_xlabel('Epoch', size=15)
ax.set_ylabel('Accuracy', size=15)

#plt.savefig('figures/14_13.png')
plt.show()


Expand All @@ -456,7 +459,6 @@ def train(model, num_epochs, train_dl, valid_dl):
pred = model(mnist_test_dataset.data.unsqueeze(1) / 255.)
is_correct = (torch.argmax(pred, dim=1) == mnist_test_dataset.targets).float()
print(f'Test accuracy: {is_correct.mean():.4f}')




Expand All @@ -475,6 +477,8 @@ def train(model, num_epochs, train_dl, valid_dl):
verticalalignment='center',
transform=ax.transAxes)


plt.savefig('figures/14_14.png')
plt.show()


Expand Down
Binary file modified ch14/figures/14_13.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ch14/figures/14_14.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit a09dcc2

Please sign in to comment.