From 154ee658e6d44b516ba214aeacce8a1e807091c0 Mon Sep 17 00:00:00 2001 From: hposborn Date: Thu, 11 Jul 2024 22:59:10 +0200 Subject: [PATCH] Removing extract/stack/unstack commands from arviz InferenceData objects which were causing from_netcdf() to return completely empty xarray Datasets --- MonoTools/fit.py | 225 +++++++++++++++++++++++------------------------ 1 file changed, 112 insertions(+), 113 deletions(-) diff --git a/MonoTools/fit.py b/MonoTools/fit.py index 51ac80c..78aada8 100755 --- a/MonoTools/fit.py +++ b/MonoTools/fit.py @@ -182,7 +182,7 @@ def load_model_from_file(self, loadfile=None): for key in pick: setattr(self,key,pick[key]) del pick - setattr(self,'trace',az.InferenceData.from_netcdf(loadfile.replace('_model.pickle','_trace.nc')).stack()) + setattr(self,'trace',az.InferenceData.from_netcdf(loadfile.replace('_model.pickle','_trace.nc'))) elif os.path.exists(loadfile): #Loading old version using pickle from pickled dictionary pick=pickle.load(open(loadfile,'rb')) @@ -212,7 +212,7 @@ def save_model_to_file(self, savefile=None, limit_size=False): except: try: #Stacking/unstacking removes Multitrace objects: - self.trace.unstack().to_netcdf(self.savenames[0]+'_trace.nc') + self.trace.to_netcdf(self.savenames[0]+'_trace.nc') except: print("Still a save error after unstacking") excl_types=[az.InferenceData] @@ -228,11 +228,11 @@ def save_model_to_file(self, savefile=None, limit_size=False): # self.init_trans_to_plot() # #And let's clip gp and lightcurves and pseudo-variables from the trace: - # remvars=[var for var in self.trace if (('gp_' in var or '_gp' in var or 'light_curve' in var) and np.product(self.trace[var].shape)>6*len(self.trace['Rs'])) or '__' in var] + # remvars=[var for var in self.trace.posterior if (('gp_' in var or '_gp' in var or 'light_curve' in var) and np.product(self.trace.posterior[var].shape)>6*len(self.trace.posterior['Rs'])) or '__' in var] # for key in remvars: # #Permanently deleting these values from the trace. # self.trace.remove_values(key) - # #medvars=[var for var in self.trace if 'gp_' not in var and '_gp' not in var and 'light_curve' not in var] + # #medvars=[var for var in self.trace.posterior if 'gp_' not in var and '_gp' not in var and 'light_curve' not in var] # n_bytes = 2**31 # max_bytes = 2**31-1 @@ -1345,7 +1345,7 @@ def init_GP(self, n_draws=900, max_len_lc=25000, use_binned=False): self.gp_init_soln = pmx.optimize(start=None, vars=optvars) if self.debug: print("sampling init GP", int(n_draws*0.66),"times with",len(self.lc.flux[mask]),"-point lightcurve") - self.gp_init_trace = az.extract(pm.sample(tune=int(n_draws*0.66), draws=n_draws, start=self.gp_init_soln, chains=4)) + self.gp_init_trace = pm.sample(tune=int(n_draws*0.66), draws=n_draws, start=self.gp_init_soln, chains=4) def init_interpolated_Mp_prior(self): """Initialise a 2D interpolated prior for the mass of a planet given the radius @@ -1824,8 +1824,8 @@ def init_pymc(self,ld_mult=1.5): if self.debug: print(np.isnan(self.model_time),np.isnan(self.model_flux),np.isnan(self.model_flux_err)) if self.train_GP: #Using histograms from the output of the previous GP training as priors for the true model. - minmaxs={var:np.percentile(self.gp_init_trace[var].values,[0.5,99.5]).astype(floattype) for var in self.gp_init_trace if '__' not in var and len(self.gp_init_trace[var].shape)==1} - hists={var:np.histogram(self.gp_init_trace[var].values,np.linspace(minmaxs[var][0],minmaxs[var][1],101))[0] for var in self.gp_init_trace if '__' not in var and len(self.gp_init_trace[var].shape)==1} + minmaxs={var:np.percentile(self.gp_init_trace.posterior[var].values,[0.5,99.5]).astype(floattype) for var in self.gp_init_trace.posterior if '__' not in var and len(self.gp_init_trace.posterior[var].shape)==1} + hists={var:np.histogram(self.gp_init_trace.posterior[var].values,np.linspace(minmaxs[var][0],minmaxs[var][1],101))[0] for var in self.gp_init_trace.posterior if '__' not in var and len(self.gp_init_trace.posterior[var].shape)==1} gpvars=[] if hasattr(self, 'periodic_kernel') and self.periodic_kernel is not None: if self.train_GP: @@ -2679,7 +2679,6 @@ def sample_model(self, n_draws=500, n_burn_in=None, overwrite=False, continue_sa self.trace = pm.sample(tune=n_burn_in, draws=n_draws, chains=n_chains, trace=self.trace, compute_convergence_checks=False)#, **kwargs) else: self.trace = pm.sample(tune=n_burn_in, draws=n_draws, start=self.init_soln, chains=n_chains, compute_convergence_checks=False)#, **kwargs) - self.trace=az.extract(self.trace) #Saving both the class and a pandas dataframe of output data. self.save_model_to_file() _=self.make_table(save=True) @@ -2731,13 +2730,13 @@ def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False, self.gp_to_plot={'n_samp':n_samp} if hasattr(self,'trace'): #Using the output of the model trace - medvars=[var for var in self.trace if 'gp_' not in var and '_gp' not in var and 'light_curve' not in var] + medvars=[var for var in self.trace.posterior if 'gp_' not in var and '_gp' not in var and 'light_curve' not in var] self.meds={} for mv in medvars: - if len(self.trace[mv].shape)>1: - self.meds[mv]=np.median(self.trace[mv],axis=0) - elif len(self.trace[mv].shape)==1: - self.meds[mv]=np.median(self.trace[mv]) + if len(self.trace.posterior[mv].shape)>1: + self.meds[mv]=np.median(self.trace.posterior[mv],axis=0) + elif len(self.trace.posterior[mv].shape)==1: + self.meds[mv]=np.median(self.trace.posterior[mv]) else: self.meds=self.init_soln @@ -2813,9 +2812,9 @@ def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False, #assert self.bin_oot stacktime=np.hstack((self.lc.time[0]-1,self.model_time,self.lc.time[-1]+1)) preds=[] - for i in np.random.choice(len(self.trace['phot_mean']),int(np.clip(10*n_samp,1,len(self.trace['phot_mean']))),replace=False): - smooth_func=interpolate.interp1d(stacktime, np.hstack((0,self.trace['gp_pred'][:,i],0)), kind='slinear') - preds+=[smooth_func(self.lc.time)+self.trace['phot_mean'].values[i]] + for i in np.random.choice(len(self.trace.posterior['phot_mean']),int(np.clip(10*n_samp,1,len(self.trace.posterior['phot_mean']))),replace=False): + smooth_func=interpolate.interp1d(stacktime, np.hstack((0,self.trace.posterior['gp_pred'][:,i],0)), kind='slinear') + preds+=[smooth_func(self.lc.time)+self.trace.posterior['phot_mean'].values[i]] prcnts=np.nanpercentile(np.column_stack(preds),[15.8655254, 50., 84.1344746],axis=1) self.gp_to_plot['gp_pred']=prcnts[1] self.gp_to_plot['gp_sd']=0.5*(prcnts[2]-prcnts[0]) @@ -2836,7 +2835,7 @@ def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False, i_gp_pred=[] i_gp_var=[] for i in np.random.choice(len(self.trace),n_samp,replace=False): - sample=self.trace[i] + sample=self.trace.posterior[i] #print(np.exp(sample['logs2'])) i_gp.kernel = pymc_terms.SHOTerm(S0=sample['phot_S0'], w0=sample['phot_w0'], Q=1/np.sqrt(2)) i_gp.mean = sample['mean'] @@ -2872,7 +2871,7 @@ def init_trans_to_plot(self,n_samp=None,**kwargs): Args: n_samp (int, optional): Number of samples to use from the MCMC trace to generate the models & percentiles. Defaults to None. """ - n_samp=len(self.trace['phot_mean']) if n_samp is None else n_samp + n_samp=len(self.trace.posterior['phot_mean']) if n_samp is None else n_samp print("Initalising Transit models for plotting with n_samp=",n_samp) if not hasattr(self,'lc_regions'): self.init_plot(plot_type='lc',**kwargs) @@ -2881,28 +2880,28 @@ def init_trans_to_plot(self,n_samp=None,**kwargs): 'n_samp':n_samp} percentiles={'-2sig':2.2750132, '-1sig':15.8655254, 'med':50., '+1sig':84.1344746, '+2sig':97.7249868} - if hasattr(self,'trace') and 'marg_all_lc_model' in self.trace: - prcnt=np.percentile(self.trace['marg_all_lc_model'],list(percentiles.values()),axis=1) + if hasattr(self,'trace') and 'marg_all_lc_model' in self.trace.posterior: + prcnt=np.percentile(self.trace.posterior['marg_all_lc_model'],list(percentiles.values()),axis=1) self.trans_to_plot['model']['allpl']={list(percentiles.keys())[n]:prcnt[n] for n in range(5)} elif 'marg_all_lc_model' in self.init_soln: self.trans_to_plot['model']['allpl']['med']=self.init_soln['marg_all_lc_model'] else: print("marg_all_lc_model not in any optimised models") for pl in self.planets: - if hasattr(self,'trace') and pl+"_light_curves" in self.trace: - prcnt = np.percentile(self.trace[pl+"_light_curves"], list(percentiles.values()), axis=1) + if hasattr(self,'trace') and pl+"_light_curves" in self.trace.posterior: + prcnt = np.percentile(self.trace.posterior[pl+"_light_curves"], list(percentiles.values()), axis=1) self.trans_to_plot['model'][pl] = {list(percentiles.keys())[n]:prcnt[n] for n in range(5)} elif hasattr(self,'init_soln') and pl+"_light_curves" in self.init_soln: self.trans_to_plot['model'][pl] = {'med':self.init_soln[pl+"_light_curves"]} ''' self.trans_to_plot_i[pl]={} if pl in self.multis or self.interpolate_v_prior: - if hasattr(self,'trace') and 'light_curve_'+pl in self.trace: - if len(self.trace['mask_light_curves'].shape)>2: - prcnt = np.percentile(self.trace['multi_mask_light_curves'][:,:,self.multis.index(pl)], + if hasattr(self,'trace') and 'light_curve_'+pl in self.trace.posterior: + if len(self.trace.posterior['mask_light_curves'].shape)>2: + prcnt = np.percentile(self.trace.posterior['multi_mask_light_curves'][:,:,self.multis.index(pl)], (5,16,50,84,95),axis=0) else: - prcnt = np.percentile(self.trace['multi_mask_light_curves'], (5,16,50,84,95), axis=0) + prcnt = np.percentile(self.trace.posterior['multi_mask_light_curves'], (5,16,50,84,95), axis=0) elif 'multi_mask_light_curves' in self.init_soln: if len(self.init_soln['multi_mask_light_curves'].shape)==1: @@ -2912,8 +2911,8 @@ def init_trans_to_plot(self,n_samp=None,**kwargs): else: print('multi_mask_light_curves not in any optimised models') elif pl in self.duos or self.monos and not self.interpolate_v_prior: - if hasattr(self,'trace') and 'marg_light_curve_'+pl in self.trace: - prcnt=np.percentile(self.trace['marg_light_curve_'+pl],(5,16,50,84,95),axis=0) + if hasattr(self,'trace') and 'marg_light_curve_'+pl in self.trace.posterior: + prcnt=np.percentile(self.trace.posterior['marg_light_curve_'+pl],(5,16,50,84,95),axis=0) nms=['-2sig','-1sig','med','+1sig','+2sig'] self.trans_to_plot_i[pl] = {nms[n]:prcnt[n] for n in range(5)} elif 'marg_light_curve_'+pl in self.init_soln: @@ -2945,7 +2944,7 @@ def init_spline_to_plot(self,n_samp=None,**kwargs): Args: n_samp (int, optional): Number of samples to use from the MCMC trace to generate the models & percentiles. Defaults to None. """ - n_samp=len(self.trace['phot_mean']) if n_samp is None else n_samp + n_samp=len(self.trace.posterior['phot_mean']) if n_samp is None else n_samp if not hasattr(self,'lc_regions'): self.init_plot(plot_type='lc',**kwargs) self.spline_to_plot={'model':{'allpl':{}}, @@ -2954,7 +2953,7 @@ def init_spline_to_plot(self,n_samp=None,**kwargs): percentiles={'-2sig':2.2750132, '-1sig':15.8655254, 'med':50., '+1sig':84.1344746, '+2sig':97.7249868} if hasattr(self,'trace'): - prcnt=np.percentile(np.sum(np.dstack([self.trace['spline_model_'+pl] for pl in self.monos+self.duos+self.trios]),axis=2),list(percentiles.values()),axis=0) + prcnt=np.percentile(np.sum(np.dstack([self.trace.posterior['spline_model_'+pl] for pl in self.monos+self.duos+self.trios]),axis=2),list(percentiles.values()),axis=0) self.spline_to_plot['model']['allpl']={list(percentiles.keys())[n]:prcnt[n] for n in range(5)} elif 'marg_all_lc_model' in self.init_soln: self.spline_to_plot['model']['allpl']['med']=np.sum(np.vstack([self.init_soln['spline_model_'+pl] for pl in self.monos+self.duos+self.trios]),axis=0) @@ -2999,7 +2998,7 @@ def init_rvs_to_plot(self, n_samp=300, plot_alias='all'): self.rvs_to_plot['t']['time']=np.arange(np.min(self.rvs['time'])-5,np.max(self.rvs['time'])+5,0.5) if hasattr(self,'trace'): - samples=[self.trace[i] for i in np.random.choice(len(self.trace), n_samp,replace=False)] + samples=[self.trace.posterior[i] for i in np.random.choice(len(self.trace), n_samp,replace=False)] else: samples=[self.init_soln] all_rv_ts_i={pl:[] for pl in all_pls_in_rvs} @@ -3069,8 +3068,8 @@ def init_rvs_to_plot(self, n_samp=300, plot_alias='all'): if hasattr(self,'trace'): #Taking best-fit model: for pl in all_pls_in_rvs: - xprcnts = np.percentile(self.trace["marg_rv_model_"+pl], percentiles, axis=0) - xtrendprcnts = np.percentile(self.trace["marg_rv_model_"+pl]+self.trace["rv_trend"], percentiles, axis=0) + xprcnts = np.percentile(self.trace.posterior["marg_rv_model_"+pl], percentiles, axis=0) + xtrendprcnts = np.percentile(self.trace.posterior["marg_rv_model_"+pl]+self.trace.posterior["rv_trend"], percentiles, axis=0) self.rvs_to_plot['x'][pl]['marg'] = {nms[n]:xprcnts[n] for n in range(5)} self.rvs_to_plot['x'][pl]['marg+trend'] = {nms[n]:xtrendprcnts[n] for n in range(5)} @@ -3081,17 +3080,17 @@ def init_rvs_to_plot(self, n_samp=300, plot_alias='all'): if pl in self.trios+self.duos+self.monos: alltrvs = np.dstack(all_rv_ts_i[pl]) for i in range(self.n_margs[pl]): - xiprcnts=np.percentile(self.trace["model_rv_"+pl][:,:,i], percentiles, axis=0) + xiprcnts=np.percentile(self.trace.posterior["model_rv_"+pl][:,:,i], percentiles, axis=0) self.rvs_to_plot['x'][pl][i]={nms[n]:xiprcnts[n] for n in range(5)} tiprcnts=np.percentile(alltrvs[:,i,:], percentiles, axis=1) self.rvs_to_plot['t'][pl][i]={nms[n]:tiprcnts[n] for n in range(5)} #print(self.rvs_to_plot) if len(all_pls_in_rvs)>1: - iprcnts = np.percentile(self.trace["rv_trend"],percentiles, axis=0) + iprcnts = np.percentile(self.trace.posterior["rv_trend"],percentiles, axis=0) self.rvs_to_plot['x']["trend+offset"] = {nms[n]:iprcnts[n] for n in range(5)} - iprcnts = np.percentile(self.trace["marg_all_rv_model"],percentiles, axis=0) + iprcnts = np.percentile(self.trace.posterior["marg_all_rv_model"],percentiles, axis=0) self.rvs_to_plot['x']["all"] = {nms[n]:iprcnts[n] for n in range(5)} - iprcnts = np.percentile(self.trace["marg_all_rv_model"]+self.trace["rv_trend"],percentiles, axis=0) + iprcnts = np.percentile(self.trace.posterior["marg_all_rv_model"]+self.trace.posterior["rv_trend"],percentiles, axis=0) self.rvs_to_plot['x']["all+trend"] = {nms[n]:iprcnts[n] for n in range(5)} iprcnts = np.percentile(np.vstack(trends_i), percentiles, axis=0) @@ -3106,7 +3105,7 @@ def init_rvs_to_plot(self, n_samp=300, plot_alias='all'): percentiles, axis=0) self.rvs_to_plot['t']["all+trend"] = {nms[n]:iprcnts[n] for n in range(5)} else: - iprcnts = np.percentile(self.trace["rv_trend"], percentiles, axis=0) + iprcnts = np.percentile(self.trace.posterior["rv_trend"], percentiles, axis=0) self.rvs_to_plot['x']["trend+offset"] = {nms[n]:iprcnts[n] for n in range(5)} self.rvs_to_plot['x']["all"] = self.rvs_to_plot['x'][pl] self.rvs_to_plot['x']["all+trend"] = self.rvs_to_plot['x'][pl]["marg+trend"] @@ -3116,7 +3115,7 @@ def init_rvs_to_plot(self, n_samp=300, plot_alias='all'): #print(self.rvs_to_plot['t']["trend"]) self.rvs_to_plot['t']["all"] = self.rvs_to_plot['t'][pl] self.rvs_to_plot['t']["all+trend"] = self.rvs_to_plot['t'][pl]["marg+trend"] - iprcnts = np.percentile(self.trace["rv_offsets"], percentiles, axis=0) + iprcnts = np.percentile(self.trace.posterior["rv_offsets"], percentiles, axis=0) self.rvs_to_plot['x']["offsets"] = {nms[n]:iprcnts[n] for n in range(5)} elif hasattr(self,'init_soln'): for pl in marg_rv_ts_i: @@ -3191,7 +3190,7 @@ def init_plot(self, interactive=False, gap_thresh=10, plottype='lc',pointcol='k' self.lc_regions[nj]['cadence']=np.nanmedian(np.diff(self.lc.time[self.lc_regions[nj]['ix']])) self.lc_regions[nj]['mad']=1.06*np.nanmedian(abs(np.diff(getattr(self.lc,fx_lab)[self.lc_regions[nj]['ix']]))) self.lc_regions[nj]['minmax']=np.nanpercentile(getattr(self.lc,fx_bin_lab)[self.lc_regions[nj]['bin_ix']],[0.25,99.75]) - transmin=np.min(self.init_soln['marg_all_lc_model']) if not hasattr(self,'trace') else np.min(np.nanmedian(self.trace['marg_all_lc_model'],axis=0)) + transmin=np.min(self.init_soln['marg_all_lc_model']) if not hasattr(self,'trace') else np.min(np.nanmedian(self.trace.posterior['marg_all_lc_model'],axis=0)) minmax_global = (np.min([np.min([self.lc_regions[nj]['minmax'][0],transmin])-self.lc_regions[nj]['mad'] for nj in self.lc_regions]), np.max([self.lc_regions[nj]['minmax'][1]+self.lc_regions[nj]['mad'] for cad in self.lc_regions])) @@ -3273,8 +3272,8 @@ def plot_RVs(self, interactive=False, plot_alias='best', nbest=4, n_samp=300, ov #Here we'll choose the best RV curves to plot (in the case of mono/duos) nbests = self.n_margs[marg_pl] if plot_alias=='all' else nbest if hasattr(self,'trace'): - ibest = np.nanmedian(self.trace['logprob_marg_'+marg_pl],axis=0).argsort()[-1*nbests:] - heights = np.array([np.clip(np.nanmedian(self.trace['K_'+marg_pl][:,i]),0.5*averr,10000) for i in ibest]) + ibest = np.nanmedian(self.trace.posterior['logprob_marg_'+marg_pl],axis=0).argsort()[-1*nbests:] + heights = np.array([np.clip(np.nanmedian(self.trace.posterior['K_'+marg_pl][:,i]),0.5*averr,10000) for i in ibest]) elif hasattr(self,'init_soln'): ibest = self.init_soln['logprob_marg_'+marg_pl].argsort()[-1*nbests:] heights = np.array([np.clip(self.init_soln['K_'+marg_pl][i],0.5*averr,10000) for i in ibest]) @@ -3408,17 +3407,17 @@ def plot_RVs(self, interactive=False, plot_alias='best', nbest=4, n_samp=300, ov for n,pl in enumerate(list(self.planets.keys())+list(self.rvplanets.keys())): if hasattr(self,'trace'): - t0=np.nanmedian(self.trace['t0_'+pl]) + t0=np.nanmedian(self.trace.posterior['t0_'+pl]) if pl in self.multis or pl in self.rvplanets: - per=[np.nanmedian(self.trace['per_'+pl])] + per=[np.nanmedian(self.trace.posterior['per_'+pl])] alphas=[1.0] else: - alphas=np.clip(2*np.exp(np.nanmedian(self.trace['logprob_marg_'+pl],axis=0)),0.25,1.0) + alphas=np.clip(2*np.exp(np.nanmedian(self.trace.posterior['logprob_marg_'+pl],axis=0)),0.25,1.0) if pl in self.duos: - t0=np.nanmedian(self.trace['t0_2_'+pl]) - per=np.nanmedian(self.trace['per_'+pl],axis=0) + t0=np.nanmedian(self.trace.posterior['t0_2_'+pl]) + per=np.nanmedian(self.trace.posterior['per_'+pl],axis=0) elif pl in self.monos: - per=np.nanmedian(self.trace['per_'+pl],axis=0) + per=np.nanmedian(self.trace.posterior['per_'+pl],axis=0) #alphas=[alphas] elif hasattr(self,'init_soln'): t0=self.init_soln['t0_'+pl] @@ -3450,7 +3449,7 @@ def plot_RVs(self, interactive=False, plot_alias='best', nbest=4, n_samp=300, ov if pl==(self.duos+self.monos)[nmargtoplot]: for n,alias in enumerate(ibest): if hasattr(self,'trace'): - K=np.clip(np.nanmedian(self.trace['K_'+pl][:,alias]),averr,100000) + K=np.clip(np.nanmedian(self.trace.posterior['K_'+pl][:,alias]),averr,100000) else: K=np.clip(self.init_soln['K_'+pl][alias],averr,100000) @@ -3546,7 +3545,7 @@ def plot_RVs(self, interactive=False, plot_alias='best', nbest=4, n_samp=300, ov elif pl in other_pls: if hasattr(self,'trace'): - K=np.clip(np.nanmedian(self.trace['K_'+pl]),averr,100000) + K=np.clip(np.nanmedian(self.trace.posterior['K_'+pl]),averr,100000) else: K=np.clip(self.init_soln['K_'+pl],averr,100000) if interactive: @@ -3867,7 +3866,7 @@ def plot(self, interactive=False, n_samp=None, overwrite=False, interp=True, new self.lc.flux_err[self.lc_regions[key]['ix']])), binsize=29/1440) else: - phot_mean=np.nanmedian(self.trace['phot_mean']) if hasattr(self,'trace') else self.init_soln['phot_mean'] + phot_mean=np.nanmedian(self.trace.posterior['phot_mean']) if hasattr(self,'trace') else self.init_soln['phot_mean'] #Plotting each part of the lightcurve: @@ -4048,7 +4047,7 @@ def plot(self, interactive=False, n_samp=None, overwrite=False, interp=True, new lower_head=arrow_heads.TeeHead(line_color='#dddddd',line_alpha=0.5))) else: #Plotting detrended: - phot_mean=np.nanmedian(self.trace['phot_mean']) if hasattr(self,'trace') else self.init_soln['phot_mean'] + phot_mean=np.nanmedian(self.trace.posterior['phot_mean']) if hasattr(self,'trace') else self.init_soln['phot_mean'] if np.nanmedian(np.diff(self.lc.time[self.lc_regions[key]['ix']]))<1/72: errors = ColumnDataSource(data=dict(base=bin_resids[:,0], lower=bin_resids[:,1] - bin_resids[:,2], @@ -4233,17 +4232,17 @@ def plot(self, interactive=False, n_samp=None, overwrite=False, interp=True, new setattr(self.lc,'phase',{}) for n,pl in enumerate(self.planets): if hasattr(self,'trace'): - t0s=[np.nanmedian(self.trace['t0_'+pl]), np.nanmedian(self.trace['t0_2_'+pl]), np.nanmedian(self.trace['t0_3_'+pl])] if pl in self.trios else [np.nanmedian(self.trace['t0_'+pl])] + t0s=[np.nanmedian(self.trace.posterior['t0_'+pl]), np.nanmedian(self.trace.posterior['t0_2_'+pl]), np.nanmedian(self.trace.posterior['t0_3_'+pl])] if pl in self.trios else [np.nanmedian(self.trace.posterior['t0_'+pl])] if pl in self.multis or pl in self.rvplanets: - per=np.nanmedian(self.trace['per_'+pl]) + per=np.nanmedian(self.trace.posterior['per_'+pl]) elif pl in self.duos+self.trios: - per=np.max(np.nanmedian(self.trace['per_'+pl],axis=0)) + per=np.max(np.nanmedian(self.trace.posterior['per_'+pl],axis=0)) elif pl in self.monos: per=3e3 if 'tdur_'+pl in self.init_soln: - binsize=np.nanmedian(self.trace['tdur_'+pl])/n_intrans_bins + binsize=np.nanmedian(self.trace.posterior['tdur_'+pl])/n_intrans_bins elif 'tdur_'+pl+'[0]' in self.init_soln: - binsize=np.nanmedian(self.trace['tdur_'+pl+'[0]'])/n_intrans_bins + binsize=np.nanmedian(self.trace.posterior['tdur_'+pl+'[0]'])/n_intrans_bins elif hasattr(self,'init_soln'): t0s=[self.init_soln['t0_'+pl], self.init_soln['t0_2_'+pl], self.init_soln['t0_3_'+pl]] if pl in self.trios else [self.init_soln['t0_'+pl]] if pl in self.multis or pl in self.rvplanets: @@ -4500,13 +4499,13 @@ def plot_periods(self, plot_loc=None, ylog=True, xlog=True, nbins=25, plt.subplot(1,len(plot_pers),npl+1) if pl in self.duos+self.trios: #As we're using the nanmedian log10(prob)s for each period, we need to make sure their sums add to 1.0 - probs=logsumexp(np.log(extra_factor)+self.trace['logprob_marg_'+pl] - logsumexp(np.log(extra_factor)+self.trace['logprob_marg_'+pl]),axis=0)/np.log(10) - pers = np.nanmedian(self.trace['per_'+pl],axis=0) + probs=logsumexp(np.log(extra_factor)+self.trace.posterior['logprob_marg_'+pl] - logsumexp(np.log(extra_factor)+self.trace.posterior['logprob_marg_'+pl]),axis=0)/np.log(10) + pers = np.nanmedian(self.trace.posterior['per_'+pl],axis=0) pmax = np.nanmax(pers)*1.03 if pmax is None else pmax pmin = np.nanmin(pers)*0.9 if pmin is None else pmin ymax = np.max(probs[perspmin: bins=np.arange(np.floor(self.planets[pl]['per_gaps']['gap_starts'][ngap])-0.5, np.clip(np.ceil(self.planets[pl]['per_gaps']['gap_ends'][ngap])+0.5,0.0,pmax), 1.0) print(bins) - ncol=int(np.floor(np.clip(np.nanmedian(self.trace['logprob_marg_'+pl][:,ngap])-total_av_prob,-6,0))) + ncol=int(np.floor(np.clip(np.nanmedian(self.trace.posterior['logprob_marg_'+pl][:,ngap])-total_av_prob,-6,0))) #print(self.planets[pl]['per_gaps']['gap_starts'][ngap], - # ncol,np.nanmedian(self.trace['logprob_marg_'+pl][:,ngap])-total_av_prob) - #print(ngap,np.exp(self.trace['logprob_marg_'+pl][:,ngap]-total_prob)) + # ncol,np.nanmedian(self.trace.posterior['logprob_marg_'+pl][:,ngap])-total_av_prob) + #print(ngap,np.exp(self.trace.posterior['logprob_marg_'+pl][:,ngap]-total_prob)) if ncol not in cols: cols+=[ncol] - plt.hist(self.trace['per_'+pl][:,ngap], bins=bins, edgecolor=sns.color_palette()[0], - weights=np.exp(self.trace['logprob_marg_'+pl][:,ngap]-total_prob), + plt.hist(self.trace.posterior['per_'+pl][:,ngap], bins=bins, edgecolor=sns.color_palette()[0], + weights=np.exp(self.trace.posterior['logprob_marg_'+pl][:,ngap]-total_prob), color=pal[6+ncol],histtype="stepfilled",label=coldic[ncol]) else: - plt.hist(self.trace['per_'+pl][:,ngap], bins=bins, edgecolor=sns.color_palette()[0], - weights=np.exp(self.trace['logprob_marg_'+pl][:,ngap]-total_prob), + plt.hist(self.trace.posterior['per_'+pl][:,ngap], bins=bins, edgecolor=sns.color_palette()[0], + weights=np.exp(self.trace.posterior['logprob_marg_'+pl][:,ngap]-total_prob), color=pal[6+ncol],histtype="stepfilled") plt.title("Mono - "+str(pl)) @@ -4614,13 +4613,13 @@ def plot_corner(self,corner_vars=None,use_marg=True,truths=None): for pl in self.planets: for var in self.fit_params: - if var+'_'+pl in self.trace: + if var+'_'+pl in self.trace.posterior: corner_vars+=[var+'_'+pl] if pl in self.duos+self.trios: corner_vars+=['t0_2_'+pl] if use_marg: for var in self.marginal_params: - if var+'_marg_'+pl in self.trace: + if var+'_marg_'+pl in self.trace.posterior: corner_vars+=[var+'_marg_'+pl] print("variables for Corner:",corner_vars) @@ -4678,22 +4677,22 @@ def plot_corner(self,corner_vars=None,use_marg=True,truths=None): for mpl in self.monos: for n_gap in np.arange(self.planets[mpl]['ngaps']): sampl_loc=np.in1d(np.arange(0,len(samples),1),np.arange(n_pos*samples_len,(n_pos+1)*samples_len,1)) - samples.loc[sampl_loc,'per_marg_'+mpl]=self.trace['per_'+mpl][:,n_gap] + samples.loc[sampl_loc,'per_marg_'+mpl]=self.trace.posterior['per_'+mpl][:,n_gap] if 'tdur' in self.marginal_params: - samples.loc[sampl_loc,'tdur_marg_'+mpl]=self.trace['tdur_'+mpl][:,n_gap] + samples.loc[sampl_loc,'tdur_marg_'+mpl]=self.trace.posterior['tdur_'+mpl][:,n_gap] elif 'b' in self.marginal_params: - samples.loc[sampl_loc,'b_marg_'+mpl]=self.trace['b_'+mpl][:,n_gap] - samples.loc[sampl_loc,'log_prob']=self.trace['logprob_marg_'+mpl][:,n_gap] + samples.loc[sampl_loc,'b_marg_'+mpl]=self.trace.posterior['b_'+mpl][:,n_gap] + samples.loc[sampl_loc,'log_prob']=self.trace.posterior['logprob_marg_'+mpl][:,n_gap] n_pos+=1 for dpl in self.duos: for n_per in np.arange(len(self.planets[dpl]['period_aliases'])): sampl_loc=np.in1d(np.arange(len(samples)),np.arange(n_pos*samples_len,(n_pos+1)*samples_len)) - samples.loc[sampl_loc,'per_marg_'+dpl]=self.trace['per_'+dpl][:,n_per] + samples.loc[sampl_loc,'per_marg_'+dpl]=self.trace.posterior['per_'+dpl][:,n_per] if 'tdur' in self.marginal_params: - samples.loc[sampl_loc,'tdur_marg_'+dpl]=self.trace['tdur_'+dpl][:,n_per] + samples.loc[sampl_loc,'tdur_marg_'+dpl]=self.trace.posterior['tdur_'+dpl][:,n_per] elif 'b' in self.marginal_params: - samples.loc[sampl_loc,'b_marg_'+dpl]=self.trace['b_'+dpl][:,n_per] - samples.loc[sampl_loc,'log_prob'] = self.trace['logprob_marg_'+dpl][:,n_per] + samples.loc[sampl_loc,'b_marg_'+dpl]=self.trace.posterior['b_'+dpl][:,n_per] + samples.loc[sampl_loc,'log_prob'] = self.trace.posterior['logprob_marg_'+dpl][:,n_per] n_pos+=1 weight_samps = np.exp(samples["log_prob"]) fig = corner.corner(samples[[col for col in samples.columns if col!='log_prob']],weights=weight_samps); @@ -4722,15 +4721,15 @@ def make_table(self,short=True,save=True,cols=['all']): cols_to_remove+=['mono_uniform_index','logliks','_priors','logprob_marg','logrho_S'] for col in self.marginal_params: cols_to_remove+=['mono_'+col+'s','duo_'+col+'s'] - medvars=[var for var in self.trace if not np.any([icol in var for icol in cols_to_remove])] + medvars=[var for var in self.trace.posterior if not np.any([icol in var for icol in cols_to_remove])] #print(cols_to_remove, medvars) - df = pm.summary(self.trace.unstack(),var_names=medvars,stat_funcs={"5%": lambda x: np.percentile(x, 5), + df = pm.summary(self.trace,var_names=medvars,stat_funcs={"5%": lambda x: np.percentile(x, 5), "-$1\sigma$": lambda x: np.percentile(x, 15.87), "median": lambda x: np.percentile(x, 50), "+$1\sigma$": lambda x: np.percentile(x, 84.13), "95%": lambda x: np.percentile(x, 95)},round_to=5) else: - df = pm.summary(self.trace.unstack(),var_names=cols,stat_funcs={"5%": lambda x: np.percentile(x, 5), + df = pm.summary(self.trace,var_names=cols,stat_funcs={"5%": lambda x: np.percentile(x, 5), "-$1\sigma$": lambda x: np.percentile(x, 15.87), "median": lambda x: np.percentile(x, 50), "+$1\sigma$": lambda x: np.percentile(x, 84.13), @@ -4878,7 +4877,7 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180, print("time range",time_start.isot,"->",time_end.isot) time_start = time_start.jd - self.lc.jd_base time_end = time_end.jd - self.lc.jd_base - elif type(time_start) in [int, np.int64, float, np.float64] and abs(time_start-np.nanmedian(self.trace['t0_'+list(self.planets.keys())[0]]))>5000: + elif type(time_start) in [int, np.int64, float, np.float64] and abs(time_start-np.nanmedian(self.trace.posterior['t0_'+list(self.planets.keys())[0]]))>5000: #This looks like a proper julian date. Let's reformat to match the lightcurve time_start -= self.lc.jd_base time_end -= self.lc.jd_base @@ -4898,36 +4897,36 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180, for pl in loopplanets: all_trans=pd.DataFrame() if pl in self.duos+self.trios: - sum_all_probs=np.logaddexp.reduce(np.nanmedian(self.trace['logprob_marg_'+pl],axis=0)) - trans_p0=np.floor(np.nanmedian(time_start - self.trace['t0_2_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values,axis=0)) - trans_p1=np.ceil(np.nanmedian(time_end - self.trace['t0_2_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values,axis=0)) + sum_all_probs=np.logaddexp.reduce(np.nanmedian(self.trace.posterior['logprob_marg_'+pl],axis=0)) + trans_p0=np.floor(np.nanmedian(time_start - self.trace.posterior['t0_2_'+pl].values)/np.nanmedian(self.trace.posterior['per_'+pl].values,axis=0)) + trans_p1=np.ceil(np.nanmedian(time_end - self.trace.posterior['t0_2_'+pl].values)/np.nanmedian(self.trace.posterior['per_'+pl].values,axis=0)) n_trans=trans_p1-trans_p0 elif pl in self.multis: - trans_p0=[np.floor(np.nanmedian(time_start - self.trace['t0_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values))] - trans_p1=[np.ceil(np.nanmedian(time_end - self.trace['t0_'+pl].values)/np.nanmedian(self.trace['per_'+pl].values))] + trans_p0=[np.floor(np.nanmedian(time_start - self.trace.posterior['t0_'+pl].values)/np.nanmedian(self.trace.posterior['per_'+pl].values))] + trans_p1=[np.ceil(np.nanmedian(time_end - self.trace.posterior['t0_'+pl].values)/np.nanmedian(self.trace.posterior['per_'+pl].values))] n_trans=[trans_p1[0]-trans_p0[0]] #print(pl,trans_p0,trans_p1,n_trans) - #print(np.nanmedian(self.trace['t0_2_'+pl])+np.nanmedian(self.trace['per_'+pl],axis=0)*trans_p0) - #print(np.nanmedian(self.trace['t0_2_'+pl])+np.nanmedian(self.trace['per_'+pl],axis=0)*trans_p1) + #print(np.nanmedian(self.trace.posterior['t0_2_'+pl])+np.nanmedian(self.trace.posterior['per_'+pl],axis=0)*trans_p0) + #print(np.nanmedian(self.trace.posterior['t0_2_'+pl])+np.nanmedian(self.trace.posterior['per_'+pl],axis=0)*trans_p1) nms=['-2sig','-1sig','med','+1sig','+2sig'] percentiles=(2.2750132, 15.8655254, 50., 84.1344746, 97.7249868) #Getting the important trace info (tcen, dur, etc) for each alias: if 'tdur' in self.fit_params or pl in self.multis: - dur=np.nanpercentile(self.trace['tdur_'+pl],percentiles) + dur=np.nanpercentile(self.trace.posterior['tdur_'+pl],percentiles) naliases=[0] if pl in self.multis else np.arange(self.planets[pl]['npers']) idfs=[] for nd in naliases: if n_trans[nd]>0: if pl in self.duos+self.trios: int_alias=int(self.planets[pl]['period_int_aliases'][nd]) - transits=np.nanpercentile(np.vstack([self.trace['t0_2_'+pl].values+ntr*self.trace['per_'+pl].values[:,nd] for ntr in np.arange(trans_p0[nd],trans_p1[nd])]),percentiles,axis=1) + transits=np.nanpercentile(np.vstack([self.trace.posterior['t0_2_'+pl].values+ntr*self.trace.posterior['per_'+pl].values[nd,:] for ntr in np.arange(trans_p0[nd],trans_p1[nd])]),percentiles,axis=1) if 'tdur' in self.marginal_params: - dur=np.nanpercentile(self.trace['tdur_'+pl][:,nd],percentiles) - logprobs=np.nanmedian(self.trace['logprob_marg_'+pl][:,nd])-sum_all_probs + dur=np.nanpercentile(self.trace.posterior['tdur_'+pl][:,nd],percentiles) + logprobs=np.nanmedian(self.trace.posterior['logprob_marg_'+pl][:,nd])-sum_all_probs else: - transits=np.nanpercentile(np.column_stack([self.trace['t0_'+pl].values+ntr*self.trace['per_'+pl].values for ntr in np.arange(trans_p0[nd],trans_p1[nd],1.0)]),percentiles,axis=0) + transits=np.nanpercentile(np.column_stack([self.trace.posterior['t0_'+pl].values+ntr*self.trace.posterior['per_'+pl].values for ntr in np.arange(trans_p0[nd],trans_p1[nd],1.0)]),percentiles,axis=0) int_alias=1 logprobs=np.array([0.0]) #Getting the aliases for this: @@ -4953,7 +4952,7 @@ def predict_future_transits(self, time_start=None, time_end=None, time_dur=180, 'prob':np.tile(np.exp(logprobs),len(transits[2])), 'planet_name':np.tile('multi_'+pl,len(transits[2])) if pl in self.multis else np.tile('duo_'+pl,len(transits[2])), 'alias_n':np.tile(nd,len(transits[2])), - 'alias_p':np.tile(np.nanmedian(self.trace['per_'+pl].values[:,nd]),len(transits[2])) if pl in self.duos+self.trios else np.tile(np.nanmedian(self.trace['per_'+pl].values),len(transits[2]))})] + 'alias_p':np.tile(np.nanmedian(self.trace.posterior['per_'+pl].values[:,nd]),len(transits[2])) if pl in self.duos+self.trios else np.tile(np.nanmedian(self.trace.posterior['per_'+pl].values),len(transits[2]))})] all_trans=pd.concat(idfs) unq_trans = all_trans.sort_values('log_prob').copy().drop_duplicates('transit_fractions') unq_trans = unq_trans.set_index(np.arange(len(unq_trans))) @@ -5124,9 +5123,9 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti searchpls=[pl] for ipl in searchpls: if self.n_margs[ipl]>1: - allprobs=np.exp(np.nanmedian(self.trace['logprob_marg_'+ipl],axis=0)) + allprobs=np.exp(np.nanmedian(self.trace.posterior['logprob_marg_'+ipl],axis=0)) allprobs/=np.sum(allprobs) #normalising - allpers=np.arange(self.trace['per_'+ipl].shape[1]) + allpers=np.arange(self.trace.posterior['per_'+ipl].shape[1]) else: allprobs=np.array([1.0]) allpers=np.array([0]) @@ -5147,8 +5146,8 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti if max_ORsobserve_threshold)>max_ORs and ipl not in self.multis: observe_threshold=np.sort(allprobs)[::-1][max_ORs] - depth=1e6*np.nanmedian(self.trace['ror_'+ipl])**2 - print("SNR for whole transit is: ",depth/self.cheops_RMS(gaiainfo['phot_g_mean_mag'], np.nanmedian(self.trace['tdur_'+ipl]))) + depth=1e6*np.nanmedian(self.trace.posterior['ror_'+ipl])**2 + print("SNR for whole transit is: ",depth/self.cheops_RMS(gaiainfo['phot_g_mean_mag'], np.nanmedian(self.trace.posterior['tdur_'+ipl]))) print("SNR for single orbit in/egress is: ",depth/self.cheops_RMS(gaiainfo['phot_g_mean_mag'], 0.5*98/1440)) prio_1_prob_threshold = np.ceil(np.sum(allprobs>observe_threshold)*prio_1_threshold) @@ -5158,7 +5157,7 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti #print(allpers,nper,allprobs[nper],observe_threshold) if allprobs[nper]>observe_threshold: ser={} - iper=np.nanmedian(self.trace['per_'+ipl][:,nper]) if len(self.trace['per_'+ipl].shape)>1 else np.nanmedian(self.trace['per_'+ipl]) + iper=np.nanmedian(self.trace.posterior['per_'+ipl][:,nper]) if len(self.trace.posterior['per_'+ipl].shape)>1 else np.nanmedian(self.trace.posterior['per_'+ipl]) ser['ObsReqName']=self.id_dic[self.mission]+str(self.ID)+'_'+ipl+'_period'+str(np.round(iper,2)).replace('.',';')+'_prob'+str(allprobs[nper])[:4] ser['Target']=self.id_dic[self.mission]+str(self.ID) if targetnamestring is None else targetnamestring ser['_RAJ2000']=old_radec.ra.to_string(unit=u.hourangle, sep=':') @@ -5181,12 +5180,12 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti ser['BJD_late']=t_end #Total observing time must cover duration, and either the full timing bound (i.e. assuming 3 sigma), or the oot_min_orbits (if the timing precision is better than the oot_min_orbits) - dur=np.nanpercentile(self.trace['tdur_'+ipl],[16,50,84]) - n_trans_av = np.round(((0.5*(t_end+t_start)-self.lc.jd_base)-np.nanmedian(self.trace['t0_'+ipl]))/iper) - if len(self.trace['per_'+ipl].shape)>1: - i_timing_bounds = np.percentile(self.trace['t0_'+ipl]+n_trans_av*self.trace['per_'+ipl][:,nper],[100*(1-stats.norm.cdf(timing_sigma)), 50, 100*stats.norm.cdf(timing_sigma)]) + dur=np.nanpercentile(self.trace.posterior['tdur_'+ipl],[16,50,84]) + n_trans_av = np.round(((0.5*(t_end+t_start)-self.lc.jd_base)-np.nanmedian(self.trace.posterior['t0_'+ipl]))/iper) + if len(self.trace.posterior['per_'+ipl].shape)>1: + i_timing_bounds = np.percentile(self.trace.posterior['t0_'+ipl]+n_trans_av*self.trace.posterior['per_'+ipl][:,nper],[100*(1-stats.norm.cdf(timing_sigma)), 50, 100*stats.norm.cdf(timing_sigma)]) else: - i_timing_bounds = np.percentile(self.trace['t0_'+ipl]+n_trans_av*self.trace['per_'+ipl],[100*(1-stats.norm.cdf(timing_sigma)), 50, 100*stats.norm.cdf(timing_sigma)]) + i_timing_bounds = np.percentile(self.trace.posterior['t0_'+ipl]+n_trans_av*self.trace.posterior['per_'+ipl],[100*(1-stats.norm.cdf(timing_sigma)), 50, 100*stats.norm.cdf(timing_sigma)]) timing_bounds = (i_timing_bounds[-1] - i_timing_bounds[0])*1440/98.7 dur_bounds = (dur[-1]-dur[0])*1440/98.7 if min_intrans_orbits is None: @@ -5217,7 +5216,7 @@ def make_cheops_OR(self, DR2ID=None, pl=None, min_eff=45, oot_min_orbits=1.0, ti ser['Texp']=Texp ser['MinEffDur']=min_eff ser['Gaia_DR2']=str(DR2ID) - ser['BJD_0']=self.lc.jd_base+np.nanmedian(self.trace['t0_'+ipl]) + ser['BJD_0']=self.lc.jd_base+np.nanmedian(self.trace.posterior['t0_'+ipl]) ser['Period']=iper #ser['T_visit']*0.5 @@ -5394,9 +5393,9 @@ def to_latex_table(self,varnames='all',order='columns'): if not hasattr(self,'savenames'): self.get_savename(how='save') if not hasattr(self,'tracemask') or self.tracemask is None: - self.tracemask=np.tile(True,len(self.trace['Rs'])) + self.tracemask=np.tile(True,len(self.trace.posterior['Rs'])) if varnames is None or varnames == 'all': - varnames=[var for var in self.trace if var[-2:]!='__' and var not in ['gp_pred','light_curves']] + varnames=[var for var in self.trace.posterior if var[-2:]!='__' and var not in ['gp_pred','light_curves']] self.samples = self.make_table(cols=varnames) self.samples = self.samples.loc[self.tracemask]