Skip to content

Commit 99d077c

Browse files
committed
Add low rank training
1 parent 2d044d7 commit 99d077c

10 files changed

+57
-96
lines changed

configs/toy/mueller_single_gaussian.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
save_dir: ./out/toy/mueller_single_gaussian
22

33
test_system: mueller_brown
4-
ode: second_order
4+
ode: first_order
55
parameterization: diagonal
66
T: 275e-4
77
xi: 5.0

configs/toy/mueller_single_gaussian_low_rank.yaml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
save_dir: ./out/toy/mueller_single_gaussian
1+
save_dir: ./out/toy/mueller_single_gaussian_low_rank
22

33
test_system: mueller_brown
44
ode: first_order
5-
parameterization: diagonal
5+
parameterization: low_rank
66
T: 275e-4
77
xi: 5.0
88
gamma: 1.0

environment.yml

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ dependencies:
1515
- openpathsampling=1.6.1
1616
- jax=0.4.23
1717
- flax=0.8.3
18+
- notebook=7.0.8
1819
- pip:
1920
- dmff @ git+https://github.com/deepmodeling/[email protected]
2021
- rdkit==2023.3.3

main.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,9 @@ def main():
125125

126126
key = jax.random.PRNGKey(args.seed)
127127
key, init_key = jax.random.split(key)
128+
128129
params_q = setup.model_q.init(init_key, jnp.zeros([args.BS, 1], dtype=jnp.float32))
130+
print(params_q)
129131

130132
optimizer_q = optax.adam(learning_rate=args.lr)
131133
state_q = train_state.TrainState.create(apply_fn=setup.model_q.apply, params=params_q, tx=optimizer_q)
@@ -146,6 +148,8 @@ def main():
146148
if checkpoint_manager.latest_step() is None:
147149
print("Warning: No checkpoint found.")
148150
else:
151+
# TODO: fix this. At least for low rank it does not work
152+
149153
print('Loading checkpoint:', checkpoint_manager.latest_step())
150154

151155
state_restored = checkpoint_manager.restore(checkpoint_manager.latest_step())
@@ -165,12 +169,12 @@ def main():
165169
show_or_save_fig(args.save_dir, 'loss_plot.pdf')
166170

167171
print("!!!TODO: how to plot this nicely?")
168-
t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1))
169-
key, path_key = jax.random.split(key)
170-
eps = jax.random.normal(path_key, [args.BS, args.num_gaussians, A.shape[-1]])
171-
mu_t, sigma_t, w_logits = state_q.apply_fn(state_q.params, t)
172-
w = jax.nn.softmax(w_logits)[None, :, None]
173-
samples = (w * (mu_t + sigma_t * eps)).sum(axis=1)
172+
# t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1))
173+
# key, path_key = jax.random.split(key)
174+
# eps = jax.random.normal(path_key, [args.BS, args.num_gaussians, A.shape[-1]])
175+
# mu_t, sigma_t, w_logits = state_q.apply_fn(state_q.params, t)
176+
# w = jax.nn.softmax(w_logits)[None, :, None]
177+
# samples = (w * (mu_t + sigma_t * eps)).sum(axis=1)
174178

175179
# plot_energy_surface()
176180
# plt.scatter(samples[:, 0], samples[:, 1])
@@ -180,7 +184,7 @@ def main():
180184

181185
key, init_key = jax.random.split(key)
182186
x_0 = jnp.ones((args.num_paths, A.shape[0]), dtype=jnp.float32) * A
183-
eps = jax.random.normal(key, shape=x_0.shape)
187+
eps = jax.random.normal(key, shape=x_0.shape, dtype=jnp.float32)
184188
x_0 += args.base_sigma * eps
185189

186190
x_t_det = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, None)

prepare_molecule.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def minimize(pdb, out, steps):
3434
for mass_ in mass:
3535
for _ in range(3):
3636
new_mass.append(mass_)
37-
mass = jnp.array(new_mass)
37+
mass = jnp.array(new_mass, dtype=jnp.float32)
3838

3939
# Initialize the potential energy with amber forcefields
4040
ff = Hamiltonian('amber14/protein.ff14SB.xml', 'amber14/tip3p.xml')

systems.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def from_name(cls, name: str, force_clip: float) -> Self:
5050
plot = partial(toy_plot_energy_surface,
5151
U=U, states=list(zip(['A', 'B'], [A, B])), xlim=xlim, ylim=ylim, alpha=1.0
5252
)
53-
mass = jnp.array([1.0, 1.0])
53+
mass = jnp.array([1.0, 1.0], dtype=jnp.float32)
5454
return cls(U, A, B, mass, plot, force_clip)
5555

5656
@classmethod

training/qsetup.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, deterministic:
7575

7676
def construct(system: System, model: nn.module, xi: float, A: ArrayLike, B: ArrayLike,
7777
args: argparse.Namespace) -> QSetup:
78-
from training.setups import diagonal
78+
from training.setups import diagonal, lowrank
7979

8080
transform = None
8181
if args.internal_coordinates:
@@ -87,5 +87,10 @@ def construct(system: System, model: nn.module, xi: float, A: ArrayLike, B: Arra
8787
model, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
8888
)
8989
return diagonal.DiagonalSetup(system, wrapped_module, xi, args.ode, args.T)
90+
elif args.parameterization == 'low_rank':
91+
wrapped_module = lowrank.LowRankWrapper(
92+
model, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
93+
)
94+
return lowrank.LowRankSetup(system, wrapped_module, xi, args.ode, args.T)
9095
else:
9196
raise ValueError(f"Unknown parameterization: {args.parameterization}")

training/setups/diagonal.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def _post_process(self, h: ArrayLike, t: ArrayLike):
3434
)
3535

3636
if self.trainable_weights:
37-
w_logits = self.param('w_logits', nn.initializers.zeros_init(), (num_mixtures,))
37+
w_logits = self.param('w_logits', nn.initializers.zeros_init(), (num_mixtures,), dtype=jnp.float32)
3838
else:
39-
w_logits = jnp.zeros(num_mixtures)
39+
w_logits = jnp.zeros(num_mixtures, dtype=jnp.float32)
4040

4141
return mu, sigma, w_logits
4242

training/setups/lowrank.py

+31-76
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@ class LowRankWrapper(WrappedModule):
1919
trainable_weights: bool
2020
base_sigma: float
2121

22-
def _pre_process(self, t: ArrayLike) -> Tuple[ArrayLike, Tuple[ArrayLike, ArrayLike, ArrayLike]]:
22+
def _pre_process(self, t: ArrayLike) -> Tuple[Tuple[ArrayLike], Tuple[ArrayLike, ArrayLike, ArrayLike]]:
2323
ndim = self.A.shape[0]
2424

2525
h_mu = (1 - t) * self.A + t * self.B
26-
S_0 = jnp.eye(ndim)
27-
S_0 = S_0 * jnp.vstack([self.base_sigma * jnp.ones((ndim // 2, 1)), self.base_sigma * jnp.ones((ndim // 2, 1))])
26+
S_0 = jnp.eye(ndim, dtype=jnp.float32)
27+
S_0 = S_0 * jnp.vstack([self.base_sigma * jnp.ones((ndim // 2, 1), dtype=jnp.float32),
28+
self.base_sigma * jnp.ones((ndim // 2, 1), dtype=jnp.float32)])
2829
S_0 = S_0[None, ...]
2930
h_S = (1 - 2 * t * (1 - t))[..., None] * S_0
30-
return jnp.hstack([h_mu, h_S.reshape(-1, ndim * ndim), t]), (h_mu, h_S, t)
31+
return (jnp.hstack([h_mu, h_S.reshape(-1, ndim * ndim), t]),), (h_mu, h_S, t)
3132

3233
@nn.compact
3334
def _post_process(self, h: ArrayLike, h_mu: ArrayLike, h_S: ArrayLike, t: ArrayLike):
@@ -43,18 +44,18 @@ def _post_process(self, h: ArrayLike, h_mu: ArrayLike, h_S: ArrayLike, t: ArrayL
4344

4445
@jax.vmap
4546
def get_tril(v):
46-
a = jnp.zeros((ndim, ndim))
47+
a = jnp.zeros((ndim, ndim), dtype=jnp.float32)
4748
a = a.at[jnp.tril_indices(ndim)].set(v)
4849
return a
4950

5051
S = get_tril(h[:, ndim:])
51-
S = jnp.tril(2 * jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim)[None, ...] * jnp.exp(S)
52+
S = jnp.tril(2 * jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim, dtype=jnp.float32)[None, ...] * jnp.exp(S)
5253
S = h_S + 2 * ((1 - t) * t)[..., None] * S
5354

5455
if self.trainable_weights:
55-
w_logits = self.param('w_logits', nn.initializers.zeros_init(), (num_mixtures,))
56+
w_logits = self.param('w_logits', nn.initializers.zeros_init(), (num_mixtures,), dtype=jnp.float32)
5657
else:
57-
w_logits = jnp.zeros(num_mixtures)
58+
w_logits = jnp.zeros(num_mixtures, dtype=jnp.float32)
5859

5960
print('mu.shape', mu.shape)
6061
print('S.shape', S.shape)
@@ -65,93 +66,47 @@ def get_tril(v):
6566
@dataclass
6667
class LowRankSetup(DriftedSetup):
6768
model_q: LowRankWrapper
68-
T: float
6969

7070
def __init__(self, system: System, model_q: LowRankWrapper, xi: ArrayLike, order: str, T: float):
71-
super().__init__(system, model_q, xi, order)
72-
self.T = T
71+
super().__init__(system, model_q, xi, order, T)
7372

7473
def construct_loss(self, state_q: TrainState, gamma: float, BS: int) -> Callable[
7574
[Union[FrozenVariableDict, Dict[str, Any]], ArrayLike], ArrayLike]:
7675
def loss_fn(params_q: Union[FrozenVariableDict, Dict[str, Any]], key: ArrayLike) -> ArrayLike:
7776
ndim = self.model_q.A.shape[-1]
7877

7978
key = jax.random.split(key)
80-
t = self.T * jax.random.uniform(key[0], [BS, 1])
81-
eps = jax.random.normal(key[1], [BS, ndim, 1])
82-
83-
mu_t = lambda _t: state_q.apply_fn(params_q, _t)[0]
84-
S_t = lambda _t: state_q.apply_fn(params_q, _t)[1]
85-
86-
def dmudt(_t):
87-
_dmudt = jax.jacrev(lambda _t: mu_t(_t).sum(0))
88-
return _dmudt(_t).squeeze().T
89-
90-
def dSdt(_t):
91-
_dSdt = jax.jacrev(lambda _t: S_t(_t).sum(0))
92-
return _dSdt(_t).squeeze().T
79+
t = self.T * jax.random.uniform(key[0], [BS, 1], dtype=jnp.float32)
80+
eps = jax.random.normal(key[1], [BS, ndim, 1], dtype=jnp.float32)
9381

9482
def v_t(_eps, _t):
95-
S_t_val, dSdt_val = S_t(_t), dSdt(_t)
96-
_x = mu_t(_t) + jax.lax.batch_matmul(S_t_val, _eps).squeeze()
97-
dlogdx = -jax.scipy.linalg.solve_triangular(jnp.transpose(S_t_val, (0, 2, 1)), _eps)
83+
_mu_t, _S_t_val, _w_logits, _dmudt, _dSdt_val = forward_and_derivatives(state_q, _t, params_q)
84+
85+
_x = _mu_t + jax.lax.batch_matmul(_S_t_val, _eps).squeeze()
86+
dlogdx = -jax.scipy.linalg.solve_triangular(jnp.transpose(_S_t_val, (0, 2, 1)), _eps)
9887
# S_t_val_inv = jnp.transpose(jnp.linalg.inv(S_t_val), (0,2,1))
9988
# dlogdx = -jax.lax.batch_matmul(S_t_val_inv, _eps)
100-
dSigmadt = jax.lax.batch_matmul(dSdt_val, jnp.transpose(S_t_val, (0, 2, 1)))
101-
dSigmadt += jax.lax.batch_matmul(S_t_val, jnp.transpose(dSdt_val, (0, 2, 1)))
102-
u_t = dmudt(_t) - 0.5 * jax.lax.batch_matmul(dSigmadt, dlogdx).squeeze()
89+
dSigmadt = jax.lax.batch_matmul(_dSdt_val, jnp.transpose(_S_t_val, (0, 2, 1)))
90+
dSigmadt += jax.lax.batch_matmul(_S_t_val, jnp.transpose(_dSdt_val, (0, 2, 1)))
91+
u_t = _dmudt - 0.5 * jax.lax.batch_matmul(dSigmadt, dlogdx).squeeze()
10392
out = (u_t - self._drift(_x.reshape(BS, ndim), gamma)) + 0.5 * (self.xi ** 2) * dlogdx.squeeze()
10493
return out
10594

10695
loss = 0.5 * ((v_t(eps, t) / self.xi) ** 2).sum(1, keepdims=True)
107-
print(loss.shape, 'loss.shape', flush=True)
96+
print(loss.shape, 'loss.shape', 'loss.dtype', loss.dtype, flush=True)
10897
return loss.mean()
10998

110-
# ndim = self.model_q.A.shape[-1]
111-
# key = jax.random.split(key)
112-
#
113-
# t = self.T * jax.random.uniform(key[0], [BS, 1], dtype=jnp.float32)
114-
# #TODO: the following needs to be changed for num gaussians. It should be BS, num_mitures, ndim
115-
# eps = jax.random.normal(key[1], [BS, ndim, 1], dtype=jnp.float32)
116-
#
117-
# def v_t(_eps, _t):
118-
# """This function is equal to v_t * xi ** 2."""
119-
# _mu_t, _sigma_t, _w_logits, _dmudt, _dsigmadt = forward_and_derivatives(state_q, _t, params_q)
120-
# _i = jax.random.categorical(key[2], _w_logits, shape=[BS, ])
121-
#
122-
# _x = _mu_t[jnp.arange(BS), _i, None] + _sigma_t[jnp.arange(BS), _i, None] * eps
123-
#
124-
# if _mu_t.shape[1] == 1:
125-
# # This completely ignores the weights and saves some time
126-
# relative_mixture_weights = 1
127-
# else:
128-
# log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1)
129-
# relative_mixture_weights = jax.nn.softmax(_w_logits + log_q_i)[:, :, None]
130-
#
131-
# log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
132-
# u_t = (relative_mixture_weights * (1 / _sigma_t * _dsigmadt * (_x - _mu_t) + _dmudt)).sum(axis=1)
133-
#
134-
# return u_t - self._drift(_x.reshape(BS, ndim), gamma) + 0.5 * (self.xi ** 2) * log_q_t
135-
#
136-
# loss = 0.5 * ((v_t(eps, t) / self.xi) ** 2).sum(-1, keepdims=True)
137-
# return loss.mean()
138-
13999
return loss_fn
140100

141101
def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, deterministic: bool, *args, **kwargs) -> ArrayLike:
142-
raise NotImplementedError
143-
144-
# _mu_t, _sigma_t, _w_logits, _dmudt, _dsigmadt = forward_and_derivatives(state_q, t)
145-
# _x = x_t[:, None, :]
146-
#
147-
# log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1)
148-
# relative_mixture_weights = jax.nn.softmax(_w_logits + log_q_i)[:, :, None]
149-
#
150-
# _u_t = (relative_mixture_weights * (1 / _sigma_t * _dsigmadt * (_x - _mu_t) + _dmudt)).sum(axis=1)
151-
#
152-
# if deterministic:
153-
# return _u_t
154-
#
155-
# log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
156-
#
157-
# return _u_t + 0.5 * (self.xi ** 2) * log_q_t
102+
_mu_t, _S_t_val, _w_logits, _dmudt, _dSdt_val = forward_and_derivatives(state_q, t)
103+
104+
dSigmadt = jax.lax.batch_matmul(_dSdt_val, jnp.transpose(_S_t_val, (0, 2, 1)))
105+
dSigmadt += jax.lax.batch_matmul(_S_t_val, jnp.transpose(_dSdt_val, (0, 2, 1)))
106+
STdlogdx = jax.scipy.linalg.solve_triangular(_S_t_val, (x_t - _mu_t)[..., None])
107+
dlogdx = -jax.scipy.linalg.solve_triangular(jnp.transpose(_S_t_val, (0, 2, 1)), STdlogdx)
108+
109+
if deterministic:
110+
return _dmudt + (-0.5 * jax.lax.batch_matmul(dSigmadt, dlogdx)).squeeze()
111+
112+
return _dmudt + (-0.5 * jax.lax.batch_matmul(dSigmadt, dlogdx) + 0.5 * self.xi ** 2 * dlogdx).squeeze()

training/train.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,12 @@ def train_step(_state_q: TrainState, _key: ArrayLike) -> (TrainState, float):
2020
_state_q = _state_q.apply_gradients(grads=grads)
2121
return _state_q, loss
2222

23-
log_loss = False
2423
with trange(ckpt['model'].step, epochs) as pbar:
2524
for i in pbar:
2625
key, loc_key = jax.random.split(key)
2726
ckpt['model'], loss = train_step(ckpt['model'], loc_key)
2827
if loss > 1e4:
29-
log_loss = True
30-
31-
if log_loss:
32-
pbar.set_postfix(log_loss=f"{jnp.log(loss):.4f}")
28+
pbar.set_postfix(log_loss=f"{jnp.log10(loss):.4f}")
3329
else:
3430
pbar.set_postfix(loss=f"{loss:.4f}")
3531
ckpt['losses'].append(loss.item())

0 commit comments

Comments
 (0)