@@ -151,6 +151,8 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
151151 ref_model .train (False )
152152 ref_model .to (device )
153153
154+ #print (ref_model)
155+
154156 for param in model_init .Tmodel .classifier .parameters ():
155157 param .requires_grad = True
156158
@@ -176,14 +178,18 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
176178 #Actually makes the changes to the model_init, so slightly redundant
177179 print ("Initializing the model to be trained" )
178180 model_init = initialize_new_model (model_init , num_classes , num_of_classes_old )
179- model_init .to (device )
181+ #print (model_init)
182+ #model_init.to(device)
180183 start_epoch = 0
181184
182185 #The training process format or LwF (Learning without Forgetting)
183186 # Add the start epoch code
184187
185188 if (best_relatedness > 0.85 ):
186189
190+ model_init .to (device )
191+ ref_model .to (device )
192+
187193 print ("Using the LwF approach" )
188194 for epoch in range (start_epoch , num_epochs ):
189195 since = time .time ()
@@ -197,7 +203,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
197203
198204 #scales the optimizer every 10 epochs
199205 optimizer = exp_lr_scheduler (optimizer , epoch , lr )
200- model_init = model_init .train (True )
206+ # model_init = model_init.train(True)
201207
202208 for data in dset_loaders :
203209 input_data , labels = data
@@ -212,32 +218,27 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
212218 input_data = Variable (input_data )
213219 labels = Variable (labels )
214220
215- model_init .to (device )
216- ref_model .to (device )
217-
218221 output = model_init (input_data )
219222 ref_output = ref_model (input_data )
220-
221223 del input_data
222224
223225 optimizer .zero_grad ()
224- model_init .zero_grad ()
225226
226227 # loss_1 only takes in the outputs from the nodes of the old classes
227228
228229 loss1_output = output [:, :num_of_classes_old ]
229230 loss2_output = output [:, num_of_classes_old :]
230231
232+ print ()
233+
231234 del output
232235
233236 loss_1 = model_criterion (loss1_output , ref_output , flag = "Distill" )
234-
235237 del ref_output
236238
237239 # loss_2 takes in the outputs from the nodes that were initialized for the new task
238240
239241 loss_2 = model_criterion (loss2_output , labels , flag = "CE" )
240-
241242 del labels
242243 #del output
243244
@@ -257,7 +258,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
257258
258259 print ('Epoch Loss:{}' .format (epoch_loss ))
259260
260- if (epoch != 0 and epoch != num_of_epochs - 1 and (epoch + 1 ) % 10 == 0 ):
261+ if (epoch != 0 and epoch != num_epochs - 1 and (epoch + 1 ) % 10 == 0 ):
261262 epoch_file_name = os .path .join (mypath , str (epoch + 1 )+ '.pth.tar' )
262263 torch .save ({
263264 'epoch' : epoch ,
@@ -277,6 +278,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
277278 #Process for finetuning the model
278279 else :
279280
281+ model_init .to (device )
280282 print ("Using the finetuning approach" )
281283
282284 for epoch in range (start_epoch , num_epochs ):
@@ -302,9 +304,6 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
302304 input_data = Variable (input_data )
303305 labels = Variable (labels )
304306
305- #Shifts the model to the device
306- model_init .to (device )
307-
308307 output = model_init (input_data )
309308
310309 del input_data
@@ -314,7 +313,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
314313 model_init .zero_grad ()
315314
316315 #Implemented as explained in the doc string
317- loss = model_criterion (output [num_of_classes_old :], labels )
316+ loss = model_criterion (output [num_of_classes_old :], labels , flag = 'CE' )
318317
319318 del output
320319 del labels
@@ -330,7 +329,7 @@ def train_model(num_classes, feature_extractor, encoder_criterion, dset_loaders,
330329
331330 print ('Epoch Loss:{}' .format (epoch_loss ))
332331
333- if (epoch != 0 and (epoch + 1 ) % 5 == 0 and epoch != num_of_epochs - 1 ):
332+ if (epoch != 0 and (epoch + 1 ) % 5 == 0 and epoch != num_epochs - 1 ):
334333 epoch_file_name = os .path .join (path_to_model , str (epoch + 1 )+ '.pth.tar' )
335334 torch .save ({
336335 'epoch' : epoch ,
0 commit comments