Skip to content

Commit

Permalink
Updated unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasstolker committed Sep 24, 2024
1 parent 827c44f commit 804377f
Showing 1 changed file with 40 additions and 9 deletions.
49 changes: 40 additions & 9 deletions tests/test_nested_sampler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Tests the NestedSampler class by fixing all parameters except for eccentricity.
Tests the NestedSampler and MultiNest classes by fixing all parameters except for eccentricity.
"""

from orbitize import system, sampler
Expand Down Expand Up @@ -40,25 +40,56 @@ def test_nested_sampler():

# run both static & dynamic nested samplers
mysampler = sampler.NestedSampler(sys)
_ = mysampler.run_sampler(bound="multi", pfrac=0.95, static=False, num_threads=8)
_ = mysampler.run_sampler(bound="multi", pfrac=0.95, static=False, num_threads=8, run_nested_kwargs={})
print("Finished first run!")

dynamic_eccentricities = mysampler.results.post[:, lab["ecc1"]]
assert np.median(dynamic_eccentricities) == pytest.approx(ecc, abs=0.1)

_ = mysampler.run_sampler(bound="multi", static=True, num_threads=8)
_ = mysampler.run_sampler(bound="multi", static=True, num_threads=8, run_nested_kwargs={})
print("Finished second run!")

static_eccentricities = mysampler.results.post[:, lab["ecc1"]]
assert np.median(static_eccentricities) == pytest.approx(ecc, abs=0.1)

# check that the static sampler raises an error when user tries to set pfrac
# for static sampler
try:
mysampler.run_sampler(pfrac=0.1, static=True)
except ValueError:
pass

def test_multinest():
# generate data
mtot = 1.2 # total system mass [M_sol]
plx = 60.0 # parallax [mas]
orbit_frac = 95
data_table, sma = generate_synthetic_data(
orbit_frac,
mtot,
plx,
num_obs=30,
)

# assumed ecc value
ecc = 0.5

# initialize orbitize `System` object
sys = system.System(1, data_table, mtot, plx)
lab = sys.param_idx

ecc = 0.5 # eccentricity

# set all parameters except eccentricity to fixed values (same as used to generate data)
sys.sys_priors[lab["inc1"]] = np.pi / 4
sys.sys_priors[lab["sma1"]] = sma
sys.sys_priors[lab["aop1"]] = np.pi / 4
sys.sys_priors[lab["pan1"]] = np.pi / 4
sys.sys_priors[lab["tau1"]] = 0.8
sys.sys_priors[lab["plx"]] = plx
sys.sys_priors[lab["mtot"]] = mtot

# running the actual sampler is not possible without compiling MultiNest
mysampler = sampler.MultiNest(sys)
assert hasattr(mysampler, "run_sampler")
assert hasattr(mysampler, "results")
assert hasattr(mysampler, "system")


if __name__ == "__main__":
test_nested_sampler()
test_multinest()

0 comments on commit 804377f

Please sign in to comment.