Skip to content

Commit f7778b4

Browse files
committed
Add low rank spline implementation
1 parent 99d077c commit f7778b4

File tree

7 files changed

+103
-33
lines changed

7 files changed

+103
-33
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
save_dir: ./out/toy/mueller_single_gaussian_low_rank_spline
2+
3+
test_system: mueller_brown
4+
ode: first_order
5+
parameterization: low_rank
6+
T: 275e-4
7+
xi: 5.0
8+
gamma: 1.0
9+
10+
model: spline
11+
num_points: 25
12+
13+
num_gaussians: 1
14+
trainable_weights: False
15+
base_sigma: 2.5e-2
16+
17+
epochs: 100000
18+
save_interval: 10000
19+
BS: 512
20+
21+
num_paths: 1000
22+
dt: 1e-4

main.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
help="Whether the weights of the mixture model are trainable.")
4747

4848
# model parameters
49+
parser.add_argument('--model', type=str, choices=['mlp', 'spline'], default='mlp',
50+
help="The model that will be used. Note that spline will not work with all configurations.")
51+
52+
parser.add_argument('--num_points', type=int, default=100, help="Number of points in the spline model.")
53+
54+
# MLP arguments
4955
parser.add_argument('--hidden_layers', nargs='+', type=int, help='The dimensions of the hidden layer of the MLP.',
5056
default=[128, 128, 128])
5157
parser.add_argument('--activation', type=str, default='swish', choices=['tanh', 'relu', 'swish'],
@@ -75,7 +81,7 @@
7581
# plotting
7682
parser.add_argument('--log_plots', type=str2bool, nargs='?', const=True, default=False,
7783
help="Save plots in log scale where possible")
78-
84+
parser.add_argument('--extension', type=str, default='pdf', help="Extension of the saved plots.")
7985

8086
def main():
8187
print("!!!!Next todos: plot ALDP")
@@ -120,7 +126,10 @@ def main():
120126

121127
# You can play around with any model here
122128
# The chosen setup will append a final layer so that the output is mu, sigma, and weights
123-
model = MLP(args.hidden_layers, args.activation, args.resnet)
129+
model = None
130+
if args.model == 'MLP':
131+
model = MLP(args.hidden_layers, args.activation, args.resnet)
132+
124133
setup = qsetup.construct(system, model, xi, A, B, args)
125134

126135
key = jax.random.PRNGKey(args.seed)
@@ -166,7 +175,7 @@ def main():
166175
print("Warning: Loss contains NaNs")
167176
plt.plot(ckpt['losses'])
168177
log_scale(args.log_plots, False, True)
169-
show_or_save_fig(args.save_dir, 'loss_plot.pdf')
178+
show_or_save_fig(args.save_dir, 'loss_plot', args.extension)
170179

171180
print("!!!TODO: how to plot this nicely?")
172181
# t = args.T * jnp.linspace(0, 1, args.BS, dtype=jnp.float32).reshape((-1, 1))
@@ -192,14 +201,14 @@ def main():
192201
if system.plot:
193202
# In case we have a second order integration scheme, we remove the velocity for plotting
194203
system.plot(title='Deterministic Paths', trajectories=x_t_det[:, :, :system.A.shape[0]])
195-
show_or_save_fig(args.save_dir, 'paths_deterministic.pdf')
204+
show_or_save_fig(args.save_dir, 'paths_deterministic', args.extension)
196205

197206
key, path_key = jax.random.split(key)
198207
x_t_stoch = setup.sample_paths(state_q, x_0, args.dt, args.T, args.BS, path_key)
199208

200209
if system.plot:
201210
system.plot(title='Stochastic Paths', trajectories=x_t_stoch[:, :, :system.A.shape[0]])
202-
show_or_save_fig(args.save_dir, 'paths_stochastic.pdf')
211+
show_or_save_fig(args.save_dir, 'paths_stochastic', args.extension)
203212

204213

205214
if __name__ == '__main__':

model/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from typing import Tuple
2-
32
from flax import linen as nn
43

54

training/qsetup.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def u_t(self, state_q: TrainState, t: ArrayLike, x_t: ArrayLike, deterministic:
7373
raise NotImplementedError
7474

7575

76-
def construct(system: System, model: nn.module, xi: float, A: ArrayLike, B: ArrayLike,
76+
def construct(system: System, model: Optional[nn.module], xi: float, A: ArrayLike, B: ArrayLike,
7777
args: argparse.Namespace) -> QSetup:
7878
from training.setups import diagonal, lowrank
7979

@@ -83,14 +83,21 @@ def construct(system: System, model: nn.module, xi: float, A: ArrayLike, B: Arra
8383
transform = aldp.InternalCoordinateWrapper(system.A.reshape(1, -1))
8484

8585
if args.parameterization == 'diagonal':
86-
wrapped_module = diagonal.DiagonalWrapper(
86+
if args.model == 'spline':
87+
raise ValueError("Spline model is not supported with diagonal parameterization")
88+
model = diagonal.DiagonalWrapper(
8789
model, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
8890
)
89-
return diagonal.DiagonalSetup(system, wrapped_module, xi, args.ode, args.T)
91+
return diagonal.DiagonalSetup(system, model, xi, args.ode, args.T)
9092
elif args.parameterization == 'low_rank':
91-
wrapped_module = lowrank.LowRankWrapper(
92-
model, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
93-
)
94-
return lowrank.LowRankSetup(system, wrapped_module, xi, args.ode, args.T)
93+
if args.model == 'spline':
94+
model = lowrank.LowRankSpline(
95+
args.num_points, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
96+
)
97+
else:
98+
model = lowrank.LowRankWrapper(
99+
model, args.T, transform, A, B, args.num_gaussians, args.trainable_weights, args.base_sigma
100+
)
101+
return lowrank.LowRankSetup(system, model, xi, args.ode, args.T)
95102
else:
96103
raise ValueError(f"Unknown parameterization: {args.parameterization}")

training/setups/diagonal.py

-6
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,7 @@ def _post_process(self, h: ArrayLike, t: ArrayLike):
4141
return mu, sigma, w_logits
4242

4343

44-
@dataclass
4544
class DiagonalSetup(DriftedSetup):
46-
model_q: DiagonalWrapper
47-
48-
def __init__(self, system: System, model_q: DiagonalWrapper, xi: ArrayLike, ode: str, T: float):
49-
super().__init__(system, model_q, xi, ode, T)
50-
5145
def construct_loss(self, state_q: TrainState, gamma: float, BS: int) -> Callable[
5246
[Union[FrozenVariableDict, Dict[str, Any]], ArrayLike], ArrayLike]:
5347

training/setups/lowrank.py

+51-12
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,64 @@
1-
from dataclasses import dataclass
21
from jax.typing import ArrayLike
32
from flax import linen as nn
43
import jax.numpy as jnp
5-
from typing import Union, Dict, Any, Callable, Tuple
4+
from typing import Union, Dict, Any, Callable, Tuple, Optional
65
from flax.training.train_state import TrainState
76
import jax
87
from flax.typing import FrozenVariableDict
98
from model.utils import WrappedModule
10-
from systems import System
119
from training.setups.drift import DriftedSetup
1210
from training.utils import forward_and_derivatives
1311

12+
interp = jax.vmap(jnp.interp, in_axes=(None, None, 1))
13+
14+
15+
class LowRankSpline(nn.Module):
16+
n_points: int
17+
T: float
18+
transform: Optional[Callable[[Any], Any]]
19+
A: ArrayLike
20+
B: ArrayLike
21+
num_mixtures: int
22+
trainable_weights: bool
23+
base_sigma: float
24+
25+
@nn.compact
26+
def __call__(self, t):
27+
print("WARNING: Mixtures for low rank not yet implemented!")
28+
assert self.num_mixtures == 1, "Mixtures for low rank not yet implemented!"
29+
30+
ndim = self.A.shape[0]
31+
t = t / self.T
32+
t_grid = jnp.linspace(0, 1, self.n_points, dtype=jnp.float32)
33+
S_0 = jnp.log(self.base_sigma) * jnp.eye(ndim, dtype=jnp.float32)
34+
S_0_vec = S_0[jnp.tril_indices(ndim)]
35+
mu_params = self.param('mu_params', lambda rng: jnp.linspace(self.A, self.B, self.n_points)[1:-1])
36+
S_params = self.param('S_params', lambda rng: jnp.linspace(S_0_vec, S_0_vec, self.n_points)[1:-1])
37+
y_grid = jnp.concatenate([self.A.reshape(1, -1), mu_params, self.B.reshape(1, -1)])
38+
S_grid = jnp.concatenate([S_0_vec[None, :], S_params, S_0_vec[None, :]])
39+
40+
@jax.vmap
41+
def get_tril(v):
42+
a = jnp.zeros((ndim, ndim), dtype=jnp.float32)
43+
a = a.at[jnp.tril_indices(ndim)].set(v)
44+
return a
45+
46+
mu = interp(t.flatten(), t_grid, y_grid).T
47+
S = interp(t.flatten(), t_grid, S_grid).T
48+
S = get_tril(S)
49+
S = jnp.tril(2 * jax.nn.sigmoid(S) - 1.0, k=-1) + jnp.eye(ndim, dtype=jnp.float32)[None, ...] * jnp.exp(S)
50+
51+
if self.trainable_weights:
52+
w_logits = self.param('w_logits', nn.initializers.zeros_init(), (self.num_mixtures,), dtype=jnp.float32)
53+
else:
54+
w_logits = jnp.zeros(self.num_mixtures, dtype=jnp.float32)
55+
56+
out = (mu, S, w_logits)
57+
if self.transform:
58+
out = self.transform(out)
59+
60+
return out
61+
1462

1563
class LowRankWrapper(WrappedModule):
1664
A: ArrayLike
@@ -57,19 +105,10 @@ def get_tril(v):
57105
else:
58106
w_logits = jnp.zeros(num_mixtures, dtype=jnp.float32)
59107

60-
print('mu.shape', mu.shape)
61-
print('S.shape', S.shape)
62-
63108
return mu, S, w_logits
64109

65110

66-
@dataclass
67111
class LowRankSetup(DriftedSetup):
68-
model_q: LowRankWrapper
69-
70-
def __init__(self, system: System, model_q: LowRankWrapper, xi: ArrayLike, order: str, T: float):
71-
super().__init__(system, model_q, xi, order, T)
72-
73112
def construct_loss(self, state_q: TrainState, gamma: float, BS: int) -> Callable[
74113
[Union[FrozenVariableDict, Dict[str, Any]], ArrayLike], ArrayLike]:
75114
def loss_fn(params_q: Union[FrozenVariableDict, Dict[str, Any]], key: ArrayLike) -> ArrayLike:

utils/plot.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ def log_scale(log_plot: bool, x: bool, y: bool):
1414
plt.gca().set_yscale('log')
1515

1616

17-
def show_or_save_fig(save_dir: Optional[str], name: str):
17+
def show_or_save_fig(save_dir: Optional[str], name: str, extension: str):
1818
if save_dir is not None:
19-
plt.savefig(f'{save_dir}/{name}', bbox_inches='tight')
19+
plt.savefig(f'{save_dir}/{name}.{extension}', bbox_inches='tight')
2020
plt.clf()
2121
else:
2222
plt.show()

0 commit comments

Comments
 (0)