Skip to content

Commit

Permalink
increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
qacwnfq committed Sep 8, 2024
1 parent 1df5756 commit 2b15e66
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 33 deletions.
12 changes: 6 additions & 6 deletions anesthetic/read/dnest4.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from anesthetic.samples import DiffusiveNestedSamples


def _determine_columns_and_labels(n_params, header, delim=' '):
def _determine_columns(n_params, header, delim=' '):
"""
Determine column names from DNest4 output.
Expand All @@ -17,8 +17,7 @@ def _determine_columns_and_labels(n_params, header, delim=' '):
columns = [f'x_{i}' for i in range(n_params)]
else:
columns = [d.strip() for d in dnest4_column_descriptions]
labels = {c: '$' + c + '$' for c in columns}
return columns, labels
return columns


def read_dnest4(root,
Expand Down Expand Up @@ -65,9 +64,10 @@ def read_dnest4(root,
logL_birth = levels[sample_level, 1]

kwargs['label'] = kwargs.get('label', os.path.basename(root))
columns, labels = _determine_columns_and_labels(n_params, header)
columns = kwargs.pop('columns', columns)
labels = kwargs.pop('labels', labels)
columns_ = _determine_columns(n_params, header)
columns = kwargs.pop('columns', columns_)
labels_ = {c: '$' + c + '$' for c in columns}
labels = kwargs.pop('labels', labels_)

return DiffusiveNestedSamples(sample_info=sample_info,
levels=levels,
Expand Down
39 changes: 18 additions & 21 deletions tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,20 @@ def close_figures_on_teardown():
plt.close("all")


@pytest.mark.parametrize('root', ["./tests/example_data/pc",
"./tests/example_data/mn",
skipif_no_h5py("./tests/example_data/un"),
"./tests/example_data/nf"])
@pytest.mark.parametrize('root', [
"./tests/example_data/pc",
"./tests/example_data/mn",
skipif_no_h5py("./tests/example_data/un"),
"./tests/example_data/nf",
"./tests/example_data/dnest4"])
def test_gui(root):
samples = read_chains(root)
plotter = samples.gui()

# Type buttons
plotter.type.buttons.set_active(1)
assert plotter.type() == 'posterior'
plotter.type.buttons.set_active(0)
assert plotter.type() == 'live'
for i, plot_type in enumerate(samples.plot_types()):
plotter.type.buttons.set_active(i)
assert plotter.type() == plot_type

# Parameter choice buttons
plotter.param_choice.buttons.set_active(1)
Expand All @@ -32,28 +33,24 @@ def test_gui(root):
assert len(plotter.triangle.ax) == 1
plotter.param_choice.buttons.set_active(0)
plotter.param_choice.buttons.set_active(2)
plotter.param_choice.buttons.set_active(3)
assert len(plotter.triangle.ax) == 4
assert len(plotter.triangle.ax) == 3

# Sliders
old = plotter.evolution()
plotter.evolution.slider.set_val(5)
assert plotter.evolution() != old
old = plotter.evolution()
plotter.evolution.slider.set_val(5.5)
assert plotter.evolution() != old
old = plotter.evolution()
plotter.evolution.slider.set_val(0)
assert plotter.evolution() == old
plotter.type.buttons.set_active(1)
if len(samples.plot_types()) > 1:
plotter.type.buttons.set_active(1)

plotter.beta.slider.set_val(0)
assert plotter.beta() == pytest.approx(0, 0, 1e-8)
plotter.beta.slider.set_val(0)
assert plotter.beta() == pytest.approx(0, 0, 1e-8)

plotter.beta.slider.set_val(samples.D_KL())
assert plotter.beta() == pytest.approx(1)
plotter.beta.slider.set_val(1e2)
assert plotter.beta() == 1e10
plotter.beta.slider.set_val(samples.D_KL())
assert plotter.beta() == pytest.approx(1)
plotter.beta.slider.set_val(1e2)
assert plotter.beta() == 1e10
plotter.type.buttons.set_active(0)

# Reload button
Expand Down
33 changes: 27 additions & 6 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,17 +378,38 @@ def test_read_dnest():
def test_read_dnest4_no_column_names():
np.random.seed(3)
ns = read_dnest4('./tests/example_data/dnest4_no_column_names')
params = ['x0', 'x1', 'logL', 'logL_birth', 'nlive']
params = ['x_0', 'x_1', 'logL', 'logL_birth', 'nlive']
assert_array_equal(ns.drop_labels().columns, params)
labels = [r'$x0$',
r'$x1$',
labels = [r'$x_0$',
r'$x_1$',
r'$\ln\mathcal{L}$',
r'$\ln\mathcal{L}_\mathrm{birth}$',
r'$n_\mathrm{live}$']

assert_array_equal(ns.get_labels(), labels)

assert isinstance(ns, DiffusiveNestedSamples)
assert ns.samples_at_level(9, label='x1').shape == (45, 1)
ns.plot_2d(['x0', 'x1'])
ns.plot_1d(['x0', 'x1'])
assert ns.samples_at_level(9, label='x_1').shape == (45, 1)
ns.plot_2d(['x_0', 'x_1'])
ns.plot_1d(['x_0', 'x_1'])


def test_read_dnest4_override_column_names():
np.random.seed(3)
columns = ['y0', 'y1']
ns = read_dnest4('./tests/example_data/dnest4_no_column_names',
columns=columns)
params = ['y0', 'y1', 'logL', 'logL_birth', 'nlive']
assert_array_equal(ns.drop_labels().columns, params)
labels = [r'$y0$',
r'$y1$',
r'$\ln\mathcal{L}$',
r'$\ln\mathcal{L}_\mathrm{birth}$',
r'$n_\mathrm{live}$']

assert_array_equal(ns.get_labels(), labels)

assert isinstance(ns, DiffusiveNestedSamples)
assert ns.samples_at_level(9, label='y0').shape == (45, 1)
ns.plot_2d(['y0', 'y1'])
ns.plot_1d(['y0', 'y1'])

0 comments on commit 2b15e66

Please sign in to comment.