Skip to content

Commit ca696ce

Browse files
committed
Reorder plotting in main
1 parent 658d924 commit ca696ce

File tree

2 files changed

+19
-16
lines changed

2 files changed

+19
-16
lines changed

main.py

+15-14
Original file line numberDiff line numberDiff line change
@@ -191,20 +191,6 @@ def main():
191191
mu_t, _, w_logits = state_q.apply_fn(state_q.params, t)
192192
w = jax.nn.softmax(w_logits)
193193
print('Weights of mixtures:', w)
194-
if system.plot:
195-
mu_t_no_vel = mu_t[:, :, :system.A.shape[0]]
196-
num_trajectories = jnp.array((w * 100).round(), dtype=int)
197-
198-
trajectories = jnp.swapaxes(mu_t_no_vel, 0, 1)
199-
trajectories = (jnp.vstack([trajectories[i].repeat(n, axis=0) for i, n in enumerate(num_trajectories) if n > 0])
200-
.reshape(num_trajectories.sum(), -1, mu_t_no_vel.shape[2]))
201-
202-
system.plot(title='Weighted mean paths', trajectories=trajectories)
203-
show_or_save_fig(args.save_dir, 'mean_paths', args.extension)
204-
205-
if system.plot and system.A.shape[0] == 2:
206-
print('Animating gif, this might take a few seconds ...')
207-
plot_u_t(system, setup, state_q, args.T, args.save_dir, 'u_t', frames=100)
208194

209195
key, init_key = jax.random.split(key)
210196
x_0 = jnp.ones((args.num_paths, A.shape[0]), dtype=jnp.float32) * A
@@ -228,6 +214,17 @@ def main():
228214
save_trajectory(system.mdtraj_topology, x_t_stoch_no_vel[-1].reshape(1, -1, 3), f'{args.save_dir}/stoch_-1.pdb')
229215

230216
if system.plot:
217+
mu_t_no_vel = mu_t[:, :, :system.A.shape[0]]
218+
num_trajectories = jnp.array((w * 100).round(), dtype=int)
219+
220+
trajectories = jnp.swapaxes(mu_t_no_vel, 0, 1)
221+
trajectories = (
222+
jnp.vstack([trajectories[i].repeat(n, axis=0) for i, n in enumerate(num_trajectories) if n > 0])
223+
.reshape(num_trajectories.sum(), -1, mu_t_no_vel.shape[2]))
224+
225+
system.plot(title='Weighted mean paths', trajectories=trajectories)
226+
show_or_save_fig(args.save_dir, 'mean_paths', args.extension)
227+
231228
plot_energy(system, [x_t_det_no_vel[0], x_t_det_no_vel[-1]], args.log_plots)
232229
show_or_save_fig(args.save_dir, 'path_energy_deterministic', args.extension)
233230

@@ -248,6 +245,10 @@ def main():
248245
plt.plot(x_t_stoch_no_vel[i, :, 0].T, x_t_stoch_no_vel[i, :, 1].T, c=c)
249246
show_or_save_fig(args.save_dir, 'paths_stochastic_and_individual', args.extension)
250247

248+
if system.A.shape[0] == 2:
249+
print('Animating gif, this might take a few seconds ...')
250+
plot_u_t(system, setup, state_q, args.T, args.save_dir, 'u_t', frames=100)
251+
251252

252253
if __name__ == '__main__':
253254
try:

training/qsetup.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
import argparse
22
from abc import ABC, abstractmethod
33
from dataclasses import dataclass
4-
from typing import Callable, Optional
4+
from typing import Callable, Optional, Union, Dict, Any
55
from flax import linen as nn
66
from flax.training.train_state import TrainState
7+
from flax.typing import FrozenVariableDict
78
from systems import System
89
from jax.typing import ArrayLike
910
import jax.numpy as jnp
1011
import jax
1112
from tqdm import trange
12-
import utils.aldp as aldp
1313

1414

1515
@dataclass
@@ -79,6 +79,8 @@ def construct(system: System, model: Optional[nn.module], xi: float, A: ArrayLik
7979

8080
transform = None
8181
if args.internal_coordinates:
82+
import utils.aldp as aldp
83+
8284
# Initialize transform with the initial state (without second order elements)
8385
transform = aldp.InternalCoordinateWrapper(system.A.reshape(1, -1))
8486
# convert A to internal coordinates, but discard the second order elements (if they exist)

0 commit comments

Comments
 (0)