Skip to content

Commit

Permalink
speed up hipparcos unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
sblunt committed Jan 10, 2024
1 parent d54e660 commit a6c0693
Showing 1 changed file with 28 additions and 11 deletions.
39 changes: 28 additions & 11 deletions tests/test_hipparcos.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,13 +151,13 @@ def test_iad_refitting():
"""

post, myHipLogProb = nielsen_iad_refitting_test(
"{}/HIP027321.d".format(DATADIR), burn_steps=10, mcmc_steps=200, saveplot=None
"{}/HIP027321.d".format(DATADIR), burn_steps=10, mcmc_steps=50, saveplot=None
)

# check that we get reasonable values for the posteriors of the refit IAD
# (we're only running the MCMC for a few steps, so these are not strict)
assert np.isclose(0, np.median(post[:, -1]), atol=0.1)
assert np.isclose(myHipLogProb.plx0, np.median(post[:, 0]), atol=0.1)
assert np.isclose(0, np.median(post[:, -1]), atol=0.5)
assert np.isclose(myHipLogProb.plx0, np.median(post[:, 0]), atol=0.5)


def test_save_load_dvd():
Expand All @@ -180,13 +180,20 @@ def test_save_load_dvd():
data_table_with_rvs,
1.22,
56.95,
mass_err=0.08,
plx_err=0.26,
hipparcos_IAD=myHip,
fit_secondary_mass=True,
gaia=myGaia,
)
n_walkers = 50

# fix some values to speed up fit
mySys.sys_priors[0] = 1
mySys.sys_priors[1] = 0
mySys.sys_priors[2] = 0
mySys.sys_priors[3] = 0
mySys.sys_priors[4] = 0
mySys.sys_priors[5] = 0

n_walkers = 20
mySamp = sampler.MCMC(mySys, num_walkers=n_walkers)
mySamp.run_sampler(n_walkers, burn_steps=0)
filename = "tmp1.hdf5"
Expand All @@ -195,7 +202,7 @@ def test_save_load_dvd():
myResults = results.Results()
myResults.load_results(filename)

# os.system('rm tmp.hdf5')
os.system("rm tmp*.hdf5")


def test_save_load_2021():
Expand Down Expand Up @@ -224,7 +231,16 @@ def test_save_load_2021():
fit_secondary_mass=True,
gaia=myGaia,
)
n_walkers = 50

# fix some values to speed up fit
mySys.sys_priors[0] = 1
mySys.sys_priors[1] = 0
mySys.sys_priors[2] = 0
mySys.sys_priors[3] = 0
mySys.sys_priors[4] = 0
mySys.sys_priors[5] = 0

n_walkers = 20
mySamp = sampler.MCMC(mySys, num_walkers=n_walkers)
mySamp.run_sampler(n_walkers, burn_steps=0)
filename = "tmp2.hdf5"
Expand All @@ -233,12 +249,13 @@ def test_save_load_2021():
myResults = results.Results()
myResults.load_results(filename)

# os.system('rm tmp.hdf5')
os.system("rm tmp*.hdf5")


if __name__ == "__main__":
test_save_load_dvd()
test_save_load_2021()
# test_save_load_2021()
# test_hipparcos_api()
# test_iad_refitting()
# test_dvd_vs_2021catalog()

# test_iad_refitting()

0 comments on commit a6c0693

Please sign in to comment.