Skip to content

Commit

Permalink
fix nbody indexing error & add unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
sblunt committed Feb 12, 2024
1 parent 6bfa14b commit e4f3cf9
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 21 deletions.
25 changes: 9 additions & 16 deletions orbitize/nbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,14 @@ def calc_orbit(
Returns:
3-tuple:
raoff (np.array): array-like (n_dates x n_orbs) of RA offsets between
raoff (np.array): array-like (n_dates x n_bodies x n_orbs) of RA offsets between
the bodies (origin is at the other body) [mas]
deoff (np.array): array-like (n_dates x n_orbs) of Dec offsets between
deoff (np.array): array-like (n_dates x n_bodies x n_orbs) of Dec offsets between
the bodies [mas]
vz (np.array): array-like (n_dates x n_orbs) of radial velocity of
vz (np.array): array-like (n_dates x n_bodies x n_orbs) of radial velocity of
one of the bodies (see `mass_for_Kamp` description) [km/s]
.. Note::
return is in format [raoff[planet1, planet2,...,planetn],
deoff[planet1, planet2,...,planetn], vz[planet1, planet2,...,planetn]
"""

sim = rebound.Simulation() #creating the simulation in Rebound
Expand Down Expand Up @@ -109,11 +104,9 @@ def calc_orbit(
raoff = plx*ra_reb
deoff = plx*dec_reb

#for formatting purposes
if len(sma)==1:
raoff = raoff.reshape(tx,)
deoff = deoff.reshape(tx,)
vz = vz.reshape(tx,)
return raoff, deoff, vz
else:
return raoff, deoff, vz
# always assume we're using MCMC (i.e. n_orbits = 1)
raoff = raoff.reshape((tx, indv + 1, 1))
deoff = deoff.reshape((tx, indv + 1, 1))
vz = vz.reshape((tx,indv + 1, 1))

return raoff, deoff, vz
34 changes: 29 additions & 5 deletions tests/test_rebound.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from matplotlib import pyplot as plt
from astropy.time import Time
from orbitize import sampler
from orbitize.system import System
from orbitize import DATADIR
from orbitize.read_input import read_file
Expand Down Expand Up @@ -125,8 +126,8 @@ def test_8799_rebound_vs_kepler(plotname=None):
params_arr, epochs=epochs, comp_rebound=False
)

delta_ra = abs(rra - kra[:, :, 0])
delta_de = abs(rde - kde[:, :, 0])
delta_ra = abs(rra - kra)
delta_de = abs(rde - kde)
yepochs = Time(epochs, format="mjd").decimalyear

# check that the difference between these two solvers is smaller than
Expand Down Expand Up @@ -179,8 +180,8 @@ def test_8799_rebound_vs_kepler(plotname=None):
plt.plot(kra[:, 1:5, 0], kde[:, 1:5, 0], "indigo", label="Orbitize approx.")
plt.plot(kra[-1, 1:5, 0], kde[-1, 1:5, 0], "o")

plt.plot(rra, rde, "r", label="Rebound", alpha=0.25)
plt.plot(rra[-1], rde[-1], "o", alpha=0.25)
plt.plot(rra[:, 1:5, 0], rde[:, 1:5, 0], "r", label="Rebound", alpha=0.25)
plt.plot(rra[-1, 1:5, 0], rde[-1, 1:5, 0], "o", alpha=0.25)

plt.plot(0, 0, "*")
plt.legend()
Expand All @@ -199,5 +200,28 @@ def test_8799_rebound_vs_kepler(plotname=None):
plt.savefig("{}_primaryorbittrack.png".format(plotname), dpi=250)


def test_rebound_mcmc():
"""
Test that a 2-body rebound fit runs through one MCMC iteration successfully.
"""

input_file = "{}/test_val_multi.csv".format(DATADIR)
data_table = read_file(input_file)

my_sys = System(num_secondary_bodies=2,
use_rebound=True,
fit_secondary_mass=True,
data_table=data_table,
stellar_or_system_mass=1.0, mass_err=0,
plx=1.0, plx_err=0.0
)


my_mcmc_samp = sampler.MCMC(my_sys, num_temps=1, num_walkers=20, num_threads=1)
my_mcmc_samp.run_sampler(5, burn_steps=1)



if __name__ == "__main__":
test_8799_rebound_vs_kepler(plotname="hr8799_diffs")
# test_8799_rebound_vs_kepler(plotname="hr8799_diffs")
test_rebound_mcmc()

0 comments on commit e4f3cf9

Please sign in to comment.