Skip to content

Commit

Permalink
fix: score imports (#1246)
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb authored Aug 29, 2024
1 parent b2e4fd0 commit bf2f96f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 1 addition & 3 deletions sbi/inference/posteriors/score_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from sbi.neural_nets.estimators.shape_handling import (
reshape_to_batch_event,
)
from sbi.samplers.score.correctors import Corrector
from sbi.samplers.score.predictors import Predictor
from sbi.samplers.score.score import Diffuser
from sbi.samplers.score import Corrector, Diffuser, Predictor
from sbi.sbi_types import Shape
from sbi.utils import check_prior
from sbi.utils.torchutils import ensure_theta_batched
Expand Down
3 changes: 3 additions & 0 deletions sbi/samplers/score/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from sbi.samplers.score.correctors import Corrector, get_corrector
from sbi.samplers.score.predictors import Predictor, get_predictor
from sbi.samplers.score.score import Diffuser
3 changes: 1 addition & 2 deletions sbi/samplers/score/score.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from sbi.inference.potentials.score_based_potential import (
PosteriorScoreBasedPotential,
)
from sbi.samplers.score.correctors import Corrector, get_corrector
from sbi.samplers.score.predictors import Predictor, get_predictor
from sbi.samplers.score import Corrector, Predictor, get_corrector, get_predictor


class Diffuser:
Expand Down
2 changes: 1 addition & 1 deletion tests/score_samplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
score_estimator_based_potential,
)
from sbi.neural_nets.net_builders import build_score_estimator
from sbi.samplers.score.score import Diffuser
from sbi.samplers.score import Diffuser


@pytest.mark.parametrize("sde_type", ["vp", "ve", "subvp"])
Expand Down

0 comments on commit bf2f96f

Please sign in to comment.