Skip to content

Commit

Permalink
Catching nan stellar parameter errors and forcing them to be large fr…
Browse files Browse the repository at this point in the history
…actions of the value
  • Loading branch information
hposborn committed Aug 21, 2024
1 parent ed0eebb commit e20c787
Showing 1 changed file with 29 additions and 20 deletions.
49 changes: 29 additions & 20 deletions MonoTools/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class monoModel():
[type]: [description]
"""

#The default monoModel class. This is what we will use to build a Pymc3 model
#The default monoModel class. This is what we will use to build a pymc model

def __init__(self, ID, mission, lc=None, rvs=None, planets=None, overwrite=False, savefileloc=None, **kwargs):
"""Initialises MonoTools fit model
Expand Down Expand Up @@ -111,7 +111,7 @@ def __init__(self, ID, mission, lc=None, rvs=None, planets=None, overwrite=False
'derive_K':True, # If we have RVs, do we derive K for each alias or fit for a single K param
'pred_all':False, # Do we predict all time array, or only a cut-down version?
'use_multinest':False, # use_multinest - bool - currently not supported
'use_pymc3':True, # use_pymc3 - bool
'use_pymc':True, # use_pymc - bool
'bin_all':False, # bin_all - bool - Bin all points to 10mins (to speed up certain)
'bin_all_size':10/1440., # bin_all_size - float - Bin size if binning all points in minutes (default to 10mins)
'bin_oot':True, # bin_oot - bool - Bin points outside the cut_distance to 30mins
Expand Down Expand Up @@ -921,7 +921,16 @@ def init_starpars(self,Rstar=None,Teff=None,logg=None,FeH=0.0,rhostar=None,Mstar
self.rhostar=np.array(rhostar).astype(float)
if Mstar is None:
self.Mstar=rhostar[0]*self.Rstar[0]**3


if self.rhostar[0] is not None and (self.rhostar[1] is None or self.rhostar[0]==np.nan):
self.rhostar[1]=0.5*self.rhostar[0]
self.rhostar[2]=0.5*self.rhostar[0]
if self.logg[0] is not None and (self.logg[1] is None or self.logg[0]==np.nan):
self.logg[1]=0.25
self.logg[2]=0.25
if self.Rstar[0] is not None and (self.Rstar[1] is None or self.Rstar[0]==np.nan):
self.Rstar[1]=0.15*self.Rstar[0]
self.Rstar[2]=0.15*self.Rstar[0]

def get_savename(self, how='load',overwrite=None):
"""Adds unique savename prefixes to class (self.savenames) with two formats:
Expand Down Expand Up @@ -1183,7 +1192,7 @@ def init_model(self, overwrite=False, **kwargs):
per_index (float, optional): period prior index e.g. P^{index}. -8/3 in to Kipping 2018. Defaults to -8/3
derive_K (bool, optional): If we have RVs, do we derive K for each alias or fit for a single K param. Defaults to True
use_multinest (bool, optional): Use Multinest sampling [NOT SUPPORTED YET]. Defaults to False
use_pymc3 (bool, optional): Use PyMC3 sampling? Defaults to True
use_pymc (bool, optional): Use pymc sampling? Defaults to True
bin_oot (bool, optional): bool - Bin points outside the cut_distance to 30mins. Defaults to True
"""

Expand All @@ -1197,7 +1206,7 @@ def init_model(self, overwrite=False, **kwargs):
self.fit_params=self.fit_params+['ecc'] if self.assume_circ and 'ecc' not in self.fit_params else self.fit_params
self.fit_params=self.fit_params+['omega'] if self.assume_circ and 'omega' not in self.fit_params else self.fit_params
self.marginal_params=self.marginal_params+['K'] if hasattr(self,'rvs') and self.derive_K else self.marginal_params
assert self.use_multinest^self.use_pymc3 #Must have either use_multinest or use_pymc3, though use_multinest doesn't work
assert self.use_multinest^self.use_pymc #Must have either use_multinest or use_pymc, though use_multinest doesn't work
assert not (self.assume_circ and self.interpolate_v_prior and (len(self.monos)+len(self.duos)+len(self.trios)>0)) #Cannot interpolate_v_prior and assume circular unless we only have multiplanets
assert not ((len(self.trios+self.duos+self.monos)>1)*hasattr(self,'rvs')) #Cannot fit more than one planet with uncertain orbits with RVs (currently)

Expand Down Expand Up @@ -1230,7 +1239,7 @@ def init_model(self, overwrite=False, **kwargs):
# Initialising sampling models:
######################################

if self.use_pymc3:
if self.use_pymc:
self.init_pymc()
elif self.use_multinest:
self.run_multinest(**kwargs)
Expand Down Expand Up @@ -1394,7 +1403,7 @@ def init_interpolated_v_prior(self):
np.column_stack((np.tile(-310,len(logprob_arr[1:,0])),logprob_arr[1:,1:])), nout=1)

def init_pymc(self,ld_mult=1.5):
"""Initialise the PyMC3 sampler
"""Initialise the pymc sampler
"""
######################################
# Selecting lightcurve:
Expand All @@ -1411,7 +1420,7 @@ def init_pymc(self,ld_mult=1.5):
start=None

with pm.Model() as model:
if self.debug: print("Forming Pymc3 model with: monos:",self.monos,"multis:",self.multis,"duos:",self.duos,"trios:",self.trios)
if self.debug: print("Forming pymc model with: monos:",self.monos,"multis:",self.multis,"duos:",self.duos,"trios:",self.trios)

######################################
# Intialising Stellar Params:
Expand Down Expand Up @@ -2026,14 +2035,14 @@ def gen_lc(i_orbit, i_rpl, n_pl, mask=None,prefix='',make_deterministic=False,pr
Args:
i_orbit (xo.orbits.KeplerianOrbit): Planetary orbits (in units of solar radii) to create lightcurve
i_rpl (pymc3 variable): Planetary radii (in units of solar radii) to create lightcurve
i_rpl (pymc variable): Planetary radii (in units of solar radii) to create lightcurve
n_pl (int): Number of planets/orbits to generate lightcurves
mask (array, optional): Specific mask to the light curve]. Defaults to None.
prefix (str, optional): prefix to the PyMC3 variable name. Defaults to ''.
make_deterministic (bool, optional): Add the output as a deterministic PyMC3 variable (may be memory intensive). Defaults to False.
prefix (str, optional): prefix to the pymc variable name. Defaults to ''.
make_deterministic (bool, optional): Add the output as a deterministic pymc variable (may be memory intensive). Defaults to False.
Returns:
array OR pymc3 variable: 1D lightcurve
array OR pymc variable: 1D lightcurve
"""
#
trans_pred=[]
Expand Down Expand Up @@ -2109,14 +2118,14 @@ def create_orbit(pl, Rs, rho_S, pers, t0s, bs, n_marg=1, eccs=None, omegas=None)
Args:
pl (str): Planet name as seen in `mod.planets` dict
Rs (pymc3 variable): Solar radius (in Rsun)
rho_S (pymc3 variable): Solar density (in rho_sun)
pers (pymc3 variable): Periods
t0s (pymc3 variable): Epochs
bs (pymc3 variable): Impact parameters
Rs (pymc variable): Solar radius (in Rsun)
rho_S (pymc variable): Solar density (in rho_sun)
pers (pymc variable): Periods
t0s (pymc variable): Epochs
bs (pymc variable): Impact parameters
n_marg (int, optional): Number of periods to marginalise over. Defaults to 1.
eccs (pymc3 variable, optional): Orbital eccentricity. Defaults to None.
omegas (pymc3 variable, optional): Orbital argument of periasteron (omega). Defaults to None.
eccs (pymc variable, optional): Orbital eccentricity. Defaults to None.
omegas (pymc variable, optional): Orbital argument of periasteron (omega). Defaults to None.
Returns:
xo.orbits.KeplerianOrbit: Exoplanet Keplerian orbit object initialised to model the lightcurve
Expand Down Expand Up @@ -2650,7 +2659,7 @@ def create_orbit(pl, Rs, rho_S, pers, t0s, bs, n_marg=1, eccs=None, omegas=None)
self.init_soln = map_soln

def sample_model(self, n_draws=500, n_burn_in=None, overwrite=False, continue_sampling=False, n_chains=4, cores=4, **kwargs):
"""Run PyMC3 sampler
"""Run pymc sampler
Args:
n_draws (int, optional): Number of independent samples to draw from each chain. Defaults to 500.
Expand Down

0 comments on commit e20c787

Please sign in to comment.