From fe71a6556aa6112c985a1d1350f5a7d69447fa3b Mon Sep 17 00:00:00 2001 From: mj-will Date: Fri, 31 May 2024 09:40:19 +0100 Subject: [PATCH 1/4] TST: add existing from bilby --- tests/test_sampler_class.py | 95 +++++++++++++++++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 tests/test_sampler_class.py diff --git a/tests/test_sampler_class.py b/tests/test_sampler_class.py new file mode 100644 index 0000000..2a71a4f --- /dev/null +++ b/tests/test_sampler_class.py @@ -0,0 +1,95 @@ +import bilby +import copy +from nessai_bilby.plugin import Nessai, ImportanceNessai +import os +import pytest +from unittest.mock import mock_open, patch + + +@pytest.fixture(params=[Nessai, ImportanceNessai]) +def SamplerClass(request): + return request.param + + +@pytest.fixture() +def create_sampler(SamplerClass, bilby_gaussian_likelihood_and_priors, tmp_path): + likelihood, priors = bilby_gaussian_likelihood_and_priors + + def create_fn(**kwargs): + return SamplerClass( + likelihood, + priors, + outdir=tmp_path / "outdir", + label="test", + use_ratio=False, + **kwargs + ) + return create_fn + + +@pytest.fixture +def sampler(create_sampler): + return create_sampler() + + +@pytest.fixture +def default_kwargs(sampler): + expected = copy.deepcopy(sampler.default_kwargs) + expected["output"] = os.path.join( + sampler.outdir, f"{sampler.label}_{sampler.sampler_name}", "" + ) + expected["seed"] = 12345 + return expected + + +@pytest.mark.parametrize( + "key", + bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs, +) +def test_translate_kwargs_nlive(create_sampler, key): + sampler = create_sampler(**{key: 1000}) + assert sampler.kwargs["nlive"] == 1000 + + +@pytest.mark.parametrize( + "key", + bilby.core.sampler.base_sampler.NestedSampler.npool_equiv_kwargs, +) +def test_translate_kwargs_npool(create_sampler, key): + sampler = create_sampler(**{key: 2}) + assert sampler.kwargs["n_pool"] == 2 + + +def test_split_kwargs(sampler): + kwargs, run_kwargs = sampler.split_kwargs() + assert "save" not in run_kwargs + assert "plot" in run_kwargs + + +def test_translate_kwargs_no_npool(create_sampler): + sampler = create_sampler() + assert sampler.kwargs["n_pool"] == 1 + + +def test_translate_kwargs_seed(create_sampler): + sampler = create_sampler(sampling_seed=150914) + assert sampler.kwargs["seed"] == 150914 + + +@patch("builtins.open", mock_open(read_data='{"nlive": 4000}')) +def test_update_from_config_file(create_sampler): + sampler = create_sampler(config_file="config_file.json") + assert sampler.kwargs["nlive"] == 4000 + + +def test_expected_outputs(SamplerClass): + expected = os.path.join("outdir", f"test_{SamplerClass.sampler_name}", "") + filenames, dirs = SamplerClass.get_expected_outputs( + outdir="outdir", + label="test", + ) + assert len(filenames) == 0 + assert len(dirs) == 3 + assert dirs[0] == expected + assert dirs[1] == os.path.join(expected, "proposal", "") + assert dirs[2] == os.path.join(expected, "diagnostics", "") From ee944f684988edd7cc52f6fa6c50fdc06a443013 Mon Sep 17 00:00:00 2001 From: mj-will Date: Fri, 31 May 2024 09:42:54 +0100 Subject: [PATCH 2/4] STY: apply black --- src/nessai_bilby/__init__.py | 1 + src/nessai_bilby/model.py | 1 + src/nessai_bilby/plugin.py | 5 ++++- tests/test_bilby_integration.py | 1 + tests/test_model.py | 4 +--- tests/test_sampler_class.py | 7 +++++-- 6 files changed, 13 insertions(+), 6 deletions(-) diff --git a/src/nessai_bilby/__init__.py b/src/nessai_bilby/__init__.py index 6e5c0ef..50350f5 100644 --- a/src/nessai_bilby/__init__.py +++ b/src/nessai_bilby/__init__.py @@ -2,6 +2,7 @@ Includes support for standard nessai and inessai. """ + from importlib.metadata import PackageNotFoundError, version try: diff --git a/src/nessai_bilby/model.py b/src/nessai_bilby/model.py index acce087..9750068 100644 --- a/src/nessai_bilby/model.py +++ b/src/nessai_bilby/model.py @@ -1,4 +1,5 @@ """Utilities for using nessai with external packages""" + from nessai.livepoint import dict_to_live_points from nessai.model import Model import numpy as np diff --git a/src/nessai_bilby/plugin.py b/src/nessai_bilby/plugin.py index b53ea06..ca451af 100644 --- a/src/nessai_bilby/plugin.py +++ b/src/nessai_bilby/plugin.py @@ -1,4 +1,5 @@ """Interface for nessai in bilby""" + import os import sys @@ -285,7 +286,9 @@ def get_expected_outputs(cls, outdir=None, label=None): List of directory names. """ dirs = [os.path.join(outdir, f"{label}_{cls.sampler_name}", "")] - dirs += [os.path.join(dirs[0], d, "") for d in ["proposal", "diagnostics"]] + dirs += [ + os.path.join(dirs[0], d, "") for d in ["proposal", "diagnostics"] + ] filenames = [] return filenames, dirs diff --git a/tests/test_bilby_integration.py b/tests/test_bilby_integration.py index 7f422fb..e747542 100644 --- a/tests/test_bilby_integration.py +++ b/tests/test_bilby_integration.py @@ -1,4 +1,5 @@ """Test the integration with bilby""" + import bilby import pytest diff --git a/tests/test_model.py b/tests/test_model.py index 9c8cfbb..80ef971 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -19,9 +19,7 @@ def test_create_model(ModelClass, bilby_gaussian_likelihood_and_priors): def test_sample_model_with_nessai( - bilby_gaussian_likelihood_and_priors, - tmp_path, - ModelClass + bilby_gaussian_likelihood_and_priors, tmp_path, ModelClass ): likelihood, priors = bilby_gaussian_likelihood_and_priors priors = bilby.core.prior.PriorDict(priors) diff --git a/tests/test_sampler_class.py b/tests/test_sampler_class.py index 2a71a4f..587a3d7 100644 --- a/tests/test_sampler_class.py +++ b/tests/test_sampler_class.py @@ -12,7 +12,9 @@ def SamplerClass(request): @pytest.fixture() -def create_sampler(SamplerClass, bilby_gaussian_likelihood_and_priors, tmp_path): +def create_sampler( + SamplerClass, bilby_gaussian_likelihood_and_priors, tmp_path +): likelihood, priors = bilby_gaussian_likelihood_and_priors def create_fn(**kwargs): @@ -22,8 +24,9 @@ def create_fn(**kwargs): outdir=tmp_path / "outdir", label="test", use_ratio=False, - **kwargs + **kwargs, ) + return create_fn From a237be315c483e6ea9202207031a3158d94c4ff6 Mon Sep 17 00:00:00 2001 From: mj-will Date: Fri, 31 May 2024 09:43:07 +0100 Subject: [PATCH 3/4] CI: add linting workflow --- .github/workflows/lint.yml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .github/workflows/lint.yml diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 0000000..df5ceec --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,23 @@ +name: Lint + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + black: + name: Black + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + with: + options: "--check --diff" + src: "./src" + version: "~= 24.0" \ No newline at end of file From 235e14c8b09da017c43d8878b5c99c7aebc5e04b Mon Sep 17 00:00:00 2001 From: mj-will Date: Fri, 31 May 2024 09:44:21 +0100 Subject: [PATCH 4/4] STY: add missing empty line --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index df5ceec..8eee39a 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -20,4 +20,4 @@ jobs: with: options: "--check --diff" src: "./src" - version: "~= 24.0" \ No newline at end of file + version: "~= 24.0"