Skip to content

Commit

Permalink
feat: flow matching posterior estimation (#1049)
Browse files Browse the repository at this point in the history
* added files for vector field-based estimators and posteriors

* Flow estimator base (requires tests)

* Prepare snpe for flow and score matching (test required)

* working on score estimation

* implementing mean and std funcs for score estimator

* draft for score-based potentials (for multiple IID obs)

* score estimator losses completed

* added score.py file for network construction for score estimation and related networks

* added structure for FMPE implementation

* moved flow matching est. to density est.

* added density est. interface

* removed flow matching est. spec. functionality

* adapted train to return a densirty est.

* wip fmpe

* implement build_zuko_flow_matching

* added typing for Tuple

* return flow

* fixed starting round to single round

* changed int to shape

* adapted shapes in model construction

* add z_scores to fm-estimator

* adapted to new structure

* adapted to new structure

* z_score

* fixed transforms

* implement new standardized backend

* test new standardize backend

* wip fmpe with broken z-scoring

* fix zuko_transform backend in standardize_transform function

* added fmpe to tests

* added fmpe to imports

* wip: add tests fmpe

* Update sbi/inference/fmpe/base.py

consistent naming

Co-authored-by: Jan <[email protected]>

* Update sbi/inference/fmpe/base.py

consistent naming

Co-authored-by: Jan <[email protected]>

* Update sbi/inference/fmpe/base.py

consistent naming

Co-authored-by: Jan <[email protected]>

* Update sbi/inference/fmpe/base.py

Co-authored-by: Jan <[email protected]>

* licence, docstrings, naming

* Some minor changes

* added files for vector field-based estimators and posteriors

* Flow estimator base (requires tests)

* Prepare snpe for flow and score matching (test required)

* draft for score-based potentials (for multiple IID obs)

* working on score estimation

* implementing mean and std funcs for score estimator

* score estimator losses completed

* added score.py file for network construction for score estimation and related networks

* added structure for FMPE implementation

* moved flow matching est. to density est.

* added density est. interface

* removed flow matching est. spec. functionality

* adapted train to return a densirty est.

* wip fmpe

* implement build_zuko_flow_matching

* added typing for Tuple

* return flow

* fixed starting round to single round

* changed int to shape

* adapted shapes in model construction

* add z_scores to fm-estimator

* z_score

* adapted to new structure

* adapted to new structure

* implement new standardized backend

* fixed transforms

* test new standardize backend

* wip fmpe with broken z-scoring

* fix zuko_transform backend in standardize_transform function

* wip: add tests fmpe

* added fmpe to tests

* added fmpe to imports

* licence, docstrings, naming

* Update sbi/inference/fmpe/base.py

consistent naming

Co-authored-by: Jan <[email protected]>

* Update sbi/inference/fmpe/base.py

consistent naming

Co-authored-by: Jan <[email protected]>

* Update sbi/inference/fmpe/base.py

consistent naming

Co-authored-by: Jan <[email protected]>

* Update sbi/inference/fmpe/base.py

Co-authored-by: Jan <[email protected]>

* Some minor changes

* WIP: refactor fmpe

* WIP: Add docstrings

Co-authored-by: Jan <[email protected]>

* WIP: implement flowmatching_nn

* remove score-based code from this PR.

* BREAKING CHANGES: refactoring, adding TODOs'

* refactoring:
- move flow matching build fns to flow_matcher.py
- move FlowMatchingEstimator to zuko_flow.py
- introduce embedding nets with z-scoring
- use same signatures like in ZukoFlow.

* small fixes for typing.

* fix transforms

* refactor density estimator tests

* fix FM embedding nets and transforms

* add todos to tests

* using flowmatching factory function instead of NPE

* added attribute embedding net for condition

* added to-do

* added convergence check for training

* adapted convergence and density est. to project convention

* added experiement tracking according to pkg. convention

* removed unused check for multi dim. inputs

* minimal formatting

* minimal formatting

* removed opt. x_dim for direct post.

* removed commented get_dataloader function as it's used from super

* removed all multi-round tests

* removed everything snpe

* removed sim helper

* formatting

* syntax

* get bugfix for shape handling from main branch.

* fix: stopping criterion baseline

* fixed early stopping bug

* fixed test argument

* fixed test argument to match FMPE

* changed to flow matching nn

* commented continue training test

* adopted new pairplot api

* added warning about missing feature

* string formatting

* feat: add GLU ResNet, fix embedding net handling.

* fix: FMPE imports.

* test: refactor and extend FMPE tests.

* fix: import Adam explictly

* fix tests: add map xfail, fix resume training

* fix: theta as resnet output dim"

* undo loss log_prob renaming in inference classes. adapt fmpe

* refactor: renaming and docs

* remove score estimtor left overs

* move nets to flow matcher.

* fix: add multi-round handling for FMPE.

* fix: resume training vs force first round loss handling, also for SNPE'

* refactor tests: map passing

* typo fix.

* docs: add FMPE to implemented methods tutorial

---------

Co-authored-by: Faried Abu Zaid <[email protected]>
Co-authored-by: Jan <[email protected]>
Co-authored-by: Jan Boelts <[email protected]>
  • Loading branch information
4 people authored Aug 14, 2024
1 parent 8ee2a2c commit 8f31c7b
Show file tree
Hide file tree
Showing 17 changed files with 1,346 additions and 16 deletions.
1 change: 1 addition & 0 deletions sbi/inference/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
infer,
simulate_for_sbi,
)
from sbi.inference.fmpe import FMPE
from sbi.inference.snle import MNLE, SNLE_A
from sbi.inference.snpe import SNPE_A, SNPE_B, SNPE_C # noqa: F401
from sbi.inference.snre import BNRE, SNRE, SNRE_A, SNRE_B, SNRE_C # noqa: F401
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/fmpe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from sbi.inference.fmpe.fmpe_base import FMPE
344 changes: 344 additions & 0 deletions sbi/inference/fmpe/fmpe_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,344 @@
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
# under the Apache License v2.0, see <https://www.apache.org/licenses/LICENSE-2.0>.


import time
from copy import deepcopy
from typing import Any, Callable, Dict, Optional, Union

import torch
from torch.distributions import Distribution
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.adam import Adam
from torch.utils.tensorboard.writer import SummaryWriter

from sbi import utils as utils
from sbi.inference.base import NeuralInference
from sbi.inference.posteriors.direct_posterior import DirectPosterior
from sbi.neural_nets import ConditionalDensityEstimator, flowmatching_nn
from sbi.utils import (
RestrictedPrior,
handle_invalid_x,
npe_msg_on_invalid_x,
validate_theta_and_x,
warn_if_zscoring_changes_data,
x_shape_from_simulation,
)
from sbi.utils.sbiutils import mask_sims_from_prior


class FMPE(NeuralInference):
"""Implements the Flow Matching Posterior Estimator (FMPE) for
simulation-based inference.
"""

def __init__(
self,
prior: Optional[Distribution],
density_estimator: Union[str, Callable] = "mlp",
device: str = "cpu",
logging_level: Union[int, str] = "WARNING",
summary_writer: Optional[SummaryWriter] = None,
show_progress_bars: bool = True,
) -> None:
"""Initialization method for the FMPE class.
Args:
prior: Prior distribution.
density_estimator: Neural network architecture used to learn the vector
field for flow matching. Can be a string, e.g., 'mlp' or 'resnet', or a
`Callable` that builds a corresponding neural network.
device: Device to use for training.
logging_level: Logging level.
summary_writer: Summary writer for tensorboard.
show_progress_bars: Whether to show progress bars.
"""
# obtain the shape of the prior samples
if isinstance(density_estimator, str):
self._build_neural_net = flowmatching_nn(model=density_estimator)
else:
self._build_neural_net = density_estimator

super().__init__(
prior=prior,
device=device,
logging_level=logging_level,
summary_writer=summary_writer,
show_progress_bars=show_progress_bars,
)

def append_simulations(
self,
theta: torch.Tensor,
x: torch.Tensor,
proposal: Optional[DirectPosterior] = None,
exclude_invalid_x: Optional[bool] = None,
data_device: Optional[str] = None,
) -> NeuralInference:
if (
proposal is None
or proposal is self._prior
or (
isinstance(proposal, RestrictedPrior) and proposal._prior is self._prior
)
):
current_round = 0
else:
raise NotImplementedError(
"Sequential FMPE with proposal different from prior is not implemented."
)

if exclude_invalid_x is None:
exclude_invalid_x = current_round == 0

if data_device is None:
data_device = self._device

theta, x = validate_theta_and_x(
theta, x, data_device=data_device, training_device=self._device
)

is_valid_x, num_nans, num_infs = handle_invalid_x(
x, exclude_invalid_x=exclude_invalid_x
)

x = x[is_valid_x]
theta = theta[is_valid_x]

# Check for problematic z-scoring
warn_if_zscoring_changes_data(x)
# Check whether there are NaNs or Infs in the data and remove accordingly.
npe_msg_on_invalid_x(
num_nans=num_nans,
num_infs=num_infs,
exclude_invalid_x=exclude_invalid_x,
algorithm="Single-round FMPE",
)

self._data_round_index.append(current_round)
prior_masks = mask_sims_from_prior(int(current_round > 0), theta.size(0))

self._theta_roundwise.append(theta)
self._x_roundwise.append(x)
self._prior_masks.append(prior_masks)

return self

def train(
self,
training_batch_size: int = 50,
learning_rate: float = 5e-4,
validation_fraction: float = 0.1,
stop_after_epochs: int = 20,
max_num_epochs: int = 2**31 - 1,
clip_max_norm: Optional[float] = 5.0,
resume_training: bool = False,
force_first_round_loss: bool = False,
show_train_summary: bool = False,
dataloader_kwargs: Optional[dict] = None,
) -> ConditionalDensityEstimator:
"""Train the flow matching estimator.
Args:
training_batch_size: Batch size for training. Defaults to 50.
learning_rate: Learning rate for training. Defaults to 5e-4.
validation_fraction: Fraction of the data to use for validation.
stop_after_epochs: Number of epochs to train for. Defaults to 20.
max_num_epochs: Maximum number of epochs to train for.
clip_max_norm: Maximum norm for gradient clipping. Defaults to 5.0.
resume_training: Can be used in case training time is limited, e.g. on a
cluster. If `True`, the split between train and validation set, the
optimizer, the number of epochs, and the best validation log-prob will
be restored from the last time `.train()` was called.
force_first_round_loss: Whether to allow training with
simulations that have not been sampled from the prior, e.g., in a
sequential inference setting. Note that can lead to biased inference
results.
show_train_summary: Whether to show the training summary. Defaults to False.
dataloader_kwargs: Additional keyword arguments for the dataloader.
Returns:
DensityEstimator: Trained flow matching estimator.
"""

# Load data from most recent round.
self._round = max(self._data_round_index)

if self._round == 0 and self._neural_net is not None:
assert force_first_round_loss or resume_training, (
"You have already trained this neural network. After you had trained "
"the network, you again appended simulations with `append_simulations"
"(theta, x)`, but you did not provide a proposal. If the new "
"simulations are sampled from the prior, you can set "
"`.train(..., force_first_round_loss=True`). However, if the new "
"simulations were not sampled from the prior, you should pass the "
"proposal, i.e. `append_simulations(theta, x, proposal)`. If "
"your samples are not sampled from the prior and you do not pass a "
"proposal and you set `force_first_round_loss=True`, the result of "
"FMPE will not be the true posterior. Instead, it will be the proposal "
"posterior, which (usually) is more narrow than the true posterior."
)

start_idx = 0 # as there is no multi-round FMPE yet

train_loader, val_loader = self.get_dataloaders(
start_idx,
training_batch_size,
validation_fraction,
resume_training=resume_training,
dataloader_kwargs=dataloader_kwargs,
)

if self._neural_net is None:
# Get theta, x to initialize NN
theta, x, _ = self.get_simulations(starting_round=start_idx)

# Use only training data for building the neural net (z-scoring transforms)
self._neural_net = self._build_neural_net(
theta[self.train_indices].to("cpu"),
x[self.train_indices].to("cpu"),
)
self._x_shape = x_shape_from_simulation(x.to("cpu"))

del theta, x

# Move entire net to device for training.
self._neural_net.to(self._device)

# initialize optimizer and training parameters
if not resume_training:
self.optimizer = Adam(
list(self._neural_net.net.parameters()), lr=learning_rate
)
self.epoch = 0
# NOTE: in the FMPE context we use MSE loss, not log probs.
self._val_loss = float("Inf")

while self.epoch <= max_num_epochs and not self._converged(
self.epoch, stop_after_epochs
):
self._neural_net.net.train()
train_loss_sum = 0
epoch_start_time = time.time()
for batch in train_loader:
self.optimizer.zero_grad()
# get batches on current device.
theta_batch, x_batch = (
batch[0].to(self._device),
batch[1].to(self._device),
)

train_loss = self._neural_net.loss(theta_batch, x_batch).mean()
train_loss_sum += train_loss.item()

train_loss.backward()
if clip_max_norm is not None:
clip_grad_norm_(
self._neural_net.net.parameters(), max_norm=clip_max_norm
)
self.optimizer.step()

self.epoch += 1

train_loss_average = train_loss_sum / len(train_loader) # type: ignore
# TODO: rename to loss once renaming is done in base class.
self._summary["training_log_probs"].append(-train_loss_average)

# Calculate validation performance.
self._neural_net.eval()
val_loss_sum = 0

with torch.no_grad():
for batch in val_loader:
theta_batch, x_batch = (
batch[0].to(self._device),
batch[1].to(self._device),
)
# Aggregate the validation losses.
val_losses = self._neural_net.loss(theta_batch, x_batch)
val_loss_sum += val_losses.sum().item()

# Take mean over all validation samples.
self._val_loss = val_loss_sum / (
len(val_loader) * val_loader.batch_size # type: ignore
)
# TODO: remove this once renaming to loss in base class is done.
self._val_log_prob = -self._val_loss
# Log validation log prob for every epoch.
# TODO: rename to loss and fix sign once renaming in base is done.
self._summary["validation_log_probs"].append(-self._val_loss)
self._summary["epoch_durations_sec"].append(time.time() - epoch_start_time)

self._maybe_show_progress(self._show_progress_bars, self.epoch)

self._report_convergence_at_end(self.epoch, stop_after_epochs, max_num_epochs)

# Update summary.
self._summary["epochs_trained"].append(self.epoch)
# TODO: rename to loss once renaming is done in base class.
self._summary["best_validation_log_prob"].append(self._best_val_log_prob)

# Update tensorboard and summary dict.
self._summarize(round_=self._round)

# Update description for progress bar.
if show_train_summary:
print(self._describe_round(self._round, self._summary))

# Avoid keeping the gradients in the resulting network, which can
# cause memory leakage when benchmarking.
self._neural_net.zero_grad(set_to_none=True)

return deepcopy(self._neural_net)

def build_posterior(
self,
density_estimator: Optional[ConditionalDensityEstimator] = None,
prior: Optional[Distribution] = None,
sample_with: str = "direct",
direct_sampling_parameters: Optional[Dict[str, Any]] = None,
**kwargs,
) -> DirectPosterior:
"""Build the posterior distribution.
Args:
density_estimator: Density estimator for the posterior.
prior: Prior distribution.
sample_with: Sampling method, currently only "direct" is supported.
direct_sampling_parameters: kwargs for DirectPosterior.
Returns:
DirectPosterior: Posterior distribution.
"""
if sample_with != "direct":
raise NotImplementedError(
"Currently, only direct sampling is supported for FMPE."
)

if prior is None:
assert self._prior is not None, (
"You did not pass a prior. You have to pass the prior either at "
"initialization `inference = SNPE(prior)` or to "
"`.build_posterior(prior=prior)`."
)
prior = self._prior
else:
utils.check_prior(prior)

if density_estimator is None:
posterior_estimator = self._neural_net
# If internal net is used device is defined.
device = self._device
else:
posterior_estimator = density_estimator
# Otherwise, infer it from the device of the net parameters.
device = next(density_estimator.parameters()).device.type

self._posterior = DirectPosterior(
posterior_estimator=posterior_estimator, # type: ignore
prior=prior,
device=device,
**direct_sampling_parameters or {},
)

return deepcopy(self._posterior)
2 changes: 1 addition & 1 deletion sbi/inference/posteriors/ensemble_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def weights(self, weights: Optional[Union[List[float], Tensor]]) -> None:
self._weights = torch.tensor([
1.0 / self.num_components for _ in range(self.num_components)
])
elif type(weights) == Tensor or type(weights) == List:
elif weights is Tensor or weights is List:
self._weights = torch.tensor(weights) / sum(weights)
else:
raise TypeError
Expand Down
2 changes: 1 addition & 1 deletion sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def train(
self._round = max(self._data_round_index)

if self._round == 0 and self._neural_net is not None:
assert force_first_round_loss, (
assert force_first_round_loss or resume_training, (
"You have already trained this neural network. After you had trained "
"the network, you again appended simulations with `append_simulations"
"(theta, x)`, but you did not provide a proposal. If the new "
Expand Down
7 changes: 6 additions & 1 deletion sbi/neural_nets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
FCEmbedding,
PermutationInvariantEmbedding,
)
from sbi.neural_nets.factory import classifier_nn, likelihood_nn, posterior_nn
from sbi.neural_nets.factory import (
classifier_nn,
flowmatching_nn,
likelihood_nn,
posterior_nn,
)
from sbi.neural_nets.flow import (
build_made,
build_maf,
Expand Down
3 changes: 3 additions & 0 deletions sbi/neural_nets/density_estimators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@
CategoricalMassEstimator,
CategoricalNet,
)
from sbi.neural_nets.density_estimators.flowmatching_estimator import (
FlowMatchingEstimator,
)
from sbi.neural_nets.density_estimators.mixed_density_estimator import (
MixedDensityEstimator,
)
Expand Down
Loading

0 comments on commit 8f31c7b

Please sign in to comment.