You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First of all I just want to say that your WTTE is really cool. Great blog post and paper. I've been using an adapted version of it for a time-to-event task and wanted to share a trick I've found useful for numerical instability issues in case you or anyone else is interested.
A couple of things to note in my case:
I'm not using an RNN, since I have sufficient engineered features for the history at a point in time.
I rewrote it in pytorch, so my code here is in pytorch.
My case uses the discrete likelihood. I haven't tested anything for the continuous case but I don't see why it wouldn't work there too.
While testing it I had a lot of issues with nan loss and numeric instability during the fit of alpha and beta. I know you've worked a lot on this from reading the other github issues.
I've found that this parameterization for alpha and beta helps a lot:
classWTTE(nn.Module):
def__init__(self, nnet_output_dim):
super(WTTE, self).__init__()
# this is the neural net whose outputs then are used to find alpha and betaself.nnet=InnerNNET()
self.softplus=nn.Softplus()
self.tanh=nn.Tanh()
self.alpha_scaling=nn.Linear(nnet_output_dim, 1)
self.beta_scaling=nn.Linear(nnet_output_dim, 1)
# offset and scale parametersalpha_offset_init, beta_offset_init=1.0, 1.0alpha_scale_init, beta_scale_init=1.0, 1.0self.alpha_offset=nn.Parameter(tt.from_numpy(np.array([alpha_offset_init])).float(), requires_grad=True)
self.beta_offset=nn.Parameter(tt.from_numpy(np.array([beta_offset_init])).float(), requires_grad=True)
self.alpha_scale=nn.Parameter(tt.from_numpy(np.array([alpha_scale_init])).float(), requires_grad=True)
self.beta_scale=nn.Parameter(tt.from_numpy(np.array([beta_scale_init])).float(), requires_grad=True)
defforward(self, x):
x=self.nnet(x)
# derive alpha and beta individual scaling factorsa_scaler=self.alpha_scaling(x)
b_scaler=self.beta_scaling(x)
# enforce the scaling factors to be between -1 and 1a_scaler=self.tanh(a_scaler)
b_scaler=self.tanh(b_scaler)
# combine the global offsets and scale factors with individual onesalpha=self.alpha_offset+ (self.alpha_scale*a_scaler)
beta=self.beta_offset+ (self.beta_scale*b_scaler)
# put alpha on positive range with exp, beta with softplusalpha=tt.exp(alpha)
beta=self.softplus(beta)
returnalpha, beta
Essentially why this helps is that the tanh activation function enforces the individual/observation scaling factors to always be between -1 and 1, so you don't have to worry about too small or large outputs from your network. The alpha_scale and beta_scale are responsible for setting the range to multiply the -1 to 1 outputs by. The offsets are nice as an intercept or centering mechanism.
If you set the initialization for the offsets and scaling factors to be low numbers (I start them at 1.0, for example), they will slowly creep up to their optimal values during fit. Here is some output from a recent fit of mine to show what I mean:
A off: 1.10000 A scale: 1.10000 B off: 0.90000 B scale: 1.10000
A off: 1.23279 A scale: 1.03885 B off: 0.90022 B scale: 0.89804
A off: 1.25786 A scale: 1.06547 B off: 0.90056 B scale: 0.89798
A off: 1.28466 A scale: 1.09343 B off: 0.90163 B scale: 0.89878
A off: 1.34988 A scale: 1.16266 B off: 0.90678 B scale: 0.90290
A off: 1.44370 A scale: 1.25528 B off: 0.93324 B scale: 0.93015
A off: 1.53040 A scale: 1.32979 B off: 0.98308 B scale: 0.98226
...[many epochs later]...
A off: 1.92782 A scale: 1.57879 B off: 2.97086 B scale: 2.55308
A off: 1.93340 A scale: 1.59249 B off: 3.01380 B scale: 2.59065
A off: 1.94988 A scale: 1.59956 B off: 3.01739 B scale: 2.54407
A off: 1.94464 A scale: 1.59733 B off: 3.03923 B scale: 2.55807
A off: 1.95629 A scale: 1.60365 B off: 3.06733 B scale: 2.58807
A off: 1.95865 A scale: 1.59092 B off: 3.09355 B scale: 2.60830
You could also enforce maximums on alpha and beta easily if you wanted to by adding torch.clamp calls around the outputs, but I have not found this to be necessary.
I have only tested this on my own data and so I can't make any claims that this will solve numerical instability issues for other people, but I figured it may help someone!
The text was updated successfully, but these errors were encountered:
First of all I just want to say that your WTTE is really cool. Great blog post and paper. I've been using an adapted version of it for a time-to-event task and wanted to share a trick I've found useful for numerical instability issues in case you or anyone else is interested.
A couple of things to note in my case:
While testing it I had a lot of issues with nan loss and numeric instability during the fit of
alpha
andbeta
. I know you've worked a lot on this from reading the other github issues.I've found that this parameterization for
alpha
andbeta
helps a lot:Essentially why this helps is that the
tanh
activation function enforces the individual/observation scaling factors to always be between -1 and 1, so you don't have to worry about too small or large outputs from your network. Thealpha_scale
andbeta_scale
are responsible for setting the range to multiply the -1 to 1 outputs by. The offsets are nice as an intercept or centering mechanism.If you set the initialization for the offsets and scaling factors to be low numbers (I start them at 1.0, for example), they will slowly creep up to their optimal values during fit. Here is some output from a recent fit of mine to show what I mean:
You could also enforce maximums on
alpha
andbeta
easily if you wanted to by addingtorch.clamp
calls around the outputs, but I have not found this to be necessary.I have only tested this on my own data and so I can't make any claims that this will solve numerical instability issues for other people, but I figured it may help someone!
The text was updated successfully, but these errors were encountered: