Skip to content

Commit 69591f3

Browse files
Fixed bug with setting updates
1 parent a0e4ede commit 69591f3

File tree

1 file changed

+12
-20
lines changed

1 file changed

+12
-20
lines changed

deepsurv/deep_surv.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,13 @@ def _get_loss_updates(self,
204204
loss, self.params, **kwargs
205205
)
206206

207-
# Store last update function
207+
# If the model was loaded from file, reload params
208+
if self.restored_update_params:
209+
for p, value in zip(updates.keys(), self.restored_update_params):
210+
p.set_value(value)
211+
self.restored_update_params = None
212+
213+
# Store last update function to be later saved
208214
self.updates = updates
209215

210216
return loss, updates
@@ -363,9 +369,6 @@ def train(self,
363369
"""
364370

365371
# @TODO? Should these be managed by the logger => then you can do logger.getMetrics
366-
# train_loss = []
367-
# train_ci = []
368-
369372
x_train, e_train, t_train = self.prepare_data(train_data)
370373

371374
# Set Standardization layer offset and scale to training data mean and std
@@ -374,8 +377,6 @@ def train(self,
374377
self.scale = x_train.std(axis = 0)
375378

376379
if valid_data:
377-
# valid_loss = []
378-
# valid_ci = []
379380
x_valid, e_valid, t_valid = self.prepare_data(valid_data)
380381

381382
# Initialize Metrics
@@ -474,24 +475,16 @@ def save_model(self, filename, weights_file = None):
474475
if weights_file:
475476
self.save_weights(weights_file)
476477

477-
# # @TODO need to reimplement with it working
478-
# # @TODO need to add save_model
479-
# def load_model(self, params):
480-
# """
481-
# Loads the network's parameters from a previously saved state.
482-
483-
# Parameters:
484-
# params: a list of parameters in same order as network.params
485-
# """
486-
# lasagne.layers.set_all_param_values(self.network, params, trainable=True)
487-
488478
def save_weights(self,filename):
489479
def save_list_by_idx(group, lst):
490480
for (idx, param) in enumerate(lst):
491481
group.create_dataset(str(idx), data=param)
492482

493483
weights_out = lasagne.layers.get_all_param_values(self.network, trainable=False)
494-
updates_out = [p.get_value() for p in self.updates.keys()]
484+
if self.updates:
485+
updates_out = [p.get_value() for p in self.updates.keys()]
486+
else:
487+
raise Exception("Model has not been trained: no params to save!")
495488

496489
# Store all of the parameters in an hd5f file
497490
# We store the parameter under the index in the list
@@ -527,8 +520,7 @@ def sort_params_by_idx(params):
527520
trainable=False)
528521

529522
sorted_updates_in = sort_params_by_idx(updates_in)
530-
for p, value in zip(self.updates.keys(), sorted_updates_in):
531-
p.set_value(value)
523+
self.restored_update_params = sorted_updates_in
532524

533525
def risk(self,deterministic = False):
534526
"""

0 commit comments

Comments
 (0)