20
20
21
21
from utils .angles import phi_psi_from_mdtraj
22
22
from utils .animation import save_trajectory , to_md_traj
23
- from utils .plot import show_or_save_fig
23
+ from utils .plot import show_or_save_fig , human_format
24
24
from utils .rmsd import kabsch_align , kabsch_rmsd
25
25
26
26
from argparse import ArgumentParser
27
27
28
28
parser = ArgumentParser ()
29
29
parser .add_argument ('--mechanism' , type = str , choices = ['one-way-shooting' , 'two-way-shooting' ], required = True )
30
- parser .add_argument ('--states' , type = str , default = 'phi-psi' , choices = ['phi-psi' , 'rmsd' ])
30
+ parser .add_argument ('--states' , type = str , default = 'phi-psi' , choices = ['phi-psi' , 'rmsd' , 'exact' ])
31
31
parser .add_argument ('--fixed_length' , type = int , default = 0 )
32
32
parser .add_argument ('--warmup' , type = int , default = 0 )
33
33
parser .add_argument ('--num_paths' , type = int , required = True )
39
39
help = 'Ensure that the initial path connects A with B by prepending A and appending B.' )
40
40
41
41
42
- def human_format (num ):
43
- """https://stackoverflow.com/a/45846841/4417954"""
44
- num = float ('{:.3g}' .format (num ))
45
- if num >= 1 :
46
- magnitude = 0
47
- while abs (num ) >= 1000 :
48
- magnitude += 1
49
- num /= 1000.0
50
- return '{}{}' .format ('{:f}' .format (num ).rstrip ('0' ).rstrip ('.' ), ['' , 'K' , 'M' , 'B' , 'T' ][magnitude ])
51
- else :
52
- magnitude = 0
53
- while abs (num ) < 1 :
54
- magnitude += 1
55
- num *= 1000.0
56
- return '{}{}' .format ('{:f}' .format (num ).rstrip ('0' ).rstrip ('.' ), ['' , 'm' , 'µ' , 'n' , 'p' , 'f' ][magnitude ])
57
-
58
-
59
42
dt_as_unit = unit .Quantity (value = 1 , unit = unit .femtosecond )
60
43
dt_in_ps = dt_as_unit .value_in_unit (unit .picosecond )
61
44
dt = dt_as_unit .value_in_unit (unit .second )
@@ -101,6 +84,8 @@ def step_n(step, _x, _v, n, _key):
101
84
savedir += f'-{ args .fixed_length } steps'
102
85
if args .states == 'rmsd' :
103
86
savedir += '-rmsd'
87
+ elif args .states == 'exact' :
88
+ savedir += '-exact'
104
89
105
90
os .makedirs (savedir , exist_ok = True )
106
91
@@ -118,6 +103,7 @@ def U_padded(x):
118
103
x = x_empty .at [:x .shape [0 ], :].set (x .reshape (- 1 , 66 ))
119
104
return system .U (x )[:orig_length ]
120
105
106
+
121
107
@jax .jit
122
108
def step (_x , _key ):
123
109
"""Perform one step of forward euler"""
@@ -197,6 +183,19 @@ def langevin_log_path_likelihood(path_and_velocities):
197
183
state_B = jax .jit (
198
184
lambda s : is_within (phis_psis (s .reshape (- 1 , 22 , 3 )).reshape (- 1 , 2 ), phis_psis (system .B .reshape (- 1 , 22 , 3 )),
199
185
radius ))
186
+ elif args .states == 'exact' :
187
+ from scipy .stats import chi2
188
+ percentile = 0.99
189
+ noise_scale = 1e-4
190
+ threshold = jnp .sqrt (chi2 .ppf (percentile , system .A .shape [0 ]) * noise_scale )
191
+ print (threshold )
192
+ def kabsch_l2 (A , B ):
193
+ a , b = kabsch_align (A , B )
194
+
195
+ return jnp .linalg .norm (a - b )
196
+
197
+ state_A = jax .jit (jax .vmap (lambda s : kabsch_l2 (system .A .reshape (22 , 3 ), s .reshape (22 , 3 )) <= threshold ))
198
+ state_B = jax .jit (jax .vmap (lambda s : kabsch_l2 (system .B .reshape (22 , 3 ), s .reshape (22 , 3 )) <= threshold ))
200
199
else :
201
200
raise ValueError (f"Unknown states { args .states } " )
202
201
@@ -216,9 +215,10 @@ def langevin_log_path_likelihood(path_and_velocities):
216
215
save_trajectory (system .mdtraj_topology , jnp .array (initial_trajectory ), f'{ savedir } /initial_trajectory.pdb' )
217
216
218
217
if args .resume :
219
- paths = [[x for x in p .astype (np .float32 )] for p in np .load (f'{ savedir } /paths.npy' , allow_pickle = True )]
218
+ print ('Loading stored data.' )
219
+ paths = [[x for x in p .astype (np .float32 )] for p in tqdm (np .load (f'{ savedir } /paths.npy' , allow_pickle = True ))]
220
220
velocities = [[v for v in p .astype (np .float32 )] for p in
221
- np .load (f'{ savedir } /velocities.npy' , allow_pickle = True )]
221
+ tqdm ( np .load (f'{ savedir } /velocities.npy' , allow_pickle = True ) )]
222
222
with open (f'{ savedir } /stats.json' , 'r' ) as fp :
223
223
statistics = json .load (fp )
224
224
@@ -227,6 +227,8 @@ def langevin_log_path_likelihood(path_and_velocities):
227
227
'velocities' : velocities ,
228
228
'statistics' : statistics
229
229
}
230
+
231
+ print ('Loaded' , len (paths ), 'paths.' )
230
232
else :
231
233
if os .path .exists (f'{ savedir } /paths.npy' ) and not args .override :
232
234
print (f"The target directory is not empty.\n "
@@ -235,8 +237,8 @@ def langevin_log_path_likelihood(path_and_velocities):
235
237
236
238
stored = None
237
239
238
- assert ((tps_config .start_state (system .A ) and tps_config .target_state (system .B ))
239
- or (tps_config .start_state (system .B ) and tps_config .target_state (system .A ))), \
240
+ assert ((tps_config .start_state (system .A . reshape ( 1 , - 1 )) and tps_config .target_state (system .B . reshape ( 1 , - 1 ) ))
241
+ or (tps_config .start_state (system .B . reshape ( 1 , - 1 )) and tps_config .target_state (system .A . reshape ( 1 , - 1 ) ))), \
240
242
'A and B are not in the correct states. Please check your settings.'
241
243
242
244
if args .mechanism == 'one-way-shooting' :
@@ -258,14 +260,19 @@ def langevin_log_path_likelihood(path_and_velocities):
258
260
fixed_length = args .fixed_length ,
259
261
stored = stored )
260
262
# paths = tps2.unguided_md(tps_config, B, 1, key)
261
- paths = [jnp .array (p ) for p in paths ]
262
- velocities = [jnp .array (p ) for p in velocities ]
263
- # store paths
264
- np .save (f'{ savedir } /paths.npy' , np .array (paths , dtype = object ), allow_pickle = True )
265
- np .save (f'{ savedir } /velocities.npy' , np .array (velocities , dtype = object ), allow_pickle = True )
266
- # save statistics, which is a dictionary
267
- with open (f'{ savedir } /stats.json' , 'w' ) as fp :
268
- json .dump (statistics , fp )
263
+ print ('Converting paths to jax.numpy arrays.' )
264
+ paths = [jnp .array (p ) for p in tqdm (paths )]
265
+ velocities = [jnp .array (p ) for p in tqdm (velocities )]
266
+
267
+ if not args .resume :
268
+ # If we are resuming, everything is already stored
269
+ print ('Storing paths ...' )
270
+ np .save (f'{ savedir } /paths.npy' , np .array (paths , dtype = object ), allow_pickle = True )
271
+ print ('Storing velocities ...' )
272
+ np .save (f'{ savedir } /velocities.npy' , np .array (velocities , dtype = object ), allow_pickle = True )
273
+ # save statistics, which is a dictionary
274
+ with open (f'{ savedir } /stats.json' , 'w' ) as fp :
275
+ json .dump (statistics , fp )
269
276
except Exception as e :
270
277
print (traceback .format_exc ())
271
278
breakpoint ()
@@ -280,8 +287,11 @@ def langevin_log_path_likelihood(path_and_velocities):
280
287
if args .fixed_length == 0 :
281
288
print ([len (p ) for p in paths ])
282
289
plt .hist ([len (p ) for p in paths ], bins = jnp .sqrt (len (paths )).astype (int ).item ())
283
- plt .savefig (f'{ savedir } /lengths.png' , bbox_inches = 'tight' )
284
- plt .show ()
290
+ show_or_save_fig (savedir , 'lengths' , 'png' )
291
+
292
+ max_energy = [jnp .max (U_padded (path )) for path in tqdm (paths )]
293
+ max_energy = np .array (max_energy )
294
+ np .save (f'{ savedir } /max_energy.npy' , max_energy )
285
295
286
296
plt .title (f"{ human_format (len (paths ))} paths @ { temp } K, dt = { human_format (dt )} s" )
287
297
system .plot (trajectories = paths , alpha = 0.7 )
0 commit comments