Skip to content

Commit

Permalink
build: drop py3.8 and torch1.12; update deps. fix slow tests. (#1233)
Browse files Browse the repository at this point in the history
* add tutorials test; fix slow tests, typing.

* build: drop python3.8 and torch 1.12, update deps.

- add movie writer to dev deps
- downgrade scipy to enable py3.9 support
  • Loading branch information
janfb authored Aug 27, 2024
1 parent 5584f13 commit 9648aff
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/cd.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.9'

- name: Cache dependency
id: cache-dependencies
Expand Down
7 changes: 3 additions & 4 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ['3.8']
torch-version: ['1.11', '2.2']
python-version: ['3.9', '3.12']

steps:
- name: Checkout
Expand All @@ -40,15 +39,15 @@ jobs:
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ matrix.python-version }}-${{ matrix.torch-version }}$
key: ${{ runner.os }}-pip-${{ matrix.python-version }}
restore-keys: |
${{ runner.os }}-pip-${{ matrix.python-version }}-
${{ runner.os }}-pip-
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install -e .[dev]
- name: Run the fast CPU tests with coverage
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.9'
- uses: pre-commit/[email protected]
with:
extra_args: --all-files --show-diff-on-failure
Expand All @@ -40,14 +40,14 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.8'
python-version: '3.9'

- name: Cache dependency
id: cache-dependencies
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ubuntu-latest-pip-3.8
key: ubuntu-latest-pip-3.9
restore-keys: |
ubuntu-latest-pip-
Expand Down
15 changes: 8 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
"Programming Language :: Python :: 3",
"Development Status :: 3 - Alpha",
]
requires-python = ">=3.8"
requires-python = ">=3.9"
dynamic = ["version"]
readme = "README.md"
keywords = ["Bayesian inference", "simulation-based inference", "PyTorch"]
Expand All @@ -33,14 +33,14 @@ dependencies = [
"joblib>=1.0.0",
"jupyter",
"matplotlib",
"numpy",
"numpy<2.0.0",
"pillow",
"pyknos>=0.16.0",
"pyro-ppl>=1.3.1",
"scikit-learn",
"scipy",
"scipy<1.13",
"tensorboard",
"torch>=1.8.0",
"torch>=1.13.0",
"tqdm",
"pymc>=5.0.0",
"zuko>=1.2.0",
Expand All @@ -61,6 +61,7 @@ doc = [
"mike"
]
dev = [
"ffmpeg",
# Lint
"pre-commit == 3.5.0",
"pyyaml",
Expand Down Expand Up @@ -134,9 +135,9 @@ xfail_strict = true

# Pyright configuration
[tool.pyright]
include = ["sbi", "tests"]
exclude = ["**/__pycache__", "**/__node_modules__", ".git", "docs", "examples", "tutorials", "tests"]
python_version = "3.8"
include = ["sbi"]
exclude = ["**/__pycache__", "**/__node_modules__", ".git", "docs", "tutorials", "tests"]
python_version = "3.9"
reportUnsupportedDunderAll = false
reportGeneralTypeIssues = false
reportInvalidTypeForm = false
Expand Down
13 changes: 8 additions & 5 deletions sbi/analysis/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,10 +179,10 @@ def plt_kde_2d(
ax.imshow(
Z,
extent=(
limits_col[0],
limits_col[1],
limits_row[0],
limits_row[1],
limits_col[0].item(),
limits_col[1].item(),
limits_row[0].item(),
limits_row[1].item(),
),
**offdiag_kwargs["mpl_kwargs"],
)
Expand Down Expand Up @@ -350,7 +350,7 @@ def get_offdiag_funcs(
def _format_subplot(
ax: Axes,
current: str,
limits: Union[List, torch.Tensor],
limits: Union[List[List[float]], torch.Tensor],
ticks: Optional[Union[List, torch.Tensor]],
labels_dim: List[str],
fig_kwargs: Dict,
Expand Down Expand Up @@ -384,6 +384,9 @@ def _format_subplot(
):
ax.set_facecolor(fig_kwargs["fig_bg_colors"][current])
# Limits
if isinstance(limits, Tensor):
assert limits.dim() == 2, "Limits should be a 2D tensor."
limits = limits.tolist()
if current == "diag":
eps = fig_kwargs["x_lim_add_eps"]
ax.set_xlim((limits[col][0] - eps, limits[col][1] + eps))
Expand Down
2 changes: 1 addition & 1 deletion tests/sbc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def test_consistent_sbc_results(density_estimator, cov_method):
def simulator(theta):
return linear_gaussian(theta, likelihood_shift, likelihood_cov)

num_simulations = 2000
num_simulations = 4000
num_posterior_samples = 1000
num_sbc_runs = 100

Expand Down
39 changes: 39 additions & 0 deletions tests/tutorials_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os

import nbformat
import pytest
from nbclient.exceptions import CellExecutionError
from nbconvert.preprocessors import ExecutePreprocessor


def list_notebooks(directory: str) -> list:
"""Return sorted list of all notebooks in a directory."""
notebooks = [
os.path.join(directory, f)
for f in os.listdir(directory)
if f.endswith("distributions.ipynb")
]
return sorted(notebooks)


@pytest.mark.slow
@pytest.mark.parametrize("notebook_path", list_notebooks("tutorials/"))
def test_tutorials(notebook_path):
"""Test that all notebooks in the tutorials directory can be executed."""
with open(notebook_path) as f:
nb = nbformat.read(f, as_version=4)
ep = ExecutePreprocessor(timeout=600, kernel_name='python3')
print(f"Executing notebook {notebook_path}")
try:
ep.preprocess(nb, {'metadata': {'path': os.path.dirname(notebook_path)}})
except CellExecutionError as e:
# Conditional_distributions tutorial requires movie writer that is failing
# in GitHub CI.
if "Requested MovieWriter" in str(e):
print("Skipping error in movie writer.")
else:
raise CellExecutionError from e
except Exception as e:
raise AssertionError(
f"Error executing the notebook {notebook_path}: {e}"
) from e

0 comments on commit 9648aff

Please sign in to comment.