@@ -19,15 +19,16 @@ class LowRankWrapper(WrappedModule):
19
19
trainable_weights : bool
20
20
base_sigma : float
21
21
22
- def _pre_process (self , t : ArrayLike ) -> Tuple [ArrayLike , Tuple [ArrayLike , ArrayLike , ArrayLike ]]:
22
+ def _pre_process (self , t : ArrayLike ) -> Tuple [Tuple [ ArrayLike ] , Tuple [ArrayLike , ArrayLike , ArrayLike ]]:
23
23
ndim = self .A .shape [0 ]
24
24
25
25
h_mu = (1 - t ) * self .A + t * self .B
26
- S_0 = jnp .eye (ndim )
27
- S_0 = S_0 * jnp .vstack ([self .base_sigma * jnp .ones ((ndim // 2 , 1 )), self .base_sigma * jnp .ones ((ndim // 2 , 1 ))])
26
+ S_0 = jnp .eye (ndim , dtype = jnp .float32 )
27
+ S_0 = S_0 * jnp .vstack ([self .base_sigma * jnp .ones ((ndim // 2 , 1 ), dtype = jnp .float32 ),
28
+ self .base_sigma * jnp .ones ((ndim // 2 , 1 ), dtype = jnp .float32 )])
28
29
S_0 = S_0 [None , ...]
29
30
h_S = (1 - 2 * t * (1 - t ))[..., None ] * S_0
30
- return jnp .hstack ([h_mu , h_S .reshape (- 1 , ndim * ndim ), t ]), (h_mu , h_S , t )
31
+ return ( jnp .hstack ([h_mu , h_S .reshape (- 1 , ndim * ndim ), t ]), ), (h_mu , h_S , t )
31
32
32
33
@nn .compact
33
34
def _post_process (self , h : ArrayLike , h_mu : ArrayLike , h_S : ArrayLike , t : ArrayLike ):
@@ -43,18 +44,18 @@ def _post_process(self, h: ArrayLike, h_mu: ArrayLike, h_S: ArrayLike, t: ArrayL
43
44
44
45
@jax .vmap
45
46
def get_tril (v ):
46
- a = jnp .zeros ((ndim , ndim ))
47
+ a = jnp .zeros ((ndim , ndim ), dtype = jnp . float32 )
47
48
a = a .at [jnp .tril_indices (ndim )].set (v )
48
49
return a
49
50
50
51
S = get_tril (h [:, ndim :])
51
- S = jnp .tril (2 * jax .nn .sigmoid (S ) - 1.0 , k = - 1 ) + jnp .eye (ndim )[None , ...] * jnp .exp (S )
52
+ S = jnp .tril (2 * jax .nn .sigmoid (S ) - 1.0 , k = - 1 ) + jnp .eye (ndim , dtype = jnp . float32 )[None , ...] * jnp .exp (S )
52
53
S = h_S + 2 * ((1 - t ) * t )[..., None ] * S
53
54
54
55
if self .trainable_weights :
55
- w_logits = self .param ('w_logits' , nn .initializers .zeros_init (), (num_mixtures ,))
56
+ w_logits = self .param ('w_logits' , nn .initializers .zeros_init (), (num_mixtures ,), dtype = jnp . float32 )
56
57
else :
57
- w_logits = jnp .zeros (num_mixtures )
58
+ w_logits = jnp .zeros (num_mixtures , dtype = jnp . float32 )
58
59
59
60
print ('mu.shape' , mu .shape )
60
61
print ('S.shape' , S .shape )
@@ -65,93 +66,47 @@ def get_tril(v):
65
66
@dataclass
66
67
class LowRankSetup (DriftedSetup ):
67
68
model_q : LowRankWrapper
68
- T : float
69
69
70
70
def __init__ (self , system : System , model_q : LowRankWrapper , xi : ArrayLike , order : str , T : float ):
71
- super ().__init__ (system , model_q , xi , order )
72
- self .T = T
71
+ super ().__init__ (system , model_q , xi , order , T )
73
72
74
73
def construct_loss (self , state_q : TrainState , gamma : float , BS : int ) -> Callable [
75
74
[Union [FrozenVariableDict , Dict [str , Any ]], ArrayLike ], ArrayLike ]:
76
75
def loss_fn (params_q : Union [FrozenVariableDict , Dict [str , Any ]], key : ArrayLike ) -> ArrayLike :
77
76
ndim = self .model_q .A .shape [- 1 ]
78
77
79
78
key = jax .random .split (key )
80
- t = self .T * jax .random .uniform (key [0 ], [BS , 1 ])
81
- eps = jax .random .normal (key [1 ], [BS , ndim , 1 ])
82
-
83
- mu_t = lambda _t : state_q .apply_fn (params_q , _t )[0 ]
84
- S_t = lambda _t : state_q .apply_fn (params_q , _t )[1 ]
85
-
86
- def dmudt (_t ):
87
- _dmudt = jax .jacrev (lambda _t : mu_t (_t ).sum (0 ))
88
- return _dmudt (_t ).squeeze ().T
89
-
90
- def dSdt (_t ):
91
- _dSdt = jax .jacrev (lambda _t : S_t (_t ).sum (0 ))
92
- return _dSdt (_t ).squeeze ().T
79
+ t = self .T * jax .random .uniform (key [0 ], [BS , 1 ], dtype = jnp .float32 )
80
+ eps = jax .random .normal (key [1 ], [BS , ndim , 1 ], dtype = jnp .float32 )
93
81
94
82
def v_t (_eps , _t ):
95
- S_t_val , dSdt_val = S_t (_t ), dSdt (_t )
96
- _x = mu_t (_t ) + jax .lax .batch_matmul (S_t_val , _eps ).squeeze ()
97
- dlogdx = - jax .scipy .linalg .solve_triangular (jnp .transpose (S_t_val , (0 , 2 , 1 )), _eps )
83
+ _mu_t , _S_t_val , _w_logits , _dmudt , _dSdt_val = forward_and_derivatives (state_q , _t , params_q )
84
+
85
+ _x = _mu_t + jax .lax .batch_matmul (_S_t_val , _eps ).squeeze ()
86
+ dlogdx = - jax .scipy .linalg .solve_triangular (jnp .transpose (_S_t_val , (0 , 2 , 1 )), _eps )
98
87
# S_t_val_inv = jnp.transpose(jnp.linalg.inv(S_t_val), (0,2,1))
99
88
# dlogdx = -jax.lax.batch_matmul(S_t_val_inv, _eps)
100
- dSigmadt = jax .lax .batch_matmul (dSdt_val , jnp .transpose (S_t_val , (0 , 2 , 1 )))
101
- dSigmadt += jax .lax .batch_matmul (S_t_val , jnp .transpose (dSdt_val , (0 , 2 , 1 )))
102
- u_t = dmudt ( _t ) - 0.5 * jax .lax .batch_matmul (dSigmadt , dlogdx ).squeeze ()
89
+ dSigmadt = jax .lax .batch_matmul (_dSdt_val , jnp .transpose (_S_t_val , (0 , 2 , 1 )))
90
+ dSigmadt += jax .lax .batch_matmul (_S_t_val , jnp .transpose (_dSdt_val , (0 , 2 , 1 )))
91
+ u_t = _dmudt - 0.5 * jax .lax .batch_matmul (dSigmadt , dlogdx ).squeeze ()
103
92
out = (u_t - self ._drift (_x .reshape (BS , ndim ), gamma )) + 0.5 * (self .xi ** 2 ) * dlogdx .squeeze ()
104
93
return out
105
94
106
95
loss = 0.5 * ((v_t (eps , t ) / self .xi ) ** 2 ).sum (1 , keepdims = True )
107
- print (loss .shape , 'loss.shape' , flush = True )
96
+ print (loss .shape , 'loss.shape' , 'loss.dtype' , loss . dtype , flush = True )
108
97
return loss .mean ()
109
98
110
- # ndim = self.model_q.A.shape[-1]
111
- # key = jax.random.split(key)
112
- #
113
- # t = self.T * jax.random.uniform(key[0], [BS, 1], dtype=jnp.float32)
114
- # #TODO: the following needs to be changed for num gaussians. It should be BS, num_mitures, ndim
115
- # eps = jax.random.normal(key[1], [BS, ndim, 1], dtype=jnp.float32)
116
- #
117
- # def v_t(_eps, _t):
118
- # """This function is equal to v_t * xi ** 2."""
119
- # _mu_t, _sigma_t, _w_logits, _dmudt, _dsigmadt = forward_and_derivatives(state_q, _t, params_q)
120
- # _i = jax.random.categorical(key[2], _w_logits, shape=[BS, ])
121
- #
122
- # _x = _mu_t[jnp.arange(BS), _i, None] + _sigma_t[jnp.arange(BS), _i, None] * eps
123
- #
124
- # if _mu_t.shape[1] == 1:
125
- # # This completely ignores the weights and saves some time
126
- # relative_mixture_weights = 1
127
- # else:
128
- # log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1)
129
- # relative_mixture_weights = jax.nn.softmax(_w_logits + log_q_i)[:, :, None]
130
- #
131
- # log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
132
- # u_t = (relative_mixture_weights * (1 / _sigma_t * _dsigmadt * (_x - _mu_t) + _dmudt)).sum(axis=1)
133
- #
134
- # return u_t - self._drift(_x.reshape(BS, ndim), gamma) + 0.5 * (self.xi ** 2) * log_q_t
135
- #
136
- # loss = 0.5 * ((v_t(eps, t) / self.xi) ** 2).sum(-1, keepdims=True)
137
- # return loss.mean()
138
-
139
99
return loss_fn
140
100
141
101
def u_t (self , state_q : TrainState , t : ArrayLike , x_t : ArrayLike , deterministic : bool , * args , ** kwargs ) -> ArrayLike :
142
- raise NotImplementedError
143
-
144
- # _mu_t, _sigma_t, _w_logits, _dmudt, _dsigmadt = forward_and_derivatives(state_q, t)
145
- # _x = x_t[:, None, :]
146
- #
147
- # log_q_i = jax.scipy.stats.norm.logpdf(_x, _mu_t, _sigma_t).sum(-1)
148
- # relative_mixture_weights = jax.nn.softmax(_w_logits + log_q_i)[:, :, None]
149
- #
150
- # _u_t = (relative_mixture_weights * (1 / _sigma_t * _dsigmadt * (_x - _mu_t) + _dmudt)).sum(axis=1)
151
- #
152
- # if deterministic:
153
- # return _u_t
154
- #
155
- # log_q_t = -(relative_mixture_weights / (_sigma_t ** 2) * (_x - _mu_t)).sum(axis=1)
156
- #
157
- # return _u_t + 0.5 * (self.xi ** 2) * log_q_t
102
+ _mu_t , _S_t_val , _w_logits , _dmudt , _dSdt_val = forward_and_derivatives (state_q , t )
103
+
104
+ dSigmadt = jax .lax .batch_matmul (_dSdt_val , jnp .transpose (_S_t_val , (0 , 2 , 1 )))
105
+ dSigmadt += jax .lax .batch_matmul (_S_t_val , jnp .transpose (_dSdt_val , (0 , 2 , 1 )))
106
+ STdlogdx = jax .scipy .linalg .solve_triangular (_S_t_val , (x_t - _mu_t )[..., None ])
107
+ dlogdx = - jax .scipy .linalg .solve_triangular (jnp .transpose (_S_t_val , (0 , 2 , 1 )), STdlogdx )
108
+
109
+ if deterministic :
110
+ return _dmudt + (- 0.5 * jax .lax .batch_matmul (dSigmadt , dlogdx )).squeeze ()
111
+
112
+ return _dmudt + (- 0.5 * jax .lax .batch_matmul (dSigmadt , dlogdx ) + 0.5 * self .xi ** 2 * dlogdx ).squeeze ()
0 commit comments