Skip to content
/ nvif Public

This repository contains an efficient approximate algorithm for inference and learning for temporal graphical models with binary latent variables.

License

Notifications You must be signed in to change notification settings

helange23/nvif

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Neural Variational Filtering and Inference

This repository contains an efficient approximate algorithm for inference and learning for temporal graphical models with binary latent variables. This code is accompanied by the following paper. A video explaining the algorithm and code will follow soon.

Usage

In order to use the algorithm, probability densities for p(zt|ztm1) and p(xt|zt) need to be defined. The algorithm has two control knobs that can trade off computational burden and accurary. The first control know is the number of samples used to approximate the data likelihood (num_samples) whereas the second control knob controls the accuracy of the underlying sampler (EPS).

The input to p(xt|zt) is xt: (num_steps, x_dim) and zt: (num_steps, num_samples, z_dim) and the output is the log probability for each state, i.e. has shape (num_steps, num_samples). Because jax.lax.scan is employed, input for p(zt|ztm1) does not have the num_steps dimension, i.e. it is zt: (num_samples, z_dim) and ztm1: (num_samples, z_dim) and the output is the log probability for each combination of states, i.e. has shape (num_samples, num_samples). The algorithm therefore scales quadratically in num_samples.

Once these two potentially parameterized functions are defined, the model can be fit and inference can be performed by:

from nvif import NVIF

N = NVIF(p_zz=p_zz, p_xz=p_xz, num_steps=128,
         num_samples=512, z_dim=15, x_dim=156)
N.train(x[:5000], optimizer = optim.Adam(3E-3), num_epochs=20)

z_hat = N.inference(x[5000:])

See nilm_example.ipynb for an example of how to use the algorithm in the context of a synthetic problem inspired by Non-Intrusive Load Monitoring.

Frequently Asked Questions

  • Why does the time required per epoch vary over time?
    • Most of the time is spent sampling without replacement from the auxiliary distribution Q. Sampling without replacement according to pre-defined inclusion probabilities is difficult and in some cases even impossible. The difficulty of sampling without replacement increases when the distribution to sample from has lower entropy. During training, the entropy of Q decreases (in the beginning most states have the same probability) making the sampling step more time consuming.
  • Is the model that is being performed learning and inference on a Factorial Hidden Markov Model (FMM)?
    • No! FMMs make the assumption that the individual latent chains are marginally indepedent. NVIF does not require this assumption making the class of models that NVIF can perform inference and learning on a lot richer. Speficially for NILM, the 'independence between chains'-assumption is oftentimes not great because, from experience, you want to constrain the number of latent states that switch states.
  • What are potential avenues for future work?
    • The underlying sampler (Yves Tilles elimination sampler) is slow but accurate. There are faster but less accurate alternatives such as, e.g. the Pareto sampler. Studying the effects of swapping out the sampler might considerably speed up inference.
    • The main contribution of NVIF is an approximate algorithm for inference and learning in temporal models with binary latent states. However, little research has gone into the best instantiations (p_zz and p_xz) to solve e.g. Non-Intrusive Load Monitoring or Energy Disaggregation. The performance of NVIF can most likely improved substantially by finding better choices for p_zz and p_xz.
    • So far, NVIF has only been evaluated in the context of NILM, applying the algorithm to other problems that require inference of binary latent states might be interesting.

About

This repository contains an efficient approximate algorithm for inference and learning for temporal graphical models with binary latent variables.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published