1
- from train import load_data , create_model , IMAGE_SHAPE , batch_size , np
1
+ import numpy as np
2
2
import matplotlib .pyplot as plt
3
- # load the data generators
3
+ from train import load_data , create_model , IMAGE_SHAPE , batch_size
4
+
5
+ # Load the data generators
4
6
train_generator , validation_generator , class_names = load_data ()
5
- # constructs the model
7
+
8
+ # Construct the model
6
9
model = create_model (input_shape = IMAGE_SHAPE )
7
- # load the optimal weights
10
+
11
+ # Load the optimal weights
8
12
model .load_weights ("results/MobileNetV2_finetune_last5-loss-0.66.h5" )
9
13
14
+ # Calculate the number of validation steps per epoch
10
15
validation_steps_per_epoch = np .ceil (validation_generator .samples / batch_size )
11
- # print the validation loss & accuracy
16
+
17
+ # Print the validation loss & accuracy
12
18
evaluation = model .evaluate_generator (validation_generator , steps = validation_steps_per_epoch , verbose = 1 )
13
19
print ("Val loss:" , evaluation [0 ])
14
20
print ("Val Accuracy:" , evaluation [1 ])
15
21
16
- # get a random batch of images
22
+ # Get a random batch of images
17
23
image_batch , label_batch = next (iter (validation_generator ))
18
- # turn the original labels into human-readable text
24
+
25
+ # Convert the original labels into human-readable text
19
26
label_batch = [class_names [np .argmax (label_batch [i ])] for i in range (batch_size )]
20
- # predict the images on the model
27
+
28
+ # Predict the images using the model
21
29
predicted_class_names = model .predict (image_batch )
22
30
predicted_ids = [np .argmax (predicted_class_names [i ]) for i in range (batch_size )]
23
- # turn the predicted vectors to human readable labels
31
+
32
+ # Convert the predicted vectors to human-readable labels
24
33
predicted_class_names = np .array ([class_names [id ] for id in predicted_ids ])
25
34
26
- # some nice plotting
27
- plt .figure (figsize = (10 ,9 ))
35
+ # Plot the results
36
+ plt .figure (figsize = (10 , 9 ))
28
37
for n in range (30 ):
29
- plt .subplot (6 ,5 , n + 1 )
30
- plt .subplots_adjust (hspace = 0.3 )
38
+ plt .subplot (6 , 5 , n + 1 )
39
+ plt .subplots_adjust (hspace = 0.3 )
31
40
plt .imshow (image_batch [n ])
41
+
42
+ # Set the title and color based on correctness
32
43
if predicted_class_names [n ] == label_batch [n ]:
33
44
color = "blue"
34
45
title = predicted_class_names [n ].title ()
35
46
else :
36
47
color = "red"
37
48
title = f"{ predicted_class_names [n ].title ()} , correct:{ label_batch [n ]} "
49
+
38
50
plt .title (title , color = color )
39
51
plt .axis ('off' )
52
+
40
53
_ = plt .suptitle ("Model predictions (blue: correct, red: incorrect)" )
41
- plt .show ()
54
+ plt .show ()
0 commit comments