8
8
from calendar import c
9
9
from typing import Generator , Mapping , NamedTuple , Sequence , Tuple
10
10
11
- import jax
12
11
import numpy as np
13
-
14
- jax .config .update ("jax_platform_name" , "cpu" ) # suppress warning about no GPUs
15
-
12
+ import jax
16
13
import haiku as hk
17
14
import jax .numpy as jnp
18
15
import optax
@@ -31,12 +28,11 @@ def add_args(parser):
31
28
parser .add_argument ("--hidden_size" , type = int , default = 512 )
32
29
parser .add_argument ("--learning_rate" , type = float , default = 0.001 )
33
30
parser .add_argument ("--batch_size" , type = int , default = 128 )
34
- parser .add_argument ("--training_steps" , type = int , default = 100000 )
31
+ parser .add_argument ("--training_steps" , type = int , default = 30000 )
35
32
parser .add_argument ("--log_interval" , type = int , default = 10000 )
36
33
parser .add_argument ("--num_eval_samples" , type = int , default = 128 )
37
34
parser .add_argument ("--gpu" , default = False , action = argparse .BooleanOptionalAction )
38
35
parser .add_argument ("--random_seed" , type = int , default = 42 )
39
- parser .add_argument ("--train_dir" , type = pathlib .Path , default = "/tmp" )
40
36
41
37
42
38
def load_dataset (
@@ -65,7 +61,7 @@ def __init__(
65
61
hidden_size : int ,
66
62
output_shape : Sequence [int ] = MNIST_IMAGE_SHAPE ,
67
63
):
68
- super ().__init__ ()
64
+ super ().__init__ (name = "model" )
69
65
self ._latent_size = latent_size
70
66
self ._hidden_size = hidden_size
71
67
self ._output_shape = output_shape
@@ -93,7 +89,7 @@ class VariationalMeanField(hk.Module):
93
89
"""Mean field variational distribution q(z | x) parameterized by inference network."""
94
90
95
91
def __init__ (self , latent_size : int , hidden_size : int ):
96
- super ().__init__ ()
92
+ super ().__init__ (name = "variational" )
97
93
self ._latent_size = latent_size
98
94
self ._hidden_size = hidden_size
99
95
self .inference_network = hk .Sequential (
@@ -121,70 +117,49 @@ def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
121
117
return q_z
122
118
123
119
124
- class ModelAndVariationalOutput (NamedTuple ):
125
- p_z : tfd .Distribution
126
- p_x_given_z : tfd .Distribution
127
- q_z : tfd .Distribution
128
- z : jnp .ndarray
129
-
130
-
131
- class ModelAndVariational (hk .Module ):
132
- """Parent class for creating inputs to the variational inference algorithm."""
133
-
134
- def __init__ (self , latent_size : int , hidden_size : int , output_shape : Sequence [int ]):
135
- super ().__init__ ()
136
- self ._latent_size = latent_size
137
- self ._hidden_size = hidden_size
138
- self ._output_shape = output_shape
139
-
140
- def __call__ (self , x : jnp .ndarray ) -> ModelAndVariationalOutput :
141
- x = x .astype (jnp .float32 )
142
- q_z = VariationalMeanField (self ._latent_size , self ._hidden_size )(x )
143
- # use a single sample from variational distribution to train
144
- # shape [num_samples, batch_size, latent_size]
145
- z = q_z .sample (sample_shape = [1 ], seed = hk .next_rng_key ())
146
-
147
- p_z , p_x_given_z = Model (
148
- self ._latent_size , self ._hidden_size , MNIST_IMAGE_SHAPE
149
- )(x = x , z = z )
150
- return ModelAndVariationalOutput (p_z , p_x_given_z , q_z , z )
151
-
152
-
153
120
def main ():
121
+ start_time = time .time ()
154
122
parser = argparse .ArgumentParser ()
155
123
add_args (parser )
156
124
args = parser .parse_args ()
157
- model_and_variational = hk .transform (
158
- lambda x : ModelAndVariational (
159
- args .latent_size , args .hidden_size , MNIST_IMAGE_SHAPE
160
- )(x )
125
+ rng_seq = hk .PRNGSequence (args .random_seed )
126
+ model = hk .transform (
127
+ lambda x , z : Model (args .latent_size , args .hidden_size , MNIST_IMAGE_SHAPE )(x , z )
128
+ )
129
+ variational = hk .transform (
130
+ lambda x : VariationalMeanField (args .latent_size , args .hidden_size )(x )
161
131
)
132
+ p_params = model .init (
133
+ next (rng_seq ),
134
+ np .zeros ((1 , * MNIST_IMAGE_SHAPE )),
135
+ np .zeros ((1 , args .latent_size )),
136
+ )
137
+ q_params = variational .init (next (rng_seq ), np .zeros ((1 , * MNIST_IMAGE_SHAPE )))
138
+ params = hk .data_structures .merge (p_params , q_params )
139
+ optimizer = optax .rmsprop (args .learning_rate )
140
+ opt_state = optimizer .init (params )
162
141
163
- # @jax.jit
142
+ @jax .jit
164
143
def objective_fn (params : hk .Params , rng_key : PRNGKey , batch : Batch ) -> jnp .ndarray :
165
144
x = batch ["image" ]
166
- out : ModelAndVariationalOutput = model_and_variational .apply (params , rng_key , x )
167
- log_q_z = out .q_z .log_prob (out .z ).sum (axis = - 1 )
145
+ predicate = lambda module_name , name , value : "model" in module_name
146
+ p_params , q_params = hk .data_structures .partition (predicate , params )
147
+ q_z = variational .apply (q_params , rng_key , x )
148
+ z = q_z .sample (sample_shape = [1 ], seed = rng_key )
149
+ p_z , p_x_given_z = model .apply (p_params , rng_key , x , z )
150
+ # out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
151
+ log_q_z = q_z .log_prob (z ).sum (axis = - 1 )
168
152
# sum over last three image dimensions (width, height, channels)
169
- log_p_x_given_z = out . p_x_given_z .log_prob (x ).sum (axis = (- 3 , - 2 , - 1 ))
153
+ log_p_x_given_z = p_x_given_z .log_prob (x ).sum (axis = (- 3 , - 2 , - 1 ))
170
154
# sum over latent dimension
171
- log_p_z = out .p_z .log_prob (out .z ).sum (axis = - 1 )
172
-
155
+ log_p_z = p_z .log_prob (z ).sum (axis = - 1 )
173
156
elbo = log_p_x_given_z + log_p_z - log_q_z
174
157
# average elbo over number of samples
175
158
elbo = elbo .mean (axis = 0 )
176
159
# sum elbo over batch
177
160
elbo = elbo .sum (axis = 0 )
178
161
return - elbo
179
162
180
- rng_seq = hk .PRNGSequence (args .random_seed )
181
-
182
- params = model_and_variational .init (
183
- next (rng_seq ), np .zeros ((1 , * MNIST_IMAGE_SHAPE ))
184
- )
185
- optimizer = optax .rmsprop (args .learning_rate )
186
- opt_state = optimizer .init (params )
187
-
188
163
@jax .jit
189
164
def train_step (
190
165
params : hk .Params , rng_key : PRNGKey , opt_state : optax .OptState , batch : Batch
@@ -201,13 +176,17 @@ def importance_weighted_estimate(
201
176
) -> Tuple [jnp .ndarray , jnp .ndarray ]:
202
177
"""Estimate marginal log p(x) using importance sampling."""
203
178
x = batch ["image" ]
204
- out : ModelAndVariationalOutput = model_and_variational .apply (params , rng_key , x )
205
- log_q_z = out .q_z .log_prob (out .z ).sum (axis = - 1 )
179
+ # out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
180
+ predicate = lambda module_name , name , value : "model" in module_name
181
+ p_params , q_params = hk .data_structures .partition (predicate , params )
182
+ q_z = variational .apply (q_params , rng_key , x )
183
+ z = q_z .sample (args .num_eval_samples , seed = rng_key )
184
+ p_z , p_x_given_z = model .apply (p_params , rng_key , x , z )
185
+ log_q_z = q_z .log_prob (z ).sum (axis = - 1 )
206
186
# sum over last three image dimensions (width, height, channels)
207
- log_p_x_given_z = out . p_x_given_z .log_prob (x ).sum (axis = (- 3 , - 2 , - 1 ))
187
+ log_p_x_given_z = p_x_given_z .log_prob (x ).sum (axis = (- 3 , - 2 , - 1 ))
208
188
# sum over latent dimension
209
- log_p_z = out .p_z .log_prob (out .z ).sum (axis = - 1 )
210
-
189
+ log_p_z = p_z .log_prob (z ).sum (axis = - 1 )
211
190
elbo = log_p_x_given_z + log_p_z - log_q_z
212
191
# importance sampling of approximate marginal likelihood with q(z)
213
192
# as the proposal, and logsumexp in the sample dimension
@@ -253,15 +232,16 @@ def print_progress(step: int, examples_per_sec: float):
253
232
f"Train ELBO estimate: { train_elbo :<5.3f} \t "
254
233
f"Validation ELBO estimate: { elbo :<5.3f} \t "
255
234
f"Validation log p(x) estimate: { log_p_x :<5.3f} \t "
256
- f"Speed: { examples_per_sec :<5.0f } examples/s"
235
+ f"Speed: { examples_per_sec :<5.2e } examples/s"
257
236
)
258
237
259
238
t0 = time .time ()
260
239
for step in range (args .training_steps ):
261
240
if step % args .log_interval == 0 :
262
- examples_per_sec = args .log_interval / (time .time () - t0 )
241
+ t1 = time .time ()
242
+ examples_per_sec = args .log_interval * args .batch_size / (t1 - t0 )
263
243
print_progress (step , examples_per_sec )
264
- t0 = time . time ()
244
+ t0 = t1
265
245
params , opt_state = train_step (params , next (rng_seq ), opt_state , next (train_ds ))
266
246
267
247
test_ds = load_dataset (tfds .Split .TEST , args .batch_size , args .random_seed )
@@ -271,6 +251,7 @@ def print_progress(step: int, examples_per_sec: float):
271
251
f"Test ELBO estimate: { elbo :<5.3f} \t "
272
252
f"Test log p(x) estimate: { log_p_x :<5.3f} \t "
273
253
)
254
+ print (f"Total time: { (time .time () - start_time ) / 60 :.3f} minutes" )
274
255
275
256
276
257
if __name__ == "__main__" :
0 commit comments