@@ -204,7 +204,13 @@ def _get_loss_updates(self,
204
204
loss , self .params , ** kwargs
205
205
)
206
206
207
- # Store last update function
207
+ # If the model was loaded from file, reload params
208
+ if self .restored_update_params :
209
+ for p , value in zip (updates .keys (), self .restored_update_params ):
210
+ p .set_value (value )
211
+ self .restored_update_params = None
212
+
213
+ # Store last update function to be later saved
208
214
self .updates = updates
209
215
210
216
return loss , updates
@@ -363,9 +369,6 @@ def train(self,
363
369
"""
364
370
365
371
# @TODO? Should these be managed by the logger => then you can do logger.getMetrics
366
- # train_loss = []
367
- # train_ci = []
368
-
369
372
x_train , e_train , t_train = self .prepare_data (train_data )
370
373
371
374
# Set Standardization layer offset and scale to training data mean and std
@@ -374,8 +377,6 @@ def train(self,
374
377
self .scale = x_train .std (axis = 0 )
375
378
376
379
if valid_data :
377
- # valid_loss = []
378
- # valid_ci = []
379
380
x_valid , e_valid , t_valid = self .prepare_data (valid_data )
380
381
381
382
# Initialize Metrics
@@ -474,24 +475,16 @@ def save_model(self, filename, weights_file = None):
474
475
if weights_file :
475
476
self .save_weights (weights_file )
476
477
477
- # # @TODO need to reimplement with it working
478
- # # @TODO need to add save_model
479
- # def load_model(self, params):
480
- # """
481
- # Loads the network's parameters from a previously saved state.
482
-
483
- # Parameters:
484
- # params: a list of parameters in same order as network.params
485
- # """
486
- # lasagne.layers.set_all_param_values(self.network, params, trainable=True)
487
-
488
478
def save_weights (self ,filename ):
489
479
def save_list_by_idx (group , lst ):
490
480
for (idx , param ) in enumerate (lst ):
491
481
group .create_dataset (str (idx ), data = param )
492
482
493
483
weights_out = lasagne .layers .get_all_param_values (self .network , trainable = False )
494
- updates_out = [p .get_value () for p in self .updates .keys ()]
484
+ if self .updates :
485
+ updates_out = [p .get_value () for p in self .updates .keys ()]
486
+ else :
487
+ raise Exception ("Model has not been trained: no params to save!" )
495
488
496
489
# Store all of the parameters in an hd5f file
497
490
# We store the parameter under the index in the list
@@ -527,8 +520,7 @@ def sort_params_by_idx(params):
527
520
trainable = False )
528
521
529
522
sorted_updates_in = sort_params_by_idx (updates_in )
530
- for p , value in zip (self .updates .keys (), sorted_updates_in ):
531
- p .set_value (value )
523
+ self .restored_update_params = sorted_updates_in
532
524
533
525
def risk (self ,deterministic = False ):
534
526
"""
0 commit comments