Skip to content

Commit dfb452b

Browse files
author
Jaan Altosaar
committed
add jax example
1 parent 2d6fc52 commit dfb452b

7 files changed

+371
-25
lines changed

.env

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# dev.env - development configuration
2+
3+
# suppress warnings for jax
4+
JAX_PLATFORM_NAME=cpu
5+
6+
# suppress tensorflow warnings
7+
TF_CPP_MIN_LOG_LEVEL=2

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
*.pyc
2+
launch.json
3+
settings.json
4+
*.code-workspace

README.md

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,12 @@ step: 20000 train elbo: -101.51
3434
step: 20000 valid elbo: -105.02 valid log p(x): -99.11
3535
step: 30000 train elbo: -98.70
3636
step: 30000 valid elbo: -103.76 valid log p(x): -97.71
37-
step: 40000 train elbo: -104.31
38-
step: 40000 valid elbo: -103.71 valid log p(x): -97.27
39-
step: 50000 train elbo: -97.20
40-
step: 50000 valid elbo: -102.97 valid log p(x): -96.60
41-
step: 60000 train elbo: -97.50
42-
step: 60000 valid elbo: -102.82 valid log p(x): -96.49
43-
step: 70000 train elbo: -94.68
44-
step: 70000 valid elbo: -102.63 valid log p(x): -96.22
45-
step: 80000 train elbo: -92.86
46-
step: 80000 valid elbo: -102.53 valid log p(x): -96.09
47-
step: 90000 train elbo: -93.83
48-
step: 90000 valid elbo: -102.33 valid log p(x): -96.00
49-
step: 100000 train elbo: -93.91
50-
step: 100000 valid elbo: -102.48 valid log p(x): -95.92
51-
step: 110000 train elbo: -94.34
52-
step: 110000 valid elbo: -102.81 valid log p(x): -96.09
53-
step: 120000 train elbo: -88.63
54-
step: 120000 valid elbo: -102.53 valid log p(x): -95.80
55-
step: 130000 train elbo: -96.61
56-
step: 130000 valid elbo: -103.56 valid log p(x): -96.26
57-
step: 140000 train elbo: -94.92
58-
step: 140000 valid elbo: -102.81 valid log p(x): -95.86
59-
step: 150000 train elbo: -97.84
60-
step: 150000 valid elbo: -103.06 valid log p(x): -95.92
61-
step: 150000 test elbo: -101.64 test log p(x): -95.33
6237
```
38+
39+
Using jax:
40+
```
41+
Step 0 Validation ELBO estimate: -507.485 Validation log p(x) estimate: -507.485
42+
Step 10000 Validation ELBO estimate: -152.695 Validation log p(x) estimate: -152.695
43+
Step 20000 Validation ELBO estimate: -150.413 Validation log p(x) estimate: -150.413
44+
Step 30000 Validation ELBO estimate: -150.529 Validation log p(x) estimate: -150.529
45+
```

environment_jax.yml

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
name: jax
2+
channels:
3+
- defaults
4+
dependencies:
5+
- _libgcc_mutex=0.1=main
6+
- ca-certificates=2021.4.13=h06a4308_1
7+
- certifi=2020.12.5=py39h06a4308_0
8+
- ld_impl_linux-64=2.33.1=h53a641e_7
9+
- libffi=3.3=he6710b0_2
10+
- libgcc-ng=9.1.0=hdf63c60_0
11+
- libstdcxx-ng=9.1.0=hdf63c60_0
12+
- ncurses=6.2=he6710b0_1
13+
- openssl=1.1.1k=h27cfd23_0
14+
- pip=21.1.1=py39h06a4308_0
15+
- python=3.9.5=hdb3f193_3
16+
- readline=8.1=h27cfd23_0
17+
- setuptools=52.0.0=py39h06a4308_0
18+
- sqlite=3.35.4=hdfb4753_0
19+
- tk=8.6.10=hbc83047_0
20+
- tzdata=2020f=h52ac0ba_0
21+
- wheel=0.36.2=pyhd3eb1b0_0
22+
- xz=5.2.5=h7b6447c_0
23+
- zlib=1.2.11=h7b6447c_3
24+
- pip:
25+
- absl-py==0.12.0
26+
- astunparse==1.6.3
27+
- attrs==21.2.0
28+
- cachetools==4.2.2
29+
- chardet==4.0.0
30+
- chex==0.0.7
31+
- cloudpickle==1.6.0
32+
- decorator==5.0.9
33+
- dill==0.3.3
34+
- dm-haiku==0.0.5.dev0
35+
- dm-tree==0.1.6
36+
- flatbuffers==1.12
37+
- future==0.18.2
38+
- gast==0.4.0
39+
- google-auth==1.30.0
40+
- google-auth-oauthlib==0.4.4
41+
- google-pasta==0.2.0
42+
- googleapis-common-protos==1.53.0
43+
- grpcio==1.34.1
44+
- h5py==3.1.0
45+
- idna==2.10
46+
- jax==0.2.13
47+
- jaxlib==0.1.67
48+
- jmp==0.0.2
49+
- keras-nightly==2.5.0.dev2021032900
50+
- keras-preprocessing==1.1.2
51+
- markdown==3.3.4
52+
- numpy==1.19.5
53+
- oauthlib==3.1.0
54+
- opt-einsum==3.3.0
55+
- optax==0.0.7
56+
- promise==2.3
57+
- protobuf==3.17.0
58+
- pyasn1==0.4.8
59+
- pyasn1-modules==0.2.8
60+
- requests==2.25.1
61+
- requests-oauthlib==1.3.0
62+
- rsa==4.7.2
63+
- scipy==1.6.3
64+
- six==1.15.0
65+
- tabulate==0.8.9
66+
- tensorboard==2.5.0
67+
- tensorboard-data-server==0.6.1
68+
- tensorboard-plugin-wit==1.8.0
69+
- tensorflow==2.5.0
70+
- tensorflow-datasets==4.3.0
71+
- tensorflow-estimator==2.5.0
72+
- tensorflow-metadata==1.0.0
73+
- termcolor==1.1.0
74+
- tfp-nightly==0.14.0.dev20210521
75+
- toolz==0.11.1
76+
- tqdm==4.60.0
77+
- typing-extensions==3.7.4.3
78+
- urllib3==1.26.4
79+
- werkzeug==2.0.1
80+
- wrapt==1.12.1
81+
prefix: /home/jaan/miniconda3/envs/jax

setup.cfg

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[flake8]
2+
max-line-length = 88

0 commit comments

Comments
 (0)