Skip to content

Commit 673b6d9

Browse files
committed
Implement exact tps baseline
1 parent 5da73ca commit 673b6d9

File tree

3 files changed

+72
-34
lines changed

3 files changed

+72
-34
lines changed

tps/second_order.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import jax.numpy as jnp
33
from tqdm import tqdm
44

5+
from utils.plot import human_format
6+
57
MAX_STEPS = 2_000
68
MAX_ABS_VALUE = 5
79

@@ -153,7 +155,8 @@ def two_way_shooting(system, trajectory, _previous_velocities, fixed_length, _dt
153155
return False, new_trajectory, new_velocities
154156

155157

156-
def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixed_length=0, warmup=50, stored=None):
158+
def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixed_length=0, warmup=50, stored=None,
159+
max_force_evaluations=10 ** 10):
157160
# pick an initial trajectory
158161
trajectories = [initial_trajectory]
159162
velocities = []
@@ -177,6 +180,7 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
177180
num_tries = 0
178181
num_force_evaluations = 0
179182
num_metropolis_rejected = 0
183+
total_num_force_evaluations = sum(statistics['num_force_evaluations'])
180184
try:
181185
with tqdm(total=num_paths + warmup, initial=len(trajectories) - 1,
182186
desc='warming up' if warmup > 0 else '') as pbar:
@@ -197,6 +201,9 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
197201
trajectories) > 1 else None,
198202
fixed_length, dt, ikey)
199203
num_force_evaluations += len(new_trajectory) - 1
204+
total_num_force_evaluations += len(new_trajectory) - 1
205+
206+
pbar.set_postfix({'total_force_evaluations': human_format(total_num_force_evaluations)})
200207

201208
if not found:
202209
continue
@@ -218,6 +225,10 @@ def mcmc_shooting(system, proposal, initial_trajectory, num_paths, dt, key, fixe
218225
pbar.update(1)
219226
else:
220227
num_metropolis_rejected += 1
228+
229+
if total_num_force_evaluations > max_force_evaluations:
230+
print('Max force evaluations reached, stopping early')
231+
break
221232
except KeyboardInterrupt:
222233
print('SIGINT received, stopping early')
223234
# Fix in case we stop when adding a trajectory

tps_baseline.py

+43-33
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020

2121
from utils.angles import phi_psi_from_mdtraj
2222
from utils.animation import save_trajectory, to_md_traj
23-
from utils.plot import show_or_save_fig
23+
from utils.plot import show_or_save_fig, human_format
2424
from utils.rmsd import kabsch_align, kabsch_rmsd
2525

2626
from argparse import ArgumentParser
2727

2828
parser = ArgumentParser()
2929
parser.add_argument('--mechanism', type=str, choices=['one-way-shooting', 'two-way-shooting'], required=True)
30-
parser.add_argument('--states', type=str, default='phi-psi', choices=['phi-psi', 'rmsd'])
30+
parser.add_argument('--states', type=str, default='phi-psi', choices=['phi-psi', 'rmsd', 'exact'])
3131
parser.add_argument('--fixed_length', type=int, default=0)
3232
parser.add_argument('--warmup', type=int, default=0)
3333
parser.add_argument('--num_paths', type=int, required=True)
@@ -39,23 +39,6 @@
3939
help='Ensure that the initial path connects A with B by prepending A and appending B.')
4040

4141

42-
def human_format(num):
43-
"""https://stackoverflow.com/a/45846841/4417954"""
44-
num = float('{:.3g}'.format(num))
45-
if num >= 1:
46-
magnitude = 0
47-
while abs(num) >= 1000:
48-
magnitude += 1
49-
num /= 1000.0
50-
return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'B', 'T'][magnitude])
51-
else:
52-
magnitude = 0
53-
while abs(num) < 1:
54-
magnitude += 1
55-
num *= 1000.0
56-
return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'm', 'µ', 'n', 'p', 'f'][magnitude])
57-
58-
5942
dt_as_unit = unit.Quantity(value=1, unit=unit.femtosecond)
6043
dt_in_ps = dt_as_unit.value_in_unit(unit.picosecond)
6144
dt = dt_as_unit.value_in_unit(unit.second)
@@ -101,6 +84,8 @@ def step_n(step, _x, _v, n, _key):
10184
savedir += f'-{args.fixed_length}steps'
10285
if args.states == 'rmsd':
10386
savedir += '-rmsd'
87+
elif args.states == 'exact':
88+
savedir += '-exact'
10489

10590
os.makedirs(savedir, exist_ok=True)
10691

@@ -118,6 +103,7 @@ def U_padded(x):
118103
x = x_empty.at[:x.shape[0], :].set(x.reshape(-1, 66))
119104
return system.U(x)[:orig_length]
120105

106+
121107
@jax.jit
122108
def step(_x, _key):
123109
"""Perform one step of forward euler"""
@@ -197,6 +183,19 @@ def langevin_log_path_likelihood(path_and_velocities):
197183
state_B = jax.jit(
198184
lambda s: is_within(phis_psis(s.reshape(-1, 22, 3)).reshape(-1, 2), phis_psis(system.B.reshape(-1, 22, 3)),
199185
radius))
186+
elif args.states == 'exact':
187+
from scipy.stats import chi2
188+
percentile = 0.99
189+
noise_scale = 1e-4
190+
threshold = jnp.sqrt(chi2.ppf(percentile, system.A.shape[0]) * noise_scale)
191+
print(threshold)
192+
def kabsch_l2(A, B):
193+
a, b = kabsch_align(A, B)
194+
195+
return jnp.linalg.norm(a - b)
196+
197+
state_A = jax.jit(jax.vmap(lambda s: kabsch_l2(system.A.reshape(22, 3), s.reshape(22, 3)) <= threshold))
198+
state_B = jax.jit(jax.vmap(lambda s: kabsch_l2(system.B.reshape(22, 3), s.reshape(22, 3)) <= threshold))
200199
else:
201200
raise ValueError(f"Unknown states {args.states}")
202201

@@ -216,9 +215,10 @@ def langevin_log_path_likelihood(path_and_velocities):
216215
save_trajectory(system.mdtraj_topology, jnp.array(initial_trajectory), f'{savedir}/initial_trajectory.pdb')
217216

218217
if args.resume:
219-
paths = [[x for x in p.astype(np.float32)] for p in np.load(f'{savedir}/paths.npy', allow_pickle=True)]
218+
print('Loading stored data.')
219+
paths = [[x for x in p.astype(np.float32)] for p in tqdm(np.load(f'{savedir}/paths.npy', allow_pickle=True))]
220220
velocities = [[v for v in p.astype(np.float32)] for p in
221-
np.load(f'{savedir}/velocities.npy', allow_pickle=True)]
221+
tqdm(np.load(f'{savedir}/velocities.npy', allow_pickle=True))]
222222
with open(f'{savedir}/stats.json', 'r') as fp:
223223
statistics = json.load(fp)
224224

@@ -227,6 +227,8 @@ def langevin_log_path_likelihood(path_and_velocities):
227227
'velocities': velocities,
228228
'statistics': statistics
229229
}
230+
231+
print('Loaded', len(paths), 'paths.')
230232
else:
231233
if os.path.exists(f'{savedir}/paths.npy') and not args.override:
232234
print(f"The target directory is not empty.\n"
@@ -235,8 +237,8 @@ def langevin_log_path_likelihood(path_and_velocities):
235237

236238
stored = None
237239

238-
assert ((tps_config.start_state(system.A) and tps_config.target_state(system.B))
239-
or (tps_config.start_state(system.B) and tps_config.target_state(system.A))), \
240+
assert ((tps_config.start_state(system.A.reshape(1, -1)) and tps_config.target_state(system.B.reshape(1, -1)))
241+
or (tps_config.start_state(system.B.reshape(1, -1)) and tps_config.target_state(system.A.reshape(1, -1)))), \
240242
'A and B are not in the correct states. Please check your settings.'
241243

242244
if args.mechanism == 'one-way-shooting':
@@ -258,14 +260,19 @@ def langevin_log_path_likelihood(path_and_velocities):
258260
fixed_length=args.fixed_length,
259261
stored=stored)
260262
# paths = tps2.unguided_md(tps_config, B, 1, key)
261-
paths = [jnp.array(p) for p in paths]
262-
velocities = [jnp.array(p) for p in velocities]
263-
# store paths
264-
np.save(f'{savedir}/paths.npy', np.array(paths, dtype=object), allow_pickle=True)
265-
np.save(f'{savedir}/velocities.npy', np.array(velocities, dtype=object), allow_pickle=True)
266-
# save statistics, which is a dictionary
267-
with open(f'{savedir}/stats.json', 'w') as fp:
268-
json.dump(statistics, fp)
263+
print('Converting paths to jax.numpy arrays.')
264+
paths = [jnp.array(p) for p in tqdm(paths)]
265+
velocities = [jnp.array(p) for p in tqdm(velocities)]
266+
267+
if not args.resume:
268+
# If we are resuming, everything is already stored
269+
print('Storing paths ...')
270+
np.save(f'{savedir}/paths.npy', np.array(paths, dtype=object), allow_pickle=True)
271+
print('Storing velocities ...')
272+
np.save(f'{savedir}/velocities.npy', np.array(velocities, dtype=object), allow_pickle=True)
273+
# save statistics, which is a dictionary
274+
with open(f'{savedir}/stats.json', 'w') as fp:
275+
json.dump(statistics, fp)
269276
except Exception as e:
270277
print(traceback.format_exc())
271278
breakpoint()
@@ -280,8 +287,11 @@ def langevin_log_path_likelihood(path_and_velocities):
280287
if args.fixed_length == 0:
281288
print([len(p) for p in paths])
282289
plt.hist([len(p) for p in paths], bins=jnp.sqrt(len(paths)).astype(int).item())
283-
plt.savefig(f'{savedir}/lengths.png', bbox_inches='tight')
284-
plt.show()
290+
show_or_save_fig(savedir, 'lengths', 'png')
291+
292+
max_energy = [jnp.max(U_padded(path)) for path in tqdm(paths)]
293+
max_energy = np.array(max_energy)
294+
np.save(f'{savedir}/max_energy.npy', max_energy)
285295

286296
plt.title(f"{human_format(len(paths))} paths @ {temp} K, dt = {human_format(dt)}s")
287297
system.plot(trajectories=paths, alpha=0.7)

utils/plot.py

+17
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,23 @@
1212
from flax.training.train_state import TrainState
1313

1414

15+
def human_format(num):
16+
"""https://stackoverflow.com/a/45846841/4417954"""
17+
num = float('{:.3g}'.format(num))
18+
if num >= 1:
19+
magnitude = 0
20+
while abs(num) >= 1000:
21+
magnitude += 1
22+
num /= 1000.0
23+
return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'B', 'T'][magnitude])
24+
else:
25+
magnitude = 0
26+
while abs(num) < 1:
27+
magnitude += 1
28+
num *= 1000.0
29+
return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'm', 'µ', 'n', 'p', 'f'][magnitude])
30+
31+
1532
def log_scale(log_plot: bool, x: bool, y: bool):
1633
if log_plot:
1734
if x:

0 commit comments

Comments
 (0)