Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental implementation of SB2 model #131

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions thejoker/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
"phase_coverage_per_period",
]

# SB2:
from .thejoker_sb2 import *
from .prior_sb2 import JokerSB2Prior


__bibtex__ = __citation__ = """@ARTICLE{thejoker,
author = {{Price-Whelan}, Adrian M. and {Hogg}, David W. and
Expand All @@ -55,3 +59,12 @@
adsnote = {Provided by the SAO/NASA Astrophysics Data System}
}
"""

__all__ = [
'TheJoker',
'RVData',
'JokerSamples',
'JokerPrior',
'plot_rv_curves',
'TheJokerSB2'
]
23 changes: 13 additions & 10 deletions thejoker/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,13 @@ class RVData:
(days). Set to ``False`` to disable subtracting the reference time.
clean : bool (optional)
Filter out any NaN or Inf data points.
sort : bool (optional)
Whether or not to sort on time.

"""

@u.quantity_input(rv=u.km / u.s, rv_err=[u.km / u.s, (u.km / u.s) ** 2])
def __init__(self, t, rv, rv_err, t_ref=None, clean=True):
def __init__(self, t, rv, rv_err, t_ref=None, clean=True, sort=True):
# For speed, time is saved internally as BMJD:
if isinstance(t, Time):
_t_bmjd = t.tcb.mjd
Expand Down Expand Up @@ -94,15 +96,16 @@ def __init__(self, t, rv, rv_err, t_ref=None, clean=True):
else:
self.rv_err = self.rv_err[idx]

# sort on times
idx = self._t_bmjd.argsort()
self._t_bmjd = self._t_bmjd[idx]
self.rv = self.rv[idx]
if self._has_cov:
self.rv_err = self.rv_err[idx]
self.rv_err = self.rv_err[:, idx]
else:
self.rv_err = self.rv_err[idx]
if sort:
# sort on times
idx = self._t_bmjd.argsort()
self._t_bmjd = self._t_bmjd[idx]
self.rv = self.rv[idx]
if self._has_cov:
self.rv_err = self.rv_err[idx]
self.rv_err = self.rv_err[:, idx]
else:
self.rv_err = self.rv_err[idx]

if t_ref is False:
self.t_ref = None
Expand Down
15 changes: 12 additions & 3 deletions thejoker/likelihood_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,13 @@ def marginal_ln_likelihood_inmem(joker_helper, prior_samples_batch):
return np.array(ll)


def make_full_samples_inmem(joker_helper, prior_samples_batch, rng, n_linear_samples=1):
from .samples import JokerSamples
def make_full_samples_inmem(
joker_helper, prior_samples_batch, rng, n_linear_samples=1, SamplesCls=None
):
if SamplesCls is None:
from .samples import JokerSamples

SamplesCls = JokerSamples

if prior_samples_batch.dtype != np.float64:
prior_samples_batch = prior_samples_batch.astype(np.float64)
Expand All @@ -77,7 +82,7 @@ def make_full_samples_inmem(joker_helper, prior_samples_batch, rng, n_linear_sam
)

# unpack the raw samples
samples = JokerSamples.unpack(
samples = SamplesCls.unpack(
raw_samples,
joker_helper.internal_units,
t_ref=joker_helper.data.t_ref,
Expand All @@ -96,6 +101,7 @@ def rejection_sample_inmem(
max_posterior_samples=None,
n_linear_samples=1,
return_all_logprobs=False,
SamplesCls=None,
):
if max_posterior_samples is None:
max_posterior_samples = len(prior_samples_batch)
Expand All @@ -114,6 +120,7 @@ def rejection_sample_inmem(
prior_samples_batch[good_samples_idx],
rng,
n_linear_samples=n_linear_samples,
SamplesCls=SamplesCls,
)

if ln_prior is not None and ln_prior is not False:
Expand All @@ -136,6 +143,7 @@ def iterative_rejection_inmem(
init_batch_size=None,
growth_factor=128,
n_linear_samples=1,
SamplesCls=None,
):
n_total_samples = len(prior_samples_batch)

Expand Down Expand Up @@ -219,6 +227,7 @@ def iterative_rejection_inmem(
prior_samples_batch[full_samples_idx],
rng,
n_linear_samples=n_linear_samples,
SamplesCls=SamplesCls,
)

# FIXME: copy-pasted from function above
Expand Down
9 changes: 7 additions & 2 deletions thejoker/multiproc_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def make_full_samples(
samples_idx,
n_linear_samples=1,
n_batches=None,
SamplesCls=JokerSamples,
):
task_args = (prior_samples_file, joker_helper, n_linear_samples)
results = run_worker(
Expand All @@ -164,14 +165,14 @@ def make_full_samples(
task_args=task_args,
n_batches=n_batches,
samples_idx=samples_idx,
rng=rng,
random_state=rng,
)

# Concatenate all of the raw samples arrays
raw_samples = np.concatenate(results)

# unpack the raw samples
samples = JokerSamples.unpack(
samples = SamplesCls.unpack(
raw_samples,
joker_helper.internal_units,
t_ref=joker_helper.data.t_ref,
Expand All @@ -195,6 +196,7 @@ def rejection_sample_helper(
n_batches=None,
randomize_prior_order=False,
return_all_logprobs=False,
SamplesCls=None,
):
# Total number of samples in the cache:
with tb.open_file(prior_samples_file, mode="r") as f:
Expand Down Expand Up @@ -271,6 +273,7 @@ def rejection_sample_helper(
full_samples_idx,
n_linear_samples=n_linear_samples,
n_batches=n_batches,
SamplesCls=SamplesCls,
)

if return_logprobs:
Expand Down Expand Up @@ -300,6 +303,7 @@ def iterative_rejection_helper(
return_logprobs=False,
n_batches=None,
randomize_prior_order=False,
SamplesCls=None,
):
# Total number of samples in the cache:
with tb.open_file(prior_samples_file, mode="r") as f:
Expand Down Expand Up @@ -412,6 +416,7 @@ def iterative_rejection_helper(
full_samples_idx,
n_linear_samples=n_linear_samples,
n_batches=n_batches,
SamplesCls=SamplesCls,
)

# FIXME: copy-pasted from function above
Expand Down
100 changes: 64 additions & 36 deletions thejoker/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def _validate_model(model):


class JokerPrior:
_sb2 = False

def __init__(self, pars=None, poly_trend=1, v0_offsets=None, model=None):
"""
This class controls the prior probability distributions for the
Expand Down Expand Up @@ -121,7 +123,9 @@ def __init__(self, pars=None, poly_trend=1, v0_offsets=None, model=None):
# are only used to validate that the units for each parameter are
# equivalent to these
self._nonlinear_equiv_units = get_nonlinear_equiv_units()
self._linear_equiv_units = get_linear_equiv_units(self.poly_trend)
self._linear_equiv_units = get_linear_equiv_units(
self.poly_trend, sb2=self._sb2
)
self._v0_offsets_equiv_units = get_v0_offsets_equiv_units(self.n_offsets)
self._all_par_unit_equiv = {
**self._nonlinear_equiv_units,
Expand Down Expand Up @@ -291,10 +295,7 @@ def __repr__(self):
def __str__(self):
return ", ".join(self.par_names)

@deprecated_renamed_argument(
"random_state", "rng", since="v1.3", warning_type=DeprecationWarning
)
def sample(
def _get_raw_samples(
self,
size=1,
generate_linear=False,
Expand All @@ -303,29 +304,6 @@ def sample(
dtype=None,
**kwargs,
):
"""
Generate random samples from the prior.

Parameters
----------
size : int (optional)
The number of samples to generate.
generate_linear : bool (optional)
Also generate samples in the linear parameters.
return_logprobs : bool (optional)
Generate the log-prior probability at the position of each sample.
**kwargs
Additional keyword arguments are passed to the
`~thejoker.JokerSamples` initializer.

Returns
-------
samples : `thejoker.Jokersamples`
The random samples.

"""
from .samples import JokerSamples

if dtype is None:
dtype = np.float64

Expand All @@ -339,11 +317,6 @@ def sample(
)
}

if generate_linear:
par_names = self.par_names
else:
par_names = list(self._nonlinear_equiv_units.keys())

# MAJOR HACK RELATED TO UPSTREAM ISSUES WITH pymc3:
# init_shapes = {}
# for name, par in sub_pars.items():
Expand Down Expand Up @@ -374,12 +347,68 @@ def sample(

logp.append(_logp)
log_prior = np.sum(logp, axis=0)
else:
log_prior = None

# CONTINUED MAJOR HACK RELATED TO UPSTREAM ISSUES WITH pymc3:
# for name, par in sub_pars.items():
# if hasattr(par, "distribution"):
# par.distribution.shape = init_shapes[name]

return raw_samples, sub_pars, log_prior

@deprecated_renamed_argument(
"random_state", "rng", since="v1.3", warning_type=DeprecationWarning
)
def sample(
self,
size=1,
generate_linear=False,
return_logprobs=False,
rng=None,
dtype=None,
**kwargs,
):
"""
Generate random samples from the prior.

.. note::

Right now, generating samples with the prior values is slow (i.e.
with ``return_logprobs=True``) because of pymc3 issues (see
discussion here:
https://discourse.pymc.io/t/draw-values-speed-scaling-with-transformed-variables/4076).
This will hopefully be resolved in the future...

Parameters
----------
size : int (optional)
The number of samples to generate.
generate_linear : bool (optional)
Also generate samples in the linear parameters.
return_logprobs : bool (optional)
Generate the log-prior probability at the position of each sample.
**kwargs
Additional keyword arguments are passed to the
`~thejoker.JokerSamples` initializer.

Returns
-------
samples : `thejoker.Jokersamples`
The random samples.

"""
from thejoker.samples import JokerSamples

raw_samples, sub_pars, log_prior = self._get_raw_samples(
size, generate_linear, return_logprobs, rng, dtype, **kwargs
)

if generate_linear:
par_names = self.par_names
else:
par_names = list(self._nonlinear_equiv_units.keys())

# Apply units if they are specified:
prior_samples = JokerSamples(
poly_trend=self.poly_trend, n_offsets=self.n_offsets, **kwargs
Expand Down Expand Up @@ -448,9 +477,8 @@ def default_nonlinear_prior(P_min=None, P_max=None, s=None, model=None, pars=Non

if isinstance(s, pt.TensorVariable):
pars["s"] = pars.get("s", s)
else:
if not hasattr(s, "unit") or not s.unit.is_equivalent(u.km / u.s):
raise u.UnitsError("Invalid unit for s: must be equivalent to km/s")
elif not hasattr(s, "unit") or not s.unit.is_equivalent(u.km / u.s):
raise u.UnitsError("Invalid unit for s: must be equivalent to km/s")

# dictionary of parameters to return
out_pars = {}
Expand Down
17 changes: 12 additions & 5 deletions thejoker/prior_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,19 @@ def validate_poly_trend(poly_trend):
return poly_trend, vtrend_names


def get_linear_equiv_units(poly_trend):
def get_linear_equiv_units(poly_trend, sb2=False):
poly_trend, v_names = validate_poly_trend(poly_trend)
return {
'K': u.m/u.s,
**{name: u.m/u.s/u.day**i for i, name in enumerate(v_names)}
}
if sb2:
return {
'K1': u.m/u.s,
'K2': u.m/u.s,
**{name: u.m/u.s/u.day**i for i, name in enumerate(v_names)}
}
else:
return {
'K': u.m/u.s,
**{name: u.m/u.s/u.day**i for i, name in enumerate(v_names)}
}


def validate_sigma_v(sigma_v, poly_trend, v_names):
Expand Down
Loading
Loading