Skip to content

Commit

Permalink
fixing plotting bug with gps'
Browse files Browse the repository at this point in the history
  • Loading branch information
hposborn committed Aug 21, 2024
1 parent 09a4edb commit a3d1812
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions MonoTools/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2831,8 +2831,10 @@ def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False,
#assert self.bin_oot
stacktime=np.hstack((self.lc.time[0]-1,self.model_time,self.lc.time[-1]+1))
preds=[]
ext_lightcurves=az.extract(self.trace.posterior,var_names=['gp_pred'])

for i in np.random.choice(len(self.trace.posterior['phot_mean']),int(np.clip(10*n_samp,1,len(self.trace.posterior['phot_mean']))),replace=False):
smooth_func=interpolate.interp1d(stacktime, np.hstack((0,self.trace.posterior['gp_pred'][:,i],0)), kind='slinear')
smooth_func=interpolate.interp1d(stacktime, np.hstack((0,ext_lightcurves['gp_pred'][:,i],0)), kind='slinear')
preds+=[smooth_func(self.lc.time)+self.trace.posterior['phot_mean'].values[i]]
prcnts=np.nanpercentile(np.column_stack(preds),[15.8655254, 50., 84.1344746],axis=1)
self.gp_to_plot['gp_pred']=prcnts[1]
Expand All @@ -2847,10 +2849,10 @@ def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False,
transit_mask=~self.lc.in_trans['all'][self.lc_regions[key]['ix']])
i_kernel = pymc_terms.SHOTerm(S0=self.meds['phot_S0'], w0=self.meds['phot_w0'], Q=1/np.sqrt(2))
i_gp = celerite2.pymc.GaussianProcess(i_kernel, mean=self.meds['phot_mean'])
limit_mask_bool[n]={}
limit_mask_bool[key]={}
for nc,c in enumerate(cutBools):
limit_mask_bool[n][nc]=np.tile(False,len(self.lc.time))
limit_mask_bool[n][nc][self.lc_regions[key]['ix']][c]=self.lc_regions[key]['ix'][self.lc_regions[key]['ix']][c]
limit_mask_bool[key][nc]=np.tile(False,len(self.lc.time))
limit_mask_bool[key][nc][self.lc_regions[key]['ix']][c]=self.lc_regions[key]['ix'][self.lc_regions[key]['ix']][c]
i_gp_pred=[]
i_gp_var=[]
for i in np.random.choice(len(self.trace),n_samp,replace=False):
Expand All @@ -2859,18 +2861,18 @@ def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False,
i_gp.kernel = pymc_terms.SHOTerm(S0=sample['phot_S0'], w0=sample['phot_w0'], Q=1/np.sqrt(2))
i_gp.mean = sample['mean']
i_gp.recompute(self.lc.time[limit_mask_bool[n][nc]],
np.sqrt(self.lc.flux_err[limit_mask_bool[n][nc]]**2 + \
np.dot(self.lc.flux_err_index[limit_mask_bool[n][nc]], np.exp(sample['logs2']))))
np.sqrt(self.lc.flux_err[limit_mask_bool[key][nc]]**2 + \
np.dot(self.lc.flux_err_index[limit_mask_bool[key][nc]], np.exp(sample['logs2']))))
marg_lc=np.tile(0.0,len(self.lc.time))
if hasattr(self,'pseudo_binlc') and len(self.trans_to_plot['all']['med'])==len(self.pseudo_binlc['time']):
marg_lc[self.lc.near_trans['all']]=sample['marg_all_lc_model'][self.pseudo_binlc['near_trans']]
elif hasattr(self,'lc_near_trans') and len(self.trans_to_plot['all']['med'])==len(self.lc_near_trans['time']):
marg_lc[self.lc.near_trans['all']]=sample['marg_all_lc_model'][key1][key2]
marg_lc[self.lc.near_trans['all']]=sample['marg_all_lc_model']#[key1][key2]
elif len(self.trans_to_plot['all']['med'])==len(self.lc.time):
marg_lc[self.lc.near_trans['all']]=sample['marg_all_lc_model'][key1][key2][self.lc.near_trans['all']]
marg_lc[self.lc.near_trans['all']]=sample['marg_all_lc_model'][self.lc.near_trans['all']]

#marg_lc[self.lc.near_trans['all']]=sample['marg_all_lc_model'][self.lc.near_trans['all']]
ii_gp_pred, ii_gp_var= i_gp.predict(self.lc.flux[limit_mask_bool[n][nc]] - marg_lc[limit_mask_bool[n][nc]],
ii_gp_pred, ii_gp_var= i_gp.predict(self.lc.flux[limit_mask_bool[key][nc]] - marg_lc[limit_mask_bool[n][nc]],
t=self.lc.time[self.lc_regions[key]['ix']][c].astype(floattype),
return_var=True, return_cov=False, include_mean=False)

Expand Down

0 comments on commit a3d1812

Please sign in to comment.