Skip to content

Commit

Permalink
fix: zuko UnconditionalTransform, fix BPF args. (#1182)
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb authored Jun 20, 2024
1 parent d76a085 commit 2867816
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 13 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ dependencies = [
"torch>=1.8.0",
"tqdm",
"pymc>=5.0.0",
"zuko>=1.1.0",
"zuko>=1.2.0",
]

[project.optional-dependencies]
Expand Down
4 changes: 1 addition & 3 deletions sbi/neural_nets/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,6 @@ def build_zuko_bpf(
num_transforms: int = 3,
embedding_net: nn.Module = nn.Identity(),
degree: int = 16,
linear: bool = False,
**kwargs,
) -> ZukoFlow:
"""
Expand Down Expand Up @@ -993,11 +992,10 @@ def build_zuko_bpf(
num_transforms: The number of transformations in the flow. Defaults to 5.
embedding_net: The embedding network to use. Defaults to nn.Identity().
degree: The degree :math:`M` of the Bernstein polynomial.
linear: Whether to use a linear or sigmoid mapping to :math:`[0, 1]`.
**kwargs: Additional keyword arguments to pass to the flow constructor.
"""
which_nf = "BPF"
additional_kwargs = {"degree": degree, "linear": linear, **kwargs}
additional_kwargs = {"degree": degree, **kwargs}
flow = build_zuko_flow(
which_nf,
batch_x,
Expand Down
4 changes: 2 additions & 2 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
biject_to,
constraints,
)
from zuko.flows import UnconditionalTransform

from sbi.sbi_types import TorchTransform
from sbi.utils.torchutils import atleast_2d
from sbi.utils.zukoutils import UnconditionalLazyTransform


def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -> None:
Expand Down Expand Up @@ -187,7 +187,7 @@ def standardizing_transform_zuko(
Affine transform for z-scoring
"""
t_mean, t_std = z_standardization(batch_t, structured_dims, min_std)
return UnconditionalLazyTransform(
return UnconditionalTransform(
AffineTransform,
loc=-t_mean / t_std,
scale=1 / t_std,
Expand Down
7 changes: 0 additions & 7 deletions sbi/utils/zukoutils.py

This file was deleted.

0 comments on commit 2867816

Please sign in to comment.