Skip to content

Commit

Permalink
fix: pyro flow pyright errors in GitHub action (#1165)
Browse files Browse the repository at this point in the history
* ignore types for wierd pyright thing

* Ignore params_dims specifically

* try to not ignore pyright

* next try

* Fixing this issue

* Adding better typing to pyro flows

* Revert "Adding better typing to pyro flows"

This reverts commit 825a525.

Notebooks should not be pushed

* Adding better typing infomration to avoid errors
  • Loading branch information
manuelgloeckler authored Jun 11, 2024
1 parent 66d1f2b commit fa6a874
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 20 deletions.
39 changes: 21 additions & 18 deletions sbi/samplers/vi/vi_pyro_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ def build_fn(

def init_affine_autoregressive(dim: int, device: str = "cpu", **kwargs):
"""Provides the default initial arguments for an affine autoregressive transform."""
hidden_dims = kwargs.pop("hidden_dims", [3 * dim + 5, 3 * dim + 5])
skip_connections = kwargs.pop("skip_connections", False)
hidden_dims: List[int] = kwargs.pop("hidden_dims", [3 * dim + 5, 3 * dim + 5])
skip_connections: bool = kwargs.pop("skip_connections", False)
nonlinearity = kwargs.pop("nonlinearity", nn.ReLU())
arn = AutoRegressiveNN(
dim, hidden_dims, nonlinearity=nonlinearity, skip_connections=skip_connections
Expand All @@ -170,12 +170,12 @@ def init_affine_autoregressive(dim: int, device: str = "cpu", **kwargs):

def init_spline_autoregressive(dim: int, device: str = "cpu", **kwargs):
"""Provides the default initial arguments for an spline autoregressive transform."""
hidden_dims = kwargs.pop("hidden_dims", [3 * dim + 5, 3 * dim + 5])
skip_connections = kwargs.pop("skip_connections", False)
hidden_dims: List[int] = kwargs.pop("hidden_dims", [3 * dim + 5, 3 * dim + 5])
skip_connections: bool = kwargs.pop("skip_connections", False)
nonlinearity = kwargs.pop("nonlinearity", nn.ReLU())
count_bins = kwargs.get("count_bins", 10)
order = kwargs.get("order", "linear")
bound = kwargs.get("bound", 10)
count_bins: int = kwargs.get("count_bins", 10)
order: str = kwargs.get("order", "linear")
bound: int = kwargs.get("bound", 10)
if order == "linear":
param_dims = [count_bins, count_bins, (count_bins - 1), count_bins]
else:
Expand All @@ -194,25 +194,28 @@ def init_affine_coupling(dim: int, device: str = "cpu", **kwargs):
"""Provides the default initial arguments for an affine autoregressive transform."""
assert dim > 1, "In 1d this would be equivalent to affine flows, use them."
nonlinearity = kwargs.pop("nonlinearity", nn.ReLU())
split_dim = kwargs.get("split_dim", dim // 2)
hidden_dims = kwargs.pop("hidden_dims", [5 * dim + 20, 5 * dim + 20])
params_dims = (dim - split_dim, dim - split_dim)
arn = DenseNN(split_dim, hidden_dims, params_dims, nonlinearity=nonlinearity).to(
device
)
split_dim: int = int(kwargs.get("split_dim", dim // 2))
hidden_dims: List[int] = kwargs.pop("hidden_dims", [5 * dim + 20, 5 * dim + 20])
params_dims: List[int] = [dim - split_dim, dim - split_dim]
arn = DenseNN(
split_dim,
hidden_dims,
params_dims,
nonlinearity=nonlinearity,
).to(device)
return [split_dim, arn], {"log_scale_min_clip": -3.0}


def init_spline_coupling(dim: int, device: str = "cpu", **kwargs):
"""Intitialize a spline coupling transform, by providing necessary args and
kwargs."""
assert dim > 1, "In 1d this would be equivalent to affine flows, use them."
split_dim = kwargs.get("split_dim", dim // 2)
hidden_dims = kwargs.pop("hidden_dims", [5 * dim + 30, 5 * dim + 30])
split_dim: int = kwargs.get("split_dim", dim // 2)
hidden_dims: List[int] = kwargs.pop("hidden_dims", [5 * dim + 30, 5 * dim + 30])
nonlinearity = kwargs.pop("nonlinearity", nn.ReLU())
count_bins = kwargs.get("count_bins", 15)
order = kwargs.get("order", "linear")
bound = kwargs.get("bound", 10)
count_bins: int = kwargs.get("count_bins", 15)
order: str = kwargs.get("order", "linear")
bound: int = kwargs.get("bound", 10)
if order == "linear":
param_dims = [
(dim - split_dim) * count_bins,
Expand Down
5 changes: 3 additions & 2 deletions sbi/utils/pyroutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ def prior():
model_trace = poutine.trace(model).get_trace(*model_args, **model_kwargs)

for name, node in model_trace.iter_stochastic_nodes():
fn = node["fn"]
transforms[name] = biject_to(fn.support).inv
if "fn" in node:
fn = node["fn"]
transforms[name] = biject_to(fn.support).inv

return transforms

0 comments on commit fa6a874

Please sign in to comment.