Skip to content

Commit

Permalink
Still bug fixing pymc/arviz stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
hposborn committed Jul 10, 2024
1 parent 777e9eb commit 5521a1e
Showing 1 changed file with 31 additions and 24 deletions.
55 changes: 31 additions & 24 deletions MonoTools/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,14 @@ def save_model_to_file(self, savefile=None, limit_size=False):
self.get_savename(how='save')
savefile=self.savenames[0]+'_model.pickle'
if hasattr(self,'trace'):
self.trace.to_netcdf(self.savenames[0]+'_trace.nc')
try:
self.trace.to_netcdf(self.savenames[0]+'_trace.nc')
except:
try:
#Stacking/unstacking removes Multitrace objects:
self.trace.unstack().to_netcdf(self.savenames[0]+'_trace.nc')
except:
print("Still a save error after unstacking")
excl_types=[az.InferenceData]
cloudpickle.dumps({attr:getattr(self,attr) for attr in self.__dict__},open(savefile,'wb'))

Expand Down Expand Up @@ -4643,7 +4650,7 @@ def plot_corner(self,corner_vars=None,use_marg=True,truths=None):
if not self.assume_circ:
corner_vars+=['multi_ecc','multi_omega']
'''
samples = pm.trace_to_dataframe(self.trace, varnames=corner_vars)
samples =self.make_table(cols=corner_vars)
#print(samples.shape,samples.columns)
assert samples.shape[1]<50

Expand Down Expand Up @@ -4715,13 +4722,13 @@ def make_table(self,short=True,save=True,cols=['all']):
cols_to_remove+=['mono_'+col+'s','duo_'+col+'s']
medvars=[var for var in self.trace if not np.any([icol in var for icol in cols_to_remove])]
#print(cols_to_remove, medvars)
df = pm.summary(self.trace,var_names=medvars,stat_funcs={"5%": lambda x: np.percentile(x, 5),
df = pm.summary(self.trace.unstack(),var_names=medvars,stat_funcs={"5%": lambda x: np.percentile(x, 5),
"-$1\sigma$": lambda x: np.percentile(x, 15.87),
"median": lambda x: np.percentile(x, 50),
"+$1\sigma$": lambda x: np.percentile(x, 84.13),
"95%": lambda x: np.percentile(x, 95)},round_to=5)
else:
df = pm.summary(self.trace,var_names=cols,stat_funcs={"5%": lambda x: np.percentile(x, 5),
df = pm.summary(self.trace.unstack(),var_names=cols,stat_funcs={"5%": lambda x: np.percentile(x, 5),
"-$1\sigma$": lambda x: np.percentile(x, 15.87),
"median": lambda x: np.percentile(x, 50),
"+$1\sigma$": lambda x: np.percentile(x, 84.13),
Expand Down Expand Up @@ -4880,23 +4887,22 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
print("time range",Time(time_start+self.lc.jd_base,format='jd').isot,
"->",Time(time_end+self.lc.jd_base,format='jd').isot)


if check_TESS:
sect_start_ends=self.check_TESS()

all_trans_fin=pd.DataFrame()
loopplanets = self.duos+self.trios+self.multis if include_multis else self.duos+self.trios

all_unq_trans=[]
for pl in loopplanets:
all_trans=pd.DataFrame()
if pl in self.duos+self.trios:
sum_all_probs=np.logaddexp.reduce(np.nanmedian(self.trace['logprob_marg_'+pl],axis=0))
trans_p0=np.floor(np.nanmedian(time_start - self.trace['t0_2_'+pl])/np.nanmedian(self.trace['per_'+pl],axis=0))
trans_p1=np.ceil(np.nanmedian(time_end - self.trace['t0_2_'+pl])/np.nanmedian(self.trace['per_'+pl],axis=0))
trans_p0=np.floor(np.nanmedian(time_start - self.trace['t0_2_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values,axis=0))
trans_p1=np.ceil(np.nanmedian(time_end - self.trace['t0_2_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values,axis=0))
n_trans=trans_p1-trans_p0
elif pl in self.multis:
trans_p0=[np.floor(np.nanmedian(time_start - self.trace['t0_'+pl])/np.nanmedian(self.trace['per_'+pl],axis=0))]
trans_p1=[np.ceil(np.nanmedian(time_end - self.trace['t0_'+pl])/np.nanmedian(self.trace['per_'+pl],axis=0))]
trans_p0=[np.floor(np.nanmedian(time_start - self.trace['t0_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values))]
trans_p1=[np.ceil(np.nanmedian(time_end - self.trace['t0_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values))]
n_trans=[trans_p1[0]-trans_p0[0]]
#print(pl,trans_p0,trans_p1,n_trans)
#print(np.nanmedian(self.trace['t0_2_'+pl])+np.nanmedian(self.trace['per_'+pl],axis=0)*trans_p0)
Expand All @@ -4909,21 +4915,21 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
if 'tdur' in self.fit_params or pl in self.multis:
dur=np.nanpercentile(self.trace['tdur_'+pl],percentiles)
naliases=[0] if pl in self.multis else np.arange(self.planets[pl]['npers'])
idfs=[]
for nd in naliases:
if n_trans[nd]>0:
if pl in self.duos+self.trios:
int_alias=int(self.planets[pl]['period_int_aliases'][nd])
transits=np.nanpercentile(np.vstack([self.trace['t0_2_'+pl]+ntr*self.trace['per_'+pl][:,nd] for ntr in np.arange(trans_p0[nd],trans_p1[nd])]),percentiles,axis=1)
transits=np.nanpercentile(np.vstack([self.trace['t0_2_'+pl].values+ntr*self.trace['per_'+pl].values[:,nd] for ntr in np.arange(trans_p0[nd],trans_p1[nd])]),percentiles,axis=1)
if 'tdur' in self.marginal_params:
dur=np.nanpercentile(self.trace['tdur_'+pl][:,nd],percentiles)
logprobs=np.nanmedian(self.trace['logprob_marg_'+pl][:,nd])-sum_all_probs
else:
transits=np.nanpercentile(np.vstack([self.trace['t0_'+pl]+ntr*self.trace['per_'+pl] for ntr in np.arange(trans_p0[nd],trans_p1[nd])]),percentiles,axis=1)
transits=np.nanpercentile(np.column_stack([self.trace['t0_'+pl].values+ntr*self.trace['per_'+pl].values for ntr in np.arange(trans_p0[nd],trans_p1[nd],1.0)]),percentiles,axis=0)
int_alias=1
logprobs=np.array([0.0])
#Getting the aliases for this:

idf=pd.DataFrame({'transit_mid_date':Time(transits[2]+self.lc.jd_base,format='jd').isot,
idfs+=[pd.DataFrame({'transit_mid_date':Time(transits[2]+self.lc.jd_base,format='jd').isot,
'transit_mid_med':transits[2],
'transit_dur_med':np.tile(dur[2],len(transits[2])),
'transit_dur_-1sig':np.tile(dur[1],len(transits[2])),
Expand All @@ -4945,8 +4951,8 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
'prob':np.tile(np.exp(logprobs),len(transits[2])),
'planet_name':np.tile('multi_'+pl,len(transits[2])) if pl in self.multis else np.tile('duo_'+pl,len(transits[2])),
'alias_n':np.tile(nd,len(transits[2])),
'alias_p':np.tile(np.nanmedian(self.trace['per_'+pl][:,nd]),len(transits[2])) if pl in self.duos+self.trios else np.nanmedian(self.trace['per_'+pl])})
all_trans=all_trans.append(idf)
'alias_p':np.tile(np.nanmedian(self.trace['per_'+pl].values[:,nd]),len(transits[2])) if pl in self.duos+self.trios else np.tile(np.nanmedian(self.trace['per_'+pl].values),len(transits[2]))})]
all_trans=pd.concat(idfs)
unq_trans = all_trans.sort_values('log_prob').copy().drop_duplicates('transit_fractions')
unq_trans = unq_trans.set_index(np.arange(len(unq_trans)))
unq_trans['aliases_ns']=unq_trans['alias_n'].values.astype(str)
Expand All @@ -4960,7 +4966,8 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180,
unq_trans.loc[i,'aliases_ps']=','.join(list(np.round(oths['alias_p'].values,4).astype(str)))
unq_trans.loc[i,'num_aliases']=len(oths)
unq_trans.loc[i,'total_prob']=np.sum(oths['prob'].values)
all_trans_fin=all_trans_fin.append(unq_trans)
all_unq_trans+=[unq_trans]
all_trans_fin=pd.concat(all_unq_trans)
all_trans_fin = all_trans_fin.loc[(all_trans_fin['transit_end_+2sig']>time_start)*(all_trans_fin['transit_start_-2sig']<time_end)].sort_values('transit_mid_med')
all_trans_fin = all_trans_fin.set_index(np.arange(len(all_trans_fin)))

Expand Down Expand Up @@ -5006,7 +5013,7 @@ def cheops_RMS(self, Gmag, tdur):
def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, timing_sigma=3, t_start=None, t_end=None, Texp=None,
max_orbits=14, min_pretrans_orbits=0.5, min_intrans_orbits=None, orbits_flex=1.4, observe_sigma=2,
observe_threshold=None, max_ORs=None,prio_1_threshold=0.25, prio_3_threshold=0.0, targetnamestring=None,
min_orbits=4.0, outfilesuffix='_output_ORs.csv',avoid_TESS=True,pre_post_TESS="pre"):
min_orbits=4.0, outfilesuffix='_output_ORs.csv',avoid_TESS=True, pre_post_TESS="pre", prog_id="0072"):
"""Given a list of observable transits (which are outputted from `trace_to_cheops_transits`),
create a csv which can be run by pycheops make_xml_files to produce input observing requests (both to FC and observing tool).
Expand Down Expand Up @@ -5057,7 +5064,6 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
gaiainfo=Gaia.launch_job_async("SELECT * \
FROM gaiadr2.gaia_source \
WHERE gaiadr2.gaia_source.source_id="+str(DR2ID)).results.to_pandas().iloc[0]

gaia_colour=(gaiainfo['phot_bp_mean_mag']-gaiainfo['phot_rp_mean_mag'])
V=gaiainfo['phot_g_mean_mag']+0.0176+0.00686*gaia_colour+0.1732*gaia_colour**2
Verr=1.09/gaiainfo['phot_g_mean_flux_over_error']+0.045858
Expand Down Expand Up @@ -5157,7 +5163,8 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
ser['_DEJ2000']=old_radec.dec.to_string(sep=':')
ser['pmra']=gaiainfo['pmra']
ser['pmdec']=gaiainfo['pmdec']
ser['parallax']=gaiainfo['plx']

ser['parallax']=gaiainfo['plx'] if 'plx' in gaiainfo else gaiainfo['parallax']
ser['SpTy']=SpTy
ser['Gmag']=gaiainfo['phot_g_mean_mag']
ser['dr2_g_mag']=gaiainfo['phot_g_mean_mag']
Expand All @@ -5167,7 +5174,7 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
ser['Vmag']=V
ser['e_Vmag']=Verr

ser['Programme_ID']='0048'
ser['Programme_ID']=prog_id
ser['BJD_early']=t_start
ser['BJD_late']=t_end
#Total observing time must cover duration, and either the full timing bound (i.e. assuming 3 sigma), or the oot_min_orbits (if the timing precision is better than the oot_min_orbits)
Expand Down Expand Up @@ -5236,7 +5243,7 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti
#ser["EndPh1"]=((row['end_earliest']-row['mid'])/100)
#ser["Effic1"]=50
ser['N_Ranges']=0
out_tab=out_tab.append(pd.Series(ser,name=nper))
out_tab.loc[nper,list(ser.keys())]=pd.Series(ser,name=nper)
out_tab['MinEffDur']=out_tab['MinEffDur'].values.astype(int)
#print(98.77*60*out_tab['T_visit'].values)
out_tab['T_visit']=(98.77*60*out_tab['T_visit'].values).astype(int) #in seconds
Expand Down Expand Up @@ -5384,12 +5391,12 @@ def to_latex_table(self,varnames='all',order='columns'):
print("Making Latex Table")
if not hasattr(self,'savenames'):
self.get_savename(how='save')
if self.tracemask is None:
if not hasattr(self,'tracemask') or self.tracemask is None:
self.tracemask=np.tile(True,len(self.trace['Rs']))
if varnames is None or varnames == 'all':
varnames=[var for var in self.trace if var[-2:]!='__' and var not in ['gp_pred','light_curves']]

self.samples = pm.trace_to_dataframe(self.trace, varnames=varnames)
self.samples = self.make_table(cols=varnames)
self.samples = self.samples.loc[self.tracemask]
facts={'r_pl':109.07637,'Ms':1.0,'rho':1.0,"t0":1.0,"period":1.0,"vrel":1.0,"tdur":24}
units={'r_pl':"$ R_\\oplus $",'Ms':"$ M_\\odot $",'rho':"$ \\rho_\\odot $",
Expand Down

0 comments on commit 5521a1e

Please sign in to comment.