This repository contains an efficient approximate algorithm for inference and learning for temporal graphical models with binary latent variables. This code is accompanied by the following paper. A video explaining the algorithm and code will follow soon.
In order to use the algorithm, probability densities for p(zt|ztm1) and p(xt|zt) need to be defined. The algorithm has two control knobs that can trade off computational burden and accurary. The first control know is the number of samples used to approximate the data likelihood (num_samples) whereas the second control knob controls the accuracy of the underlying sampler (EPS).
The input to p(xt|zt) is xt: (num_steps, x_dim)
and zt: (num_steps, num_samples, z_dim)
and the output is the log probability for each state, i.e. has shape (num_steps, num_samples)
. Because jax.lax.scan
is employed, input for p(zt|ztm1) does not have the num_steps
dimension, i.e. it is zt: (num_samples, z_dim)
and ztm1: (num_samples, z_dim)
and the output is the log probability for each combination of states, i.e. has shape (num_samples, num_samples)
. The algorithm therefore scales quadratically in num_samples
.
Once these two potentially parameterized functions are defined, the model can be fit and inference can be performed by:
from nvif import NVIF
N = NVIF(p_zz=p_zz, p_xz=p_xz, num_steps=128,
num_samples=512, z_dim=15, x_dim=156)
N.train(x[:5000], optimizer = optim.Adam(3E-3), num_epochs=20)
z_hat = N.inference(x[5000:])
See nilm_example.ipynb
for an example of how to use the algorithm in the context of a synthetic problem inspired by Non-Intrusive Load Monitoring.
- Why does the time required per epoch vary over time?
- Most of the time is spent sampling without replacement from the auxiliary distribution Q. Sampling without replacement according to pre-defined inclusion probabilities is difficult and in some cases even impossible. The difficulty of sampling without replacement increases when the distribution to sample from has lower entropy. During training, the entropy of Q decreases (in the beginning most states have the same probability) making the sampling step more time consuming.
- Is the model that is being performed learning and inference on a Factorial Hidden Markov Model (FMM)?
- No! FMMs make the assumption that the individual latent chains are marginally indepedent. NVIF does not require this assumption making the class of models that NVIF can perform inference and learning on a lot richer. Speficially for NILM, the 'independence between chains'-assumption is oftentimes not great because, from experience, you want to constrain the number of latent states that switch states.
- What are potential avenues for future work?
- The underlying sampler (Yves Tilles elimination sampler) is slow but accurate. There are faster but less accurate alternatives such as, e.g. the Pareto sampler. Studying the effects of swapping out the sampler might considerably speed up inference.
- The main contribution of NVIF is an approximate algorithm for inference and learning in temporal models with binary latent states. However, little research has gone into the best instantiations (
p_zz
andp_xz
) to solve e.g. Non-Intrusive Load Monitoring or Energy Disaggregation. The performance of NVIF can most likely improved substantially by finding better choices forp_zz
andp_xz
. - So far, NVIF has only been evaluated in the context of NILM, applying the algorithm to other problems that require inference of binary latent states might be interesting.