Skip to content

Commit c2dbd8d

Browse files
authored
Update test.py
Made the code more readable, added descriptive variable names and proper indentation
1 parent 17ab545 commit c2dbd8d

File tree

1 file changed

+27
-14
lines changed
  • I/image-classifier-using-transfer-learning

1 file changed

+27
-14
lines changed
Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,41 +1,54 @@
1-
from train import load_data, create_model, IMAGE_SHAPE, batch_size, np
1+
import numpy as np
22
import matplotlib.pyplot as plt
3-
# load the data generators
3+
from train import load_data, create_model, IMAGE_SHAPE, batch_size
4+
5+
# Load the data generators
46
train_generator, validation_generator, class_names = load_data()
5-
# constructs the model
7+
8+
# Construct the model
69
model = create_model(input_shape=IMAGE_SHAPE)
7-
# load the optimal weights
10+
11+
# Load the optimal weights
812
model.load_weights("results/MobileNetV2_finetune_last5-loss-0.66.h5")
913

14+
# Calculate the number of validation steps per epoch
1015
validation_steps_per_epoch = np.ceil(validation_generator.samples / batch_size)
11-
# print the validation loss & accuracy
16+
17+
# Print the validation loss & accuracy
1218
evaluation = model.evaluate_generator(validation_generator, steps=validation_steps_per_epoch, verbose=1)
1319
print("Val loss:", evaluation[0])
1420
print("Val Accuracy:", evaluation[1])
1521

16-
# get a random batch of images
22+
# Get a random batch of images
1723
image_batch, label_batch = next(iter(validation_generator))
18-
# turn the original labels into human-readable text
24+
25+
# Convert the original labels into human-readable text
1926
label_batch = [class_names[np.argmax(label_batch[i])] for i in range(batch_size)]
20-
# predict the images on the model
27+
28+
# Predict the images using the model
2129
predicted_class_names = model.predict(image_batch)
2230
predicted_ids = [np.argmax(predicted_class_names[i]) for i in range(batch_size)]
23-
# turn the predicted vectors to human readable labels
31+
32+
# Convert the predicted vectors to human-readable labels
2433
predicted_class_names = np.array([class_names[id] for id in predicted_ids])
2534

26-
# some nice plotting
27-
plt.figure(figsize=(10,9))
35+
# Plot the results
36+
plt.figure(figsize=(10, 9))
2837
for n in range(30):
29-
plt.subplot(6,5,n+1)
30-
plt.subplots_adjust(hspace = 0.3)
38+
plt.subplot(6, 5, n + 1)
39+
plt.subplots_adjust(hspace=0.3)
3140
plt.imshow(image_batch[n])
41+
42+
# Set the title and color based on correctness
3243
if predicted_class_names[n] == label_batch[n]:
3344
color = "blue"
3445
title = predicted_class_names[n].title()
3546
else:
3647
color = "red"
3748
title = f"{predicted_class_names[n].title()}, correct:{label_batch[n]}"
49+
3850
plt.title(title, color=color)
3951
plt.axis('off')
52+
4053
_ = plt.suptitle("Model predictions (blue: correct, red: incorrect)")
41-
plt.show()
54+
plt.show()

0 commit comments

Comments
 (0)