Skip to content

Commit 031b87c

Browse files
committed
Fix the oversight, wherein the feature extractor had not been initialized whilst testing the autoencoders
1 parent 97579a5 commit 031b87c

File tree

1 file changed

+25
-15
lines changed

1 file changed

+25
-15
lines changed

test_models.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import argparse
2323
import numpy as np
2424
from random import shuffle
25+
import os
2526

2627
import copy
2728
from autoencoder import *
@@ -38,13 +39,14 @@
3839
parser = argparse.ArgumentParser(description='Test file')
3940
#parser.add_argument('--task_number', default=1, type=int, help='Select the task you want to test out the architecture; choose from 1-4')
4041
parser.add_argument('--use_gpu', default=False, type=bool, help = 'Set the flag if you wish to use the GPU')
41-
42+
parser.add_argument('--batch_size', default=16, type=int, help='Batch size you want to use whilst testing the model')
4243
args = parser.parse_args()
4344
use_gpu = args.use_gpu
4445

46+
4547
#randomly shuffle the tasks in the sequence
4648
task_number_list = [x for x in range(1, 10)]
47-
shuffle(task_number)
49+
shuffle(task_number_list)
4850

4951

5052
#transformations for the test data
@@ -65,12 +67,10 @@
6567
])
6668
}
6769

68-
69-
#create the results.txt file
70-
with open("results.txt", "w") as myfile:
71-
myfile.write()
72-
myfile.close()
73-
70+
#set the device to be used and initialize the feature extractor to feed the data into the autoencoder
71+
device = torch.device("cuda:0" if use_gpu else "cpu")
72+
feature_extractor = Alexnet_FE(models.alexnet(pretrained=True))
73+
feature_extractor.to(device)
7474

7575
for task_number in task_number_list:
7676

@@ -91,14 +91,14 @@
9191
image_folder = datasets.ImageFolder(os.path.join(path_task, 'test'), transform = data_transforms_mnist['test'])
9292
dset_size = len(image_folder)
9393

94-
device = torch.device("cuda:0" if use_gpu else "cpu")
95-
94+
9695
dset_loaders = torch.utils.data.DataLoader(image_folder, batch_size = batch_size,
9796
shuffle=True, num_workers=4)
9897

9998
best_loss = 99999999999
10099
model_number = 0
101100

101+
102102
#Load autoencoder models for tasks 1-4; need to select the best performing autoencoder model
103103
for ae_number in range(1, 10):
104104
ae_path = os.path.join(encoder_path, "autoencoder_" + str(ae_number))
@@ -122,12 +122,19 @@
122122
else:
123123
input_data = Variable(input_data)
124124

125-
preds = model(input_data)
126-
loss = encoder_criterion(preds, input_data)
125+
126+
#get the input to the autoencoder from the conv backbone of the Alexnet
127+
input_to_ae = feature_extractor(input_data)
128+
input_to_ae = input_to_ae.view(input_to_ae.size(0), -1)
127129

130+
#get the outputs from the model
131+
preds = model(input_to_ae)
132+
loss = encoder_criterion(preds, input_to_ae)
133+
128134
del preds
129135
del input_data
130-
136+
del input_to_ae
137+
131138
running_loss = running_loss + loss.item()
132139

133140
model_loss = running_loss/dset_size
@@ -146,15 +153,17 @@
146153
print ("Incorrect routing, wrong model has been selected")
147154

148155

149-
trained_model_path = os.path.join(model_path, "model_" + model_number)
156+
#Load the expert that has been found by this procedure into memory
157+
trained_model_path = os.path.join(model_path, "model_" + str(model_number))
150158

159+
#Get the number of classes that this expert was exposed to
151160
file_name = os.path.join(trained_model_path, "classes.txt")
152161
file_object = open(file_name, 'r')
153162

154163
num_of_classes = file_object.read()
155164
file_object.close()
156165

157-
num_of_classes = int(num_of_classes_old)
166+
num_of_classes = int(num_of_classes)
158167

159168
model = GeneralModelClass(num_of_classes)
160169
model.load_state_dict(torch.load(os.path.join(trained_model_path, 'best_performing_model.pth')))
@@ -193,6 +202,7 @@
193202
model_loss = running_loss/dset_size
194203
model_accuracy = running_corrects.double()/dset_size
195204

205+
#Store the results into a file
196206
with open("results.txt", "a") as myfile:
197207
myfile.write("\n{}: {}".format(task_number, model_accuracy*100))
198208
myfile.close()

0 commit comments

Comments
 (0)