Skip to content

Code for reimplementing "Stein Variational Gradient Descent (SVGD)" method.

Notifications You must be signed in to change notification settings

birajpandey/SVGD-reimplementation

Repository files navigation

Re-implementation of Stein Variational Gradient Descent (SVGD)

In this repo, we re-implement the Stein Variational Gradient Descent (SVGD) algorithm as the final project for AMATH 590: Gradient Flows course at the University of Washington taught by taught by Prof. Bamdad Hosseini. The authors of this study are @birajpandey and @Vilin97.

SVGD was first developed by Liu et. al. from the Dartmouth ML group. The original implementation is linked here.

Qiang Liu and Dilin Wang. "Stein Variational Gradient Descent: A General Purpose Bayesian Inference Algorithm". [NeurIPS 2017]

Final Report

To read our experimental results, please check out the AMATH_590_SVGD_report.pdf in the repo.

Setup:

  1. Do a clean download of the repository.

    git clone https://github.com/birajpandey/SVGD-reproducibility.git
    
  2. Go to the downloaded repo

    cd path/to/SVGD-reproducibility
    
  3. Run the Makefile. It creates an anaconda environment called svgd_env, downloads required packages, datasets and runs tests.

    make 
    
  4. Activate the conda environment.

    conda activate svgd_env
    
  5. Install the svgd package

    pip install -e .
    
  6. Run the files in scripts/ to reproduce our results.

Remark: This project structure is based on the cookiecutter data science project template. We also took a great deal of help from the The Good Research Code Handbook written by Patrick J Mineault.

Example:

Here we use SVGD to sample from a gaussian mixture with three modes.

import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = 'false'

import jax.numpy as jnp
import jax.random as jrandom
import numpy as np
from svgd import kernel, density, models
import matplotlib.pyplot as plt


# Define the gaussian mixture
means = jnp.array([[-3,0], [3,0], [0, 3]])
covariances = jnp.array([[[0.2, 0],[0, 0.2]], [[0.2, 0],[0, 0.2]],
                         [[0.2, 0],[0, 0.2]]])
weights = jnp.array([1 / 3, 1 / 3, 1 / 3])
density_params = {'mean': means, 'covariance': covariances, 'weights': weights}
density_obj = density.Density(density.gaussian_mixture_pdf,
                              density_params)

# initialize the particles
key = jrandom.PRNGKey(10)
particles = jrandom.normal(key=key, shape=(500, 2))  * 0.5

# define model
model_params = {'length_scale': 0.3}
model_kernel = kernel.Kernel(kernel.rbf_kernel, model_params)
transporter = models.SVGDModel(kernel=model_kernel)

# transport
num_iterations, step_size = 1000, 0.5
transported, trajectory = transporter.predict(particles,
                                              density_obj.score,
                                              num_iterations, step_size,
                                              trajectory=True, 
                                              adapt_length_scale=False)
# Plot density
grid_res = 100
# Input locations at which to compute probabilities
x_plot = np.linspace(-4.5, 4.5, grid_res)
x_plot = np.stack(np.meshgrid(x_plot, x_plot), axis=-1)

# Plot density
prob = density_obj(x_plot.reshape(-1, 2)).reshape(grid_res, grid_res)
plt.figure(figsize=(5, 5))
plt.contourf(x_plot[:, :, 0], x_plot[:, :, 1], prob, cmap="magma")

# plot initial particles
plt.scatter(particles[:, 0], particles[:, 1], zorder=2, c="w", s=10,
            label="initial sample", alpha=0.5)

# plot final particles
plt.scatter(transported[:, 0], transported[:, 1], zorder=2, c='r',
            s=10, label="final sample", alpha=0.6)
plt.xlim(-4.5, 4.5)
plt.ylim(-4.5, 4.5)
plt.legend()
plt.title(f"Transported Particles, h={model_params['length_scale']}")
plt.show()

Reproducing our experiments

1D benchmarks

To reproduce our experiments for 1D benchmarks, run:

python scripts/1.0-bp-svgd-1d.py

2D benchmarks

To reproduce our experiments for two-dimensional gaussian mixture, run:

python scripts/2.0-bp-three-gaussian.py

To reproduce our experiments for two-dimensional circle, run:

python scripts/3.0-vi-svgd-circle.py

About

Code for reimplementing "Stein Variational Gradient Descent (SVGD)" method.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published