Skip to content

Commit f212232

Browse files
author
Jaan Altosaar
committed
make jax implementation modular and update timing info
1 parent f049960 commit f212232

5 files changed

+160
-103
lines changed

.env

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@
44
JAX_PLATFORM_NAME=cpu
55

66
# suppress tensorflow warnings
7-
TF_CPP_MIN_LOG_LEVEL=2
7+
TF_CPP_MIN_LOG_LEVEL=2
8+
9+
TFDS_DATA_DIR=/scratch/gpfs/altosaar/tensorflow_datasets

README.md

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,20 @@ Variational inference is used to fit the model to binarized MNIST handwritten di
99

1010
Blog post: https://jaan.io/what-is-variational-autoencoder-vae-tutorial/
1111

12-
Example output with importance sampling for estimating the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. Final marginal likelihood on the test set was `-97.10` nats after 65k iterations.
12+
13+
## PyTorch implementation
14+
15+
(anaconda environment is in `environment-jax.yml`)
16+
17+
Importance sampling is used to estimate the marginal likelihood on Hugo Larochelle's Binary MNIST dataset. The final marginal likelihood on the test set was `-97.10` nats is comparable to published numbers.
1318

1419
```
15-
$ python train_variational_autoencoder_pytorch.py --variational mean-field
16-
step: 0 train elbo: -558.28
17-
step: 0 valid elbo: -392.78 valid log p(x): -359.91
18-
step: 10000 train elbo: -106.67
19-
step: 10000 valid elbo: -109.12 valid log p(x): -103.11
20-
step: 20000 train elbo: -107.28
21-
step: 20000 valid elbo: -105.65 valid log p(x): -99.74
20+
$ python train_variational_autoencoder_pytorch.py --variational mean-field --use_gpu --data_dir $DAT --max_iterations 30000 --log_interval 10000
21+
Step 0 Train ELBO estimate: -558.027 Validation ELBO estimate: -384.432 Validation log p(x) estimate: -355.430 Speed: 2.72e+06 examples/s
22+
Step 10000 Train ELBO estimate: -111.323 Validation ELBO estimate: -109.048 Validation log p(x) estimate: -103.746 Speed: 2.64e+04 examples/s
23+
Step 20000 Train ELBO estimate: -103.013 Validation ELBO estimate: -107.655 Validation log p(x) estimate: -101.275 Speed: 2.63e+04 examples/s
24+
Step 29999 Test ELBO estimate: -106.642 Test log p(x) estimate: -100.309
25+
Total time: 2.49 minutes
2226
```
2327

2428

@@ -36,12 +40,14 @@ step: 30000 train elbo: -98.70
3640
step: 30000 valid elbo: -103.76 valid log p(x): -97.71
3741
```
3842

39-
Using jax (anaconda environment is in `environment-jax.yml`):
43+
## jax implementation
44+
45+
Using jax (anaconda environment is in `environment-jax.yml`), to get a 3x speedup over pytorch:
4046
```
41-
Step 0 Train ELBO estimate: -565.785 Validation ELBO estimate: -565.775 Validation log p(x) estimate: -565.775 Speed: 3813003636 examples/s
42-
Step 10000 Train ELBO estimate: -99.048 Validation ELBO estimate: -105.412 Validation log p(x) estimate: -105.412 Speed: 134 examples/s
43-
Step 20000 Train ELBO estimate: -108.399 Validation ELBO estimate: -105.191 Validation log p(x) estimate: -105.191 Speed: 140 examples/s
44-
Step 30000 Train ELBO estimate: -100.839 Validation ELBO estimate: -105.404 Validation log p(x) estimate: -105.404 Speed: 139 examples/s
45-
Step 40000 Train ELBO estimate: -97.761 Validation ELBO estimate: -105.382 Validation log p(x) estimate: -105.382 Speed: 139 examples/s
46-
Step 50000 Train ELBO estimate: -98.228 Validation ELBO estimate: -105.718 Validation log p(x) estimate: -105.718 Speed: 139 examples/s
47+
$ python train_variational_autoencoder_jax.py --gpu
48+
Step 0 Train ELBO estimate: -566.059 Validation ELBO estimate: -565.755 Validation log p(x) estimate: -557.914 Speed: 2.56e+11 examples/s
49+
Step 10000 Train ELBO estimate: -98.560 Validation ELBO estimate: -105.725 Validation log p(x) estimate: -98.973 Speed: 7.03e+04 examples/s
50+
Step 20000 Train ELBO estimate: -109.794 Validation ELBO estimate: -105.756 Validation log p(x) estimate: -97.914 Speed: 4.26e+04 examples/s
51+
Step 29999 Test ELBO estimate: -104.867 Test log p(x) estimate: -96.716
52+
Total time: 0.810 minutes
4753
```

environment-pytorch.yml

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
name: /scratch/gpfs/altosaar/environment-pytorch
2+
channels:
3+
- pytorch
4+
- nvidia
5+
- defaults
6+
dependencies:
7+
- _libgcc_mutex=0.1=main
8+
- blas=1.0=mkl
9+
- bzip2=1.0.8=h7b6447c_0
10+
- ca-certificates=2021.4.13=h06a4308_1
11+
- certifi=2020.12.5=py38h06a4308_0
12+
- cudatoolkit=11.1.74=h6bb024c_0
13+
- ffmpeg=4.3=hf484d3e_0
14+
- freetype=2.10.4=h5ab3b9f_0
15+
- gmp=6.2.1=h2531618_2
16+
- gnutls=3.6.15=he1e5248_0
17+
- h5py=2.10.0=py38hd6299e0_1
18+
- hdf5=1.10.6=hb1b8bf9_0
19+
- intel-openmp=2021.2.0=h06a4308_610
20+
- jpeg=9b=h024ee3a_2
21+
- lame=3.100=h7b6447c_0
22+
- lcms2=2.12=h3be6417_0
23+
- ld_impl_linux-64=2.33.1=h53a641e_7
24+
- libffi=3.3=he6710b0_2
25+
- libgcc-ng=9.1.0=hdf63c60_0
26+
- libgfortran-ng=7.3.0=hdf63c60_0
27+
- libiconv=1.15=h63c8f33_5
28+
- libidn2=2.3.1=h27cfd23_0
29+
- libpng=1.6.37=hbc83047_0
30+
- libstdcxx-ng=9.1.0=hdf63c60_0
31+
- libtasn1=4.16.0=h27cfd23_0
32+
- libtiff=4.1.0=h2733197_1
33+
- libunistring=0.9.10=h27cfd23_0
34+
- libuv=1.40.0=h7b6447c_0
35+
- lz4-c=1.9.3=h2531618_0
36+
- mkl=2021.2.0=h06a4308_296
37+
- mkl-service=2.3.0=py38h27cfd23_1
38+
- mkl_fft=1.3.0=py38h42c9631_2
39+
- mkl_random=1.2.1=py38ha9443f7_2
40+
- ncurses=6.2=he6710b0_1
41+
- nettle=3.7.2=hbbd107a_1
42+
- ninja=1.10.2=hff7bd54_1
43+
- numpy=1.20.2=py38h2d18471_0
44+
- numpy-base=1.20.2=py38hfae3a4d_0
45+
- olefile=0.46=py_0
46+
- openh264=2.1.0=hd408876_0
47+
- openssl=1.1.1k=h27cfd23_0
48+
- pillow=8.2.0=py38he98fc37_0
49+
- pip=21.1.1=py38h06a4308_0
50+
- python=3.8.10=hdb3f193_7
51+
- pytorch=1.8.1=py3.8_cuda11.1_cudnn8.0.5_0
52+
- readline=8.1=h27cfd23_0
53+
- setuptools=52.0.0=py38h06a4308_0
54+
- six=1.15.0=py38h06a4308_0
55+
- sqlite=3.35.4=hdfb4753_0
56+
- tk=8.6.10=hbc83047_0
57+
- torchaudio=0.8.1=py38
58+
- torchvision=0.9.1=py38_cu111
59+
- typing_extensions=3.7.4.3=pyha847dfd_0
60+
- wheel=0.36.2=pyhd3eb1b0_0
61+
- xz=5.2.5=h7b6447c_0
62+
- zlib=1.2.11=h7b6447c_3
63+
- zstd=1.4.9=haebb681_0
64+
prefix: /scratch/gpfs/altosaar/environment-pytorch

train_variational_autoencoder_jax.py

Lines changed: 44 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,8 @@
88
from calendar import c
99
from typing import Generator, Mapping, NamedTuple, Sequence, Tuple
1010

11-
import jax
1211
import numpy as np
13-
14-
jax.config.update("jax_platform_name", "cpu") # suppress warning about no GPUs
15-
12+
import jax
1613
import haiku as hk
1714
import jax.numpy as jnp
1815
import optax
@@ -31,12 +28,11 @@ def add_args(parser):
3128
parser.add_argument("--hidden_size", type=int, default=512)
3229
parser.add_argument("--learning_rate", type=float, default=0.001)
3330
parser.add_argument("--batch_size", type=int, default=128)
34-
parser.add_argument("--training_steps", type=int, default=100000)
31+
parser.add_argument("--training_steps", type=int, default=30000)
3532
parser.add_argument("--log_interval", type=int, default=10000)
3633
parser.add_argument("--num_eval_samples", type=int, default=128)
3734
parser.add_argument("--gpu", default=False, action=argparse.BooleanOptionalAction)
3835
parser.add_argument("--random_seed", type=int, default=42)
39-
parser.add_argument("--train_dir", type=pathlib.Path, default="/tmp")
4036

4137

4238
def load_dataset(
@@ -65,7 +61,7 @@ def __init__(
6561
hidden_size: int,
6662
output_shape: Sequence[int] = MNIST_IMAGE_SHAPE,
6763
):
68-
super().__init__()
64+
super().__init__(name="model")
6965
self._latent_size = latent_size
7066
self._hidden_size = hidden_size
7167
self._output_shape = output_shape
@@ -93,7 +89,7 @@ class VariationalMeanField(hk.Module):
9389
"""Mean field variational distribution q(z | x) parameterized by inference network."""
9490

9591
def __init__(self, latent_size: int, hidden_size: int):
96-
super().__init__()
92+
super().__init__(name="variational")
9793
self._latent_size = latent_size
9894
self._hidden_size = hidden_size
9995
self.inference_network = hk.Sequential(
@@ -121,70 +117,49 @@ def __call__(self, x: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
121117
return q_z
122118

123119

124-
class ModelAndVariationalOutput(NamedTuple):
125-
p_z: tfd.Distribution
126-
p_x_given_z: tfd.Distribution
127-
q_z: tfd.Distribution
128-
z: jnp.ndarray
129-
130-
131-
class ModelAndVariational(hk.Module):
132-
"""Parent class for creating inputs to the variational inference algorithm."""
133-
134-
def __init__(self, latent_size: int, hidden_size: int, output_shape: Sequence[int]):
135-
super().__init__()
136-
self._latent_size = latent_size
137-
self._hidden_size = hidden_size
138-
self._output_shape = output_shape
139-
140-
def __call__(self, x: jnp.ndarray) -> ModelAndVariationalOutput:
141-
x = x.astype(jnp.float32)
142-
q_z = VariationalMeanField(self._latent_size, self._hidden_size)(x)
143-
# use a single sample from variational distribution to train
144-
# shape [num_samples, batch_size, latent_size]
145-
z = q_z.sample(sample_shape=[1], seed=hk.next_rng_key())
146-
147-
p_z, p_x_given_z = Model(
148-
self._latent_size, self._hidden_size, MNIST_IMAGE_SHAPE
149-
)(x=x, z=z)
150-
return ModelAndVariationalOutput(p_z, p_x_given_z, q_z, z)
151-
152-
153120
def main():
121+
start_time = time.time()
154122
parser = argparse.ArgumentParser()
155123
add_args(parser)
156124
args = parser.parse_args()
157-
model_and_variational = hk.transform(
158-
lambda x: ModelAndVariational(
159-
args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE
160-
)(x)
125+
rng_seq = hk.PRNGSequence(args.random_seed)
126+
model = hk.transform(
127+
lambda x, z: Model(args.latent_size, args.hidden_size, MNIST_IMAGE_SHAPE)(x, z)
128+
)
129+
variational = hk.transform(
130+
lambda x: VariationalMeanField(args.latent_size, args.hidden_size)(x)
161131
)
132+
p_params = model.init(
133+
next(rng_seq),
134+
np.zeros((1, *MNIST_IMAGE_SHAPE)),
135+
np.zeros((1, args.latent_size)),
136+
)
137+
q_params = variational.init(next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE)))
138+
params = hk.data_structures.merge(p_params, q_params)
139+
optimizer = optax.rmsprop(args.learning_rate)
140+
opt_state = optimizer.init(params)
162141

163-
# @jax.jit
142+
@jax.jit
164143
def objective_fn(params: hk.Params, rng_key: PRNGKey, batch: Batch) -> jnp.ndarray:
165144
x = batch["image"]
166-
out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
167-
log_q_z = out.q_z.log_prob(out.z).sum(axis=-1)
145+
predicate = lambda module_name, name, value: "model" in module_name
146+
p_params, q_params = hk.data_structures.partition(predicate, params)
147+
q_z = variational.apply(q_params, rng_key, x)
148+
z = q_z.sample(sample_shape=[1], seed=rng_key)
149+
p_z, p_x_given_z = model.apply(p_params, rng_key, x, z)
150+
# out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
151+
log_q_z = q_z.log_prob(z).sum(axis=-1)
168152
# sum over last three image dimensions (width, height, channels)
169-
log_p_x_given_z = out.p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1))
153+
log_p_x_given_z = p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1))
170154
# sum over latent dimension
171-
log_p_z = out.p_z.log_prob(out.z).sum(axis=-1)
172-
155+
log_p_z = p_z.log_prob(z).sum(axis=-1)
173156
elbo = log_p_x_given_z + log_p_z - log_q_z
174157
# average elbo over number of samples
175158
elbo = elbo.mean(axis=0)
176159
# sum elbo over batch
177160
elbo = elbo.sum(axis=0)
178161
return -elbo
179162

180-
rng_seq = hk.PRNGSequence(args.random_seed)
181-
182-
params = model_and_variational.init(
183-
next(rng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE))
184-
)
185-
optimizer = optax.rmsprop(args.learning_rate)
186-
opt_state = optimizer.init(params)
187-
188163
@jax.jit
189164
def train_step(
190165
params: hk.Params, rng_key: PRNGKey, opt_state: optax.OptState, batch: Batch
@@ -201,13 +176,17 @@ def importance_weighted_estimate(
201176
) -> Tuple[jnp.ndarray, jnp.ndarray]:
202177
"""Estimate marginal log p(x) using importance sampling."""
203178
x = batch["image"]
204-
out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
205-
log_q_z = out.q_z.log_prob(out.z).sum(axis=-1)
179+
# out: ModelAndVariationalOutput = model_and_variational.apply(params, rng_key, x)
180+
predicate = lambda module_name, name, value: "model" in module_name
181+
p_params, q_params = hk.data_structures.partition(predicate, params)
182+
q_z = variational.apply(q_params, rng_key, x)
183+
z = q_z.sample(args.num_eval_samples, seed=rng_key)
184+
p_z, p_x_given_z = model.apply(p_params, rng_key, x, z)
185+
log_q_z = q_z.log_prob(z).sum(axis=-1)
206186
# sum over last three image dimensions (width, height, channels)
207-
log_p_x_given_z = out.p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1))
187+
log_p_x_given_z = p_x_given_z.log_prob(x).sum(axis=(-3, -2, -1))
208188
# sum over latent dimension
209-
log_p_z = out.p_z.log_prob(out.z).sum(axis=-1)
210-
189+
log_p_z = p_z.log_prob(z).sum(axis=-1)
211190
elbo = log_p_x_given_z + log_p_z - log_q_z
212191
# importance sampling of approximate marginal likelihood with q(z)
213192
# as the proposal, and logsumexp in the sample dimension
@@ -253,15 +232,16 @@ def print_progress(step: int, examples_per_sec: float):
253232
f"Train ELBO estimate: {train_elbo:<5.3f}\t"
254233
f"Validation ELBO estimate: {elbo:<5.3f}\t"
255234
f"Validation log p(x) estimate: {log_p_x:<5.3f}\t"
256-
f"Speed: {examples_per_sec:<5.0f} examples/s"
235+
f"Speed: {examples_per_sec:<5.2e} examples/s"
257236
)
258237

259238
t0 = time.time()
260239
for step in range(args.training_steps):
261240
if step % args.log_interval == 0:
262-
examples_per_sec = args.log_interval / (time.time() - t0)
241+
t1 = time.time()
242+
examples_per_sec = args.log_interval * args.batch_size / (t1 - t0)
263243
print_progress(step, examples_per_sec)
264-
t0 = time.time()
244+
t0 = t1
265245
params, opt_state = train_step(params, next(rng_seq), opt_state, next(train_ds))
266246

267247
test_ds = load_dataset(tfds.Split.TEST, args.batch_size, args.random_seed)
@@ -271,6 +251,7 @@ def print_progress(step: int, examples_per_sec: float):
271251
f"Test ELBO estimate: {elbo:<5.3f}\t"
272252
f"Test log p(x) estimate: {log_p_x:<5.3f}\t"
273253
)
254+
print(f"Total time: {(time.time() - start_time) / 60:.3f} minutes")
274255

275256

276257
if __name__ == "__main__":

0 commit comments

Comments
 (0)