@@ -191,20 +191,6 @@ def main():
191
191
mu_t , _ , w_logits = state_q .apply_fn (state_q .params , t )
192
192
w = jax .nn .softmax (w_logits )
193
193
print ('Weights of mixtures:' , w )
194
- if system .plot :
195
- mu_t_no_vel = mu_t [:, :, :system .A .shape [0 ]]
196
- num_trajectories = jnp .array ((w * 100 ).round (), dtype = int )
197
-
198
- trajectories = jnp .swapaxes (mu_t_no_vel , 0 , 1 )
199
- trajectories = (jnp .vstack ([trajectories [i ].repeat (n , axis = 0 ) for i , n in enumerate (num_trajectories ) if n > 0 ])
200
- .reshape (num_trajectories .sum (), - 1 , mu_t_no_vel .shape [2 ]))
201
-
202
- system .plot (title = 'Weighted mean paths' , trajectories = trajectories )
203
- show_or_save_fig (args .save_dir , 'mean_paths' , args .extension )
204
-
205
- if system .plot and system .A .shape [0 ] == 2 :
206
- print ('Animating gif, this might take a few seconds ...' )
207
- plot_u_t (system , setup , state_q , args .T , args .save_dir , 'u_t' , frames = 100 )
208
194
209
195
key , init_key = jax .random .split (key )
210
196
x_0 = jnp .ones ((args .num_paths , A .shape [0 ]), dtype = jnp .float32 ) * A
@@ -228,6 +214,17 @@ def main():
228
214
save_trajectory (system .mdtraj_topology , x_t_stoch_no_vel [- 1 ].reshape (1 , - 1 , 3 ), f'{ args .save_dir } /stoch_-1.pdb' )
229
215
230
216
if system .plot :
217
+ mu_t_no_vel = mu_t [:, :, :system .A .shape [0 ]]
218
+ num_trajectories = jnp .array ((w * 100 ).round (), dtype = int )
219
+
220
+ trajectories = jnp .swapaxes (mu_t_no_vel , 0 , 1 )
221
+ trajectories = (
222
+ jnp .vstack ([trajectories [i ].repeat (n , axis = 0 ) for i , n in enumerate (num_trajectories ) if n > 0 ])
223
+ .reshape (num_trajectories .sum (), - 1 , mu_t_no_vel .shape [2 ]))
224
+
225
+ system .plot (title = 'Weighted mean paths' , trajectories = trajectories )
226
+ show_or_save_fig (args .save_dir , 'mean_paths' , args .extension )
227
+
231
228
plot_energy (system , [x_t_det_no_vel [0 ], x_t_det_no_vel [- 1 ]], args .log_plots )
232
229
show_or_save_fig (args .save_dir , 'path_energy_deterministic' , args .extension )
233
230
@@ -248,6 +245,10 @@ def main():
248
245
plt .plot (x_t_stoch_no_vel [i , :, 0 ].T , x_t_stoch_no_vel [i , :, 1 ].T , c = c )
249
246
show_or_save_fig (args .save_dir , 'paths_stochastic_and_individual' , args .extension )
250
247
248
+ if system .A .shape [0 ] == 2 :
249
+ print ('Animating gif, this might take a few seconds ...' )
250
+ plot_u_t (system , setup , state_q , args .T , args .save_dir , 'u_t' , frames = 100 )
251
+
251
252
252
253
if __name__ == '__main__' :
253
254
try :
0 commit comments