3
3
import lasagne
4
4
import numpy
5
5
import time
6
+ import json
7
+ import h5py
6
8
7
9
import theano
8
10
import theano .tensor as T
@@ -16,7 +18,7 @@ def __init__(self, n_in,
16
18
learning_rate , hidden_layers_sizes = None ,
17
19
lr_decay = 0.0 , momentum = 0.9 ,
18
20
L2_reg = 0.0 , L1_reg = 0.0 ,
19
- activation = lasagne . nonlinearities . rectify ,
21
+ activation = " rectify" ,
20
22
dropout = None ,
21
23
batch_norm = False ,
22
24
standardize = False ,
@@ -59,9 +61,14 @@ def __init__(self, n_in,
59
61
shared_axes = 0 )
60
62
self .standardize = standardize
61
63
64
+ if activation == 'rectify' :
65
+ activation_fn = lasagne .nonlinearities .rectify
66
+ else :
67
+ raise IllegalArgumentException ("Unknown activation function: %s" % activation )
68
+
62
69
# Construct Neural Network
63
70
for n_layer in (hidden_layers_sizes or []):
64
- if activation == lasagne .nonlinearities .rectify :
71
+ if activation_fn == lasagne .nonlinearities .rectify :
65
72
W_init = lasagne .init .GlorotUniform ()
66
73
else :
67
74
# TODO: implement other initializations
@@ -70,7 +77,7 @@ def __init__(self, n_in,
70
77
71
78
network = lasagne .layers .DenseLayer (
72
79
network , num_units = n_layer ,
73
- nonlinearity = activation ,
80
+ nonlinearity = activation_fn ,
74
81
W = W_init
75
82
)
76
83
@@ -95,7 +102,21 @@ def __init__(self, n_in,
95
102
# Relevant Functions
96
103
self .partial_hazard = T .exp (self .risk (deterministic = True )) # e^h(x)
97
104
98
- # Set Hyper-parameters:
105
+ # Store and set needed Hyper-parameters:
106
+ self .hyperparams = {
107
+ 'n_in' : n_in ,
108
+ 'learning_rate' : learning_rate ,
109
+ 'hidden_layers_sizes' : hidden_layers_sizes ,
110
+ 'lr_decay' : lr_decay ,
111
+ 'momentum' : momentum ,
112
+ 'L2_reg' : L2_reg ,
113
+ 'L1_reg' : L1_reg ,
114
+ 'activation' : activation ,
115
+ 'dropout' : dropout ,
116
+ 'batch_norm' : batch_norm ,
117
+ 'standardize' : standardize
118
+ }
119
+
99
120
self .n_in = n_in
100
121
self .learning_rate = learning_rate
101
122
self .lr_decay = lr_decay
@@ -183,6 +204,9 @@ def _get_loss_updates(self,
183
204
loss , self .params , ** kwargs
184
205
)
185
206
207
+ # Store last update function
208
+ self .updates = updates
209
+
186
210
return loss , updates
187
211
188
212
def _get_train_valid_fn (self ,
@@ -373,9 +397,6 @@ def train(self,
373
397
374
398
start = time .time ()
375
399
for epoch in range (n_epochs ):
376
- if logger and (epoch % validation_frequency == 0 ):
377
- logger .print_progress_bar (epoch , n_epochs )
378
-
379
400
# Power-Learning Rate Decay
380
401
lr = self .learning_rate / (1 + epoch * self .lr_decay )
381
402
@@ -415,6 +436,9 @@ def train(self,
415
436
# best_params_idx = epoch
416
437
best_validation_loss = validation_loss
417
438
439
+ if logger and (epoch % validation_frequency == 0 ):
440
+ logger .print_progress_bar (epoch , n_epochs , loss )
441
+
418
442
if patience <= epoch :
419
443
break
420
444
@@ -440,16 +464,71 @@ def train(self,
440
464
441
465
return logger .history
442
466
443
- # @TODO need to reimplement with it working
444
- # @TODO need to add save_model
445
- def load_model (self , params ):
446
- """
447
- Loads the network's parameters from a previously saved state.
448
-
449
- Parameters:
450
- params: a list of parameters in same order as network.params
451
- """
452
- lasagne .layers .set_all_param_values (self .network , params , trainable = True )
467
+ def to_json (self ):
468
+ return json .dumps (self .hyperparams )
469
+
470
+ def save_model (self , filename , weights_file = None ):
471
+ with open (filename , 'w' ) as fp :
472
+ fp .write (self .to_json ())
473
+
474
+ if weights_file :
475
+ self .save_weights (weights_file )
476
+
477
+ # # @TODO need to reimplement with it working
478
+ # # @TODO need to add save_model
479
+ # def load_model(self, params):
480
+ # """
481
+ # Loads the network's parameters from a previously saved state.
482
+
483
+ # Parameters:
484
+ # params: a list of parameters in same order as network.params
485
+ # """
486
+ # lasagne.layers.set_all_param_values(self.network, params, trainable=True)
487
+
488
+ def save_weights (self ,filename ):
489
+ def save_list_by_idx (group , lst ):
490
+ for (idx , param ) in enumerate (lst ):
491
+ group .create_dataset (str (idx ), data = param )
492
+
493
+ weights_out = lasagne .layers .get_all_param_values (self .network , trainable = False )
494
+ updates_out = [p .get_value () for p in self .updates .keys ()]
495
+
496
+ # Store all of the parameters in an hd5f file
497
+ # We store the parameter under the index in the list
498
+ # so that when we read it later, we can construct the list of
499
+ # parameters in the same order they were saved
500
+ with h5py .File (filename , 'w' ) as f_out :
501
+ weights_grp = f_out .create_group ('weights' )
502
+ save_list_by_idx (weights_grp , weights_out )
503
+
504
+ updates_grp = f_out .create_group ('updates' )
505
+ save_list_by_idx (updates_grp , updates_out )
506
+
507
+ def load_weights (self , filename ):
508
+ def load_all_keys (fp ):
509
+ results = []
510
+ for key in fp :
511
+ dataset = fp [key ][:]
512
+ results .append ((int (key ), dataset ))
513
+ return results
514
+
515
+ def sort_params_by_idx (params ):
516
+ return [param for (idx , param ) in sorted (params ,
517
+ key = lambda param : param [0 ])]
518
+
519
+ # Load all of the parameters
520
+ with h5py .File (filename , 'r' ) as f_in :
521
+ weights_in = load_all_keys (f_in ['weights' ])
522
+ updates_in = load_all_keys (f_in ['updates' ])
523
+
524
+ # Sort them according to the idx to ensure they are set correctly
525
+ sorted_weights_in = sort_params_by_idx (weights_in )
526
+ lasagne .layers .set_all_param_values (self .network , sorted_weights_in ,
527
+ trainable = False )
528
+
529
+ sorted_updates_in = sort_params_by_idx (updates_in )
530
+ for p , value in zip (self .updates .keys (), sorted_updates_in ):
531
+ p .set_value (value )
453
532
454
533
def risk (self ,deterministic = False ):
455
534
"""
@@ -554,3 +633,15 @@ def plot_risk_surface(self, data, i = 0, j = 1,
554
633
plt .ylabel ('$x_{%d}$' % j , fontsize = 18 )
555
634
556
635
return fig
636
+
637
+ def load_model_from_json (model_fp , weights_fp = None ):
638
+ with open (model_fp , 'r' ) as fp :
639
+ json_model = fp .read ()
640
+ hyperparams = json .loads (json_model )
641
+
642
+ model = DeepSurv (** hyperparams )
643
+
644
+ if weights_fp :
645
+ model .load_weights (weights_fp )
646
+
647
+ return model
0 commit comments