Skip to content

Commit

Permalink
Fix circular imports
Browse files Browse the repository at this point in the history
  • Loading branch information
michaeldeistler committed Apr 10, 2024
1 parent fac45c8 commit 81ce2c9
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions sbi/analysis/conditional_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch import Tensor
from torch.distributions import Distribution

from sbi.neural_nets.density_estimators.nflows_flow import NFlowsFlow
from sbi.sbi_types import Shape, TorchTransform
from sbi.utils.conditional_density_utils import (
ConditionedPotential,
Expand Down Expand Up @@ -186,16 +185,17 @@ def conditional_corrcoeff(
class ConditionedMDN:
def __init__(
self,
mdn: NFlowsFlow,
mdn,
x_o: Tensor,
condition: Tensor,
dims_to_sample: List[int],
) -> None:
r"""Class that can sample and evaluate a conditional mixture-of-gaussians.
Args:
mdn: Mixture density network that models $p(\theta|x). We use the normflows
implementation of MDNs.
mdn Mixture density network that models $p(\theta|x). We use the normflows
implementation of MDNs. Type is `NFlowsFlow`, type hint removed to
avoid circular import, see #1140.
x_o: The datapoint at which the `net` is evaluated.
condition: Parameter set that all dimensions not specified in
`dims_to_sample` will be fixed to. Should contain dim_theta elements,
Expand Down

0 comments on commit 81ce2c9

Please sign in to comment.