Skip to content

Latest commit

 

History

History
79 lines (56 loc) · 3.45 KB

README.md

File metadata and controls

79 lines (56 loc) · 3.45 KB

spatial-VAE

Source code and datasets for Explicitly disentangling image content from translation and rotation with spatial-VAE to appear at NeurIPS 2019.

Learned hinge motion of 5HDB (1-d latent variable)
5HDB_gif

Learned arm motion of CODH/ACS (2-d latent variable)
codhacs_gif

Learned antibody conformations (2-d latent variable)
antibody_gif

Bibtex

@incollection{bepler2019spatialvae,
title = {Explicitly disentangling image content from translation and rotation with spatial-VAE},
author = {Bepler, Tristan and Zhong, Ellen and Kelley, Kotaro and Brignole, Edward and Berger, Bonnie},
booktitle = {Advances in Neural Information Processing Systems 32},
editor = {H. Wallach and H. Larochelle and A. Beygelzimer and F. d\textquotesingle Alch\'{e}-Buc and E. Fox and R. Garnett},
pages = {15409--15419},
year = {2019},
publisher = {Curran Associates, Inc.},
url = {http://papers.nips.cc/paper/9677-explicitly-disentangling-image-content-from-translation-and-rotation-with-spatial-vae.pdf}
}

Setup

Dependencies:

  • python 3
  • pytorch >= 0.4
  • torchvision
  • numpy
  • pillow
  • topaz (for loading MRC files)

Datasets

Datasets as tarballs are available from the links below.

Usage

The scripts, "train_mnist.py", "train_particles.py", and "train_galaxy.py", train spatial-VAE models on the MNIST, single particle EM, and galaxy zoo data.

For example, to train a spatial-VAE model on the CODH/ACS dataset

python train_particles.py data/codhacs/processed_train.npy data/codhacs/processed_test.npy --num-epochs=1000 --augment-rotation

Some script options include:
--z-dim: dimension of the unstructured latent variable (default: 2)
--p-hidden-dim and --p-num-layers: the number of layers and number of units per layer in the spatial generator network
--q-hidden-dim and --q-num-layers: the number of layers and number of units per layer in the approximate inference network
--dx-prior, --theta-prior: standard deviation (in fraction of image size) of the translation prior and standard deviation of the rotation prior
--no-rotate, --no-translate: flags to disable rotation and translation inference
--normalize: normalize the images before training (subtract mean, divide by standard deviation)
--ctf-train, --ctf-test: path to tables containing CTF parameters for the train and test images, used to perform CTF correction if provided
--fit-noise: also output the standard deviation of each pixel from the spatial generator network, sometimes called a colored noise model
--save-prefix: save model parameters every few epochs to this path prefix

See --help for complete arguments list.

License

This source code is provided under the MIT License.