Skip to content

Commit

Permalink
fix: upgrade pyright, fix typing.
Browse files Browse the repository at this point in the history
  • Loading branch information
janfb committed Jul 30, 2024
1 parent 98493dd commit 923fb40
Show file tree
Hide file tree
Showing 11 changed files with 31 additions and 26 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ dev = [
# Lint
"pre-commit == 3.5.0",
"pyyaml",
"pyright == 1.1.355",
"pyright",
"ruff>=0.3.3",
# Test
"pytest",
Expand Down
2 changes: 1 addition & 1 deletion sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1562,7 +1562,7 @@ def _sbc_rank_plot(
for idx, rank in enumerate(ranks_list):
assert isinstance(rank, (Tensor, np.ndarray))
if isinstance(rank, Tensor):
ranks_list[idx] = rank.numpy()
ranks_list[idx]: np.ndarray = rank.numpy() # type: ignore

plot_types = ["hist", "cdf"]
assert (
Expand Down
5 changes: 3 additions & 2 deletions sbi/analysis/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@

import torch
from pyknos.nflows.nn import nets
from torch import Tensor, nn, optim, relu
from torch import Tensor, nn, relu
from torch.nn import MSELoss
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.adam import Adam
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler

Expand Down Expand Up @@ -286,7 +287,7 @@ def train(
self._device
)

optimizer = optim.Adam(
optimizer = Adam(
list(self._regression_net.parameters()),
lr=learning_rate,
)
Expand Down
10 changes: 2 additions & 8 deletions sbi/diagnostics/lc2st.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,15 +670,9 @@ def train_lc2st(
Returns:
Trained classifier.
"""
# cpu and numpy
theta_p = theta_p.cpu().numpy()
theta_q = theta_q.cpu().numpy()
x_p = x_p.cpu().numpy()
x_q = x_q.cpu().numpy()

# concatenate to get joint data
joint_p = np.concatenate([theta_p, x_p], axis=1)
joint_q = np.concatenate([theta_q, x_q], axis=1)
joint_p = np.concatenate([theta_p.cpu().numpy(), x_p.cpu().numpy()], axis=1)
joint_q = np.concatenate([theta_q.cpu().numpy(), x_q.cpu().numpy()], axis=1)

# prepare data
data = np.concatenate((joint_p, joint_q))
Expand Down
2 changes: 1 addition & 1 deletion sbi/diagnostics/tarp.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def check_tarp(
midindex = nentries // 2
atc = float((ecp[midindex:, ...] - alpha[midindex:, ...]).sum())

kstest_pvals = kstest(ecp.numpy(), alpha.numpy())[1]
kstest_pvals: float = kstest(ecp.numpy(), alpha.numpy())[1] # type: ignore

return atc, kstest_pvals

Expand Down
5 changes: 3 additions & 2 deletions sbi/inference/snle/snle_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import Any, Callable, Dict, Optional, Union

import torch
from torch import Tensor, optim
from torch import Tensor
from torch.distributions import Distribution
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.adam import Adam
from torch.utils.tensorboard.writer import SummaryWriter

from sbi.inference.base import NeuralInference
Expand Down Expand Up @@ -182,7 +183,7 @@ def train(

self._neural_net.to(self._device)
if not resume_training:
self.optimizer = optim.Adam(
self.optimizer = Adam(
list(self._neural_net.parameters()),
lr=learning_rate,
)
Expand Down
7 changes: 3 additions & 4 deletions sbi/inference/snpe/snpe_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
from warnings import warn

import torch
from torch import Tensor, ones, optim
from torch import Tensor, ones
from torch.distributions import Distribution
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.adam import Adam
from torch.utils.tensorboard.writer import SummaryWriter

from sbi.inference.base import NeuralInference, check_if_proposal_has_default_x
Expand Down Expand Up @@ -334,9 +335,7 @@ def default_calibration_kernel(x):
self._neural_net.to(self._device)

if not resume_training:
self.optimizer = optim.Adam(
list(self._neural_net.parameters()), lr=learning_rate
)
self.optimizer = Adam(list(self._neural_net.parameters()), lr=learning_rate)
self.epoch, self._val_log_prob = 0, float("-Inf")

while self.epoch <= max_num_epochs and not self._converged(
Expand Down
5 changes: 3 additions & 2 deletions sbi/inference/snre/snre_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from typing import Any, Callable, Dict, Optional, Union

import torch
from torch import Tensor, eye, nn, ones, optim
from torch import Tensor, eye, nn, ones
from torch.distributions import Distribution
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.adam import Adam
from torch.utils.tensorboard.writer import SummaryWriter

from sbi.inference.base import NeuralInference
Expand Down Expand Up @@ -207,7 +208,7 @@ def train(
self._neural_net.to(self._device)

if not resume_training:
self.optimizer = optim.Adam(
self.optimizer = Adam(
list(self._neural_net.parameters()),
lr=learning_rate,
)
Expand Down
9 changes: 8 additions & 1 deletion sbi/samplers/vi/vi_divergence_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
import torch
from torch import Tensor, nn
from torch.distributions import Distribution
from torch.optim import ASGD, SGD, Adadelta, Adagrad, Adam, AdamW, Adamax, RMSprop
from torch.optim.adadelta import Adadelta
from torch.optim.adagrad import Adagrad
from torch.optim.adam import Adam
from torch.optim.adamax import Adamax
from torch.optim.adamw import AdamW
from torch.optim.asgd import ASGD
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
CosineAnnealingWarmRestarts,
Expand All @@ -17,6 +22,8 @@
LambdaLR,
StepLR,
)
from torch.optim.rmsprop import RMSprop
from torch.optim.sgd import SGD

from sbi.inference.potentials.base_potential import BasePotential
from sbi.samplers.vi.vi_utils import (
Expand Down
5 changes: 3 additions & 2 deletions sbi/utils/restriction_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import torch
import torch.nn.functional as F
from pyknos.nflows.nn import nets
from torch import Tensor, nn, optim, relu
from torch import Tensor, nn, relu
from torch.distributions import Distribution
from torch.nn.utils.clip_grad import clip_grad_norm_
from torch.optim.adam import Adam
from torch.utils import data
from torch.utils.data.sampler import SubsetRandomSampler, WeightedRandomSampler

Expand Down Expand Up @@ -325,7 +326,7 @@ def train(
self._first_round_validation_theta = theta[val_indices]
self._first_round_validation_label = label[val_indices]

optimizer = optim.Adam(
optimizer = Adam(
list(self._classifier.parameters()),
lr=learning_rate,
)
Expand Down
5 changes: 3 additions & 2 deletions sbi/utils/sbiutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.distributions.transforms as torch_tf
import zuko
from pyro.distributions import Empirical
from torch import Tensor, ones, optim, zeros
from torch import Tensor, ones, zeros
from torch import nn as nn
from torch.distributions import (
AffineTransform,
Expand All @@ -22,6 +22,7 @@
biject_to,
constraints,
)
from torch.optim.adam import Adam
from zuko.flows import UnconditionalTransform

from sbi.sbi_types import TorchTransform
Expand Down Expand Up @@ -935,7 +936,7 @@ def gradient_ascent(

optimize_inits = theta_transform(optimize_inits)
optimize_inits.requires_grad_(True) # type: ignore
optimizer = optim.Adam([optimize_inits], lr=learning_rate) # type: ignore
optimizer = Adam([optimize_inits], lr=learning_rate) # type: ignore

iter_ = 0

Expand Down

0 comments on commit 923fb40

Please sign in to comment.