From 804377f59337f3c31840d6cb852d4d8aca23afad Mon Sep 17 00:00:00 2001 From: Tomas Stolker Date: Tue, 24 Sep 2024 15:35:17 +0200 Subject: [PATCH] Updated unit tests --- tests/test_nested_sampler.py | 49 +++++++++++++++++++++++++++++------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/tests/test_nested_sampler.py b/tests/test_nested_sampler.py index a6130c35..a00fb8ca 100644 --- a/tests/test_nested_sampler.py +++ b/tests/test_nested_sampler.py @@ -1,5 +1,5 @@ """ -Tests the NestedSampler class by fixing all parameters except for eccentricity. +Tests the NestedSampler and MultiNest classes by fixing all parameters except for eccentricity. """ from orbitize import system, sampler @@ -40,25 +40,56 @@ def test_nested_sampler(): # run both static & dynamic nested samplers mysampler = sampler.NestedSampler(sys) - _ = mysampler.run_sampler(bound="multi", pfrac=0.95, static=False, num_threads=8) + _ = mysampler.run_sampler(bound="multi", pfrac=0.95, static=False, num_threads=8, run_nested_kwargs={}) print("Finished first run!") dynamic_eccentricities = mysampler.results.post[:, lab["ecc1"]] assert np.median(dynamic_eccentricities) == pytest.approx(ecc, abs=0.1) - _ = mysampler.run_sampler(bound="multi", static=True, num_threads=8) + _ = mysampler.run_sampler(bound="multi", static=True, num_threads=8, run_nested_kwargs={}) print("Finished second run!") static_eccentricities = mysampler.results.post[:, lab["ecc1"]] assert np.median(static_eccentricities) == pytest.approx(ecc, abs=0.1) - # check that the static sampler raises an error when user tries to set pfrac - # for static sampler - try: - mysampler.run_sampler(pfrac=0.1, static=True) - except ValueError: - pass + +def test_multinest(): + # generate data + mtot = 1.2 # total system mass [M_sol] + plx = 60.0 # parallax [mas] + orbit_frac = 95 + data_table, sma = generate_synthetic_data( + orbit_frac, + mtot, + plx, + num_obs=30, + ) + + # assumed ecc value + ecc = 0.5 + + # initialize orbitize `System` object + sys = system.System(1, data_table, mtot, plx) + lab = sys.param_idx + + ecc = 0.5 # eccentricity + + # set all parameters except eccentricity to fixed values (same as used to generate data) + sys.sys_priors[lab["inc1"]] = np.pi / 4 + sys.sys_priors[lab["sma1"]] = sma + sys.sys_priors[lab["aop1"]] = np.pi / 4 + sys.sys_priors[lab["pan1"]] = np.pi / 4 + sys.sys_priors[lab["tau1"]] = 0.8 + sys.sys_priors[lab["plx"]] = plx + sys.sys_priors[lab["mtot"]] = mtot + + # running the actual sampler is not possible without compiling MultiNest + mysampler = sampler.MultiNest(sys) + assert hasattr(mysampler, "run_sampler") + assert hasattr(mysampler, "results") + assert hasattr(mysampler, "system") if __name__ == "__main__": test_nested_sampler() + test_multinest()