Skip to content

Commit

Permalink
Merge pull request #7 from mj-will/testing-ci-updates
Browse files Browse the repository at this point in the history
Testing and CI updates
  • Loading branch information
mj-will authored May 31, 2024
2 parents bc66258 + 235e14c commit 6898b0f
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 4 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Lint

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true

jobs:
black:
name: Black
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: psf/black@stable
with:
options: "--check --diff"
src: "./src"
version: "~= 24.0"
1 change: 1 addition & 0 deletions src/nessai_bilby/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Includes support for standard nessai and inessai.
"""

from importlib.metadata import PackageNotFoundError, version

try:
Expand Down
1 change: 1 addition & 0 deletions src/nessai_bilby/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Utilities for using nessai with external packages"""

from nessai.livepoint import dict_to_live_points
from nessai.model import Model
import numpy as np
Expand Down
5 changes: 4 additions & 1 deletion src/nessai_bilby/plugin.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Interface for nessai in bilby"""

import os
import sys

Expand Down Expand Up @@ -285,7 +286,9 @@ def get_expected_outputs(cls, outdir=None, label=None):
List of directory names.
"""
dirs = [os.path.join(outdir, f"{label}_{cls.sampler_name}", "")]
dirs += [os.path.join(dirs[0], d, "") for d in ["proposal", "diagnostics"]]
dirs += [
os.path.join(dirs[0], d, "") for d in ["proposal", "diagnostics"]
]
filenames = []
return filenames, dirs

Expand Down
1 change: 1 addition & 0 deletions tests/test_bilby_integration.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test the integration with bilby"""

import bilby
import pytest

Expand Down
4 changes: 1 addition & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@ def test_create_model(ModelClass, bilby_gaussian_likelihood_and_priors):


def test_sample_model_with_nessai(
bilby_gaussian_likelihood_and_priors,
tmp_path,
ModelClass
bilby_gaussian_likelihood_and_priors, tmp_path, ModelClass
):
likelihood, priors = bilby_gaussian_likelihood_and_priors
priors = bilby.core.prior.PriorDict(priors)
Expand Down
98 changes: 98 additions & 0 deletions tests/test_sampler_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import bilby
import copy
from nessai_bilby.plugin import Nessai, ImportanceNessai
import os
import pytest
from unittest.mock import mock_open, patch


@pytest.fixture(params=[Nessai, ImportanceNessai])
def SamplerClass(request):
return request.param


@pytest.fixture()
def create_sampler(
SamplerClass, bilby_gaussian_likelihood_and_priors, tmp_path
):
likelihood, priors = bilby_gaussian_likelihood_and_priors

def create_fn(**kwargs):
return SamplerClass(
likelihood,
priors,
outdir=tmp_path / "outdir",
label="test",
use_ratio=False,
**kwargs,
)

return create_fn


@pytest.fixture
def sampler(create_sampler):
return create_sampler()


@pytest.fixture
def default_kwargs(sampler):
expected = copy.deepcopy(sampler.default_kwargs)
expected["output"] = os.path.join(
sampler.outdir, f"{sampler.label}_{sampler.sampler_name}", ""
)
expected["seed"] = 12345
return expected


@pytest.mark.parametrize(
"key",
bilby.core.sampler.base_sampler.NestedSampler.npoints_equiv_kwargs,
)
def test_translate_kwargs_nlive(create_sampler, key):
sampler = create_sampler(**{key: 1000})
assert sampler.kwargs["nlive"] == 1000


@pytest.mark.parametrize(
"key",
bilby.core.sampler.base_sampler.NestedSampler.npool_equiv_kwargs,
)
def test_translate_kwargs_npool(create_sampler, key):
sampler = create_sampler(**{key: 2})
assert sampler.kwargs["n_pool"] == 2


def test_split_kwargs(sampler):
kwargs, run_kwargs = sampler.split_kwargs()
assert "save" not in run_kwargs
assert "plot" in run_kwargs


def test_translate_kwargs_no_npool(create_sampler):
sampler = create_sampler()
assert sampler.kwargs["n_pool"] == 1


def test_translate_kwargs_seed(create_sampler):
sampler = create_sampler(sampling_seed=150914)
assert sampler.kwargs["seed"] == 150914


@patch("builtins.open", mock_open(read_data='{"nlive": 4000}'))
def test_update_from_config_file(create_sampler):
sampler = create_sampler(config_file="config_file.json")
assert sampler.kwargs["nlive"] == 4000


def test_expected_outputs(SamplerClass):
expected = os.path.join("outdir", f"test_{SamplerClass.sampler_name}", "")
filenames, dirs = SamplerClass.get_expected_outputs(
outdir="outdir",
label="test",
)
assert len(filenames) == 0
assert len(dirs) == 3
assert dirs[0] == expected
assert dirs[1] == os.path.join(expected, "proposal", "")
assert dirs[2] == os.path.join(expected, "diagnostics", "")

0 comments on commit 6898b0f

Please sign in to comment.