Skip to content

Commit a0e4ede

Browse files
Add saving and loading models
1 parent cf1aeda commit a0e4ede

File tree

2 files changed

+113
-20
lines changed

2 files changed

+113
-20
lines changed

deepsurv/deep_surv.py

Lines changed: 108 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import lasagne
44
import numpy
55
import time
6+
import json
7+
import h5py
68

79
import theano
810
import theano.tensor as T
@@ -16,7 +18,7 @@ def __init__(self, n_in,
1618
learning_rate, hidden_layers_sizes = None,
1719
lr_decay = 0.0, momentum = 0.9,
1820
L2_reg = 0.0, L1_reg = 0.0,
19-
activation = lasagne.nonlinearities.rectify,
21+
activation = "rectify",
2022
dropout = None,
2123
batch_norm = False,
2224
standardize = False,
@@ -59,9 +61,14 @@ def __init__(self, n_in,
5961
shared_axes = 0)
6062
self.standardize = standardize
6163

64+
if activation == 'rectify':
65+
activation_fn = lasagne.nonlinearities.rectify
66+
else:
67+
raise IllegalArgumentException("Unknown activation function: %s" % activation)
68+
6269
# Construct Neural Network
6370
for n_layer in (hidden_layers_sizes or []):
64-
if activation == lasagne.nonlinearities.rectify:
71+
if activation_fn == lasagne.nonlinearities.rectify:
6572
W_init = lasagne.init.GlorotUniform()
6673
else:
6774
# TODO: implement other initializations
@@ -70,7 +77,7 @@ def __init__(self, n_in,
7077

7178
network = lasagne.layers.DenseLayer(
7279
network, num_units = n_layer,
73-
nonlinearity = activation,
80+
nonlinearity = activation_fn,
7481
W = W_init
7582
)
7683

@@ -95,7 +102,21 @@ def __init__(self, n_in,
95102
# Relevant Functions
96103
self.partial_hazard = T.exp(self.risk(deterministic = True)) # e^h(x)
97104

98-
# Set Hyper-parameters:
105+
# Store and set needed Hyper-parameters:
106+
self.hyperparams = {
107+
'n_in': n_in,
108+
'learning_rate': learning_rate,
109+
'hidden_layers_sizes': hidden_layers_sizes,
110+
'lr_decay': lr_decay,
111+
'momentum': momentum,
112+
'L2_reg': L2_reg,
113+
'L1_reg': L1_reg,
114+
'activation': activation,
115+
'dropout': dropout,
116+
'batch_norm': batch_norm,
117+
'standardize': standardize
118+
}
119+
99120
self.n_in = n_in
100121
self.learning_rate = learning_rate
101122
self.lr_decay = lr_decay
@@ -183,6 +204,9 @@ def _get_loss_updates(self,
183204
loss, self.params, **kwargs
184205
)
185206

207+
# Store last update function
208+
self.updates = updates
209+
186210
return loss, updates
187211

188212
def _get_train_valid_fn(self,
@@ -373,9 +397,6 @@ def train(self,
373397

374398
start = time.time()
375399
for epoch in range(n_epochs):
376-
if logger and (epoch % validation_frequency == 0):
377-
logger.print_progress_bar(epoch, n_epochs)
378-
379400
# Power-Learning Rate Decay
380401
lr = self.learning_rate / (1 + epoch * self.lr_decay)
381402

@@ -415,6 +436,9 @@ def train(self,
415436
# best_params_idx = epoch
416437
best_validation_loss = validation_loss
417438

439+
if logger and (epoch % validation_frequency == 0):
440+
logger.print_progress_bar(epoch, n_epochs, loss)
441+
418442
if patience <= epoch:
419443
break
420444

@@ -440,16 +464,71 @@ def train(self,
440464

441465
return logger.history
442466

443-
# @TODO need to reimplement with it working
444-
# @TODO need to add save_model
445-
def load_model(self, params):
446-
"""
447-
Loads the network's parameters from a previously saved state.
448-
449-
Parameters:
450-
params: a list of parameters in same order as network.params
451-
"""
452-
lasagne.layers.set_all_param_values(self.network, params, trainable=True)
467+
def to_json(self):
468+
return json.dumps(self.hyperparams)
469+
470+
def save_model(self, filename, weights_file = None):
471+
with open(filename, 'w') as fp:
472+
fp.write(self.to_json())
473+
474+
if weights_file:
475+
self.save_weights(weights_file)
476+
477+
# # @TODO need to reimplement with it working
478+
# # @TODO need to add save_model
479+
# def load_model(self, params):
480+
# """
481+
# Loads the network's parameters from a previously saved state.
482+
483+
# Parameters:
484+
# params: a list of parameters in same order as network.params
485+
# """
486+
# lasagne.layers.set_all_param_values(self.network, params, trainable=True)
487+
488+
def save_weights(self,filename):
489+
def save_list_by_idx(group, lst):
490+
for (idx, param) in enumerate(lst):
491+
group.create_dataset(str(idx), data=param)
492+
493+
weights_out = lasagne.layers.get_all_param_values(self.network, trainable=False)
494+
updates_out = [p.get_value() for p in self.updates.keys()]
495+
496+
# Store all of the parameters in an hd5f file
497+
# We store the parameter under the index in the list
498+
# so that when we read it later, we can construct the list of
499+
# parameters in the same order they were saved
500+
with h5py.File(filename, 'w') as f_out:
501+
weights_grp = f_out.create_group('weights')
502+
save_list_by_idx(weights_grp, weights_out)
503+
504+
updates_grp = f_out.create_group('updates')
505+
save_list_by_idx(updates_grp, updates_out)
506+
507+
def load_weights(self, filename):
508+
def load_all_keys(fp):
509+
results = []
510+
for key in fp:
511+
dataset = fp[key][:]
512+
results.append((int(key), dataset))
513+
return results
514+
515+
def sort_params_by_idx(params):
516+
return [param for (idx, param) in sorted(params,
517+
key=lambda param: param[0])]
518+
519+
# Load all of the parameters
520+
with h5py.File(filename, 'r') as f_in:
521+
weights_in = load_all_keys(f_in['weights'])
522+
updates_in = load_all_keys(f_in['updates'])
523+
524+
# Sort them according to the idx to ensure they are set correctly
525+
sorted_weights_in = sort_params_by_idx(weights_in)
526+
lasagne.layers.set_all_param_values(self.network, sorted_weights_in,
527+
trainable=False)
528+
529+
sorted_updates_in = sort_params_by_idx(updates_in)
530+
for p, value in zip(self.updates.keys(), sorted_updates_in):
531+
p.set_value(value)
453532

454533
def risk(self,deterministic = False):
455534
"""
@@ -554,3 +633,15 @@ def plot_risk_surface(self, data, i = 0, j = 1,
554633
plt.ylabel('$x_{%d}$' % j, fontsize=18)
555634

556635
return fig
636+
637+
def load_model_from_json(model_fp, weights_fp = None):
638+
with open(model_fp, 'r') as fp:
639+
json_model = fp.read()
640+
hyperparams = json.loads(json_model)
641+
642+
model = DeepSurv(**hyperparams)
643+
644+
if weights_fp:
645+
model.load_weights(weights_fp)
646+
647+
return model

deepsurv/deepsurv_logger.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@ def __init__(self):
99
def logMessage(self,message):
1010
self.logger.info(message)
1111

12-
def print_progress_bar(self, step, max_steps, bar_length = 50, char = '*'):
12+
def print_progress_bar(self, step, max_steps, loss = None, bar_length = 25, char = '*', ):
1313
progress_length = int(bar_length * step / max_steps)
1414
progress_bar = [char] * (progress_length) + [' '] * (bar_length - progress_length)
15-
self.logger.info("Training step %d/%d |" % (step, max_steps)
16-
+ ''.join(progress_bar) + "|")
15+
message = "Training step %d/%d |" % (step, max_steps) + ''.join(progress_bar) + "|"
16+
if loss:
17+
message += " - loss: %.4f" % loss
18+
self.logger.info(message)
1719

1820

1921
class TensorboardLogger(DeepSurvLogger):

0 commit comments

Comments
 (0)