Skip to content

Commit

Permalink
adding kwargs to find_time_regions for large timeseries which break t…
Browse files Browse the repository at this point in the history
…he classic plotting
  • Loading branch information
Hugh Osborn authored and Hugh Osborn committed Jun 23, 2023
1 parent e00afd1 commit 93353e1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions MonoTools/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2516,7 +2516,7 @@ def Table(self):
savefileloc=None, tracemask=tracemask)
'''

def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False, overwrite=False):
def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False, overwrite=False,**kwargs):
"""Initialise the GP model for plotting.
As it is memory-intensive to store predicted GP samples for each datapoint in the light curve during sampling,
Expand All @@ -2542,7 +2542,7 @@ def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False,
elif interp:
from scipy import interpolate
if not hasattr(self,'lc_regions'):
self.init_plot(plot_type='lc')
self.init_plot(plot_type='lc',**kwargs)
gp_pred=[]
gp_sd=[]
self.gp_to_plot={'n_samp':n_samp}
Expand Down Expand Up @@ -2678,7 +2678,7 @@ def init_gp_to_plot(self, n_samp=7, max_gp_len=12000, interp=True, newgp=False,
self.gp_to_plot['gp_pred']=np.hstack(gp_pred)
self.gp_to_plot['gp_sd']=np.hstack(gp_sd)

def init_trans_to_plot(self,n_samp=None):
def init_trans_to_plot(self,n_samp=None,**kwargs):
"""Initialising the transit models to plot
The result is the `trans_to_plot` array, which is a dictionary of predicted transit flux model percentiles computed for each point in the time series.
Expand All @@ -2689,7 +2689,7 @@ def init_trans_to_plot(self,n_samp=None):
n_samp=len(self.trace['phot_mean']) if n_samp is None else n_samp
print("Initalising Transit models for plotting with n_samp=",n_samp)
if not hasattr(self,'lc_regions'):
self.init_plot(plot_type='lc')
self.init_plot(plot_type='lc',**kwargs)
self.trans_to_plot={'model':{'allpl':{}},
'all':{'allpl':{}},
'n_samp':n_samp}
Expand Down Expand Up @@ -2921,7 +2921,7 @@ def init_rvs_to_plot(self, n_samp=300, plot_alias='all'):
self.rvs_to_plot['x']["trend+offset"] = {'med':self.init_soln["rv_trend"]}
self.rvs_to_plot['x']["offsets"] = {'med':self.init_soln["rv_offsets"]}

def init_plot(self, interactive=False, gap_thresh=10, plottype='lc',pointcol='k',palette=None, ncols=None, plot_flat=False):
def init_plot(self, interactive=False, gap_thresh=10, plottype='lc',pointcol='k',palette=None, ncols=None, plot_flat=False,**kwargs):
"""Initialising plotting
Args:
Expand Down Expand Up @@ -2959,7 +2959,7 @@ def init_plot(self, interactive=False, gap_thresh=10, plottype='lc',pointcol='k'
else:
fx_lab='flux'
fx_bin_lab='bin_flux'
time_regions=tools.find_time_regions(self.lc.time[self.lc.mask])
time_regions=tools.find_time_regions(self.lc.time[self.lc.mask],**kwargs)
self.lc_regions={}
for nj in range(len(time_regions)):
self.lc_regions[nj]={'start':time_regions[nj][0],'end':time_regions[nj][1]}
Expand Down Expand Up @@ -3005,7 +3005,7 @@ def init_plot(self, interactive=False, gap_thresh=10, plottype='lc',pointcol='k'


def PlotRVs(self, interactive=False, plot_alias='best', nbest=4, n_samp=300, overwrite=False, return_fig=False, plot_resids=False,
plot_loc=None, palette=None, pointcol='k', plottype='png',raster=False, nmargtoplot=0, save=True):
plot_loc=None, palette=None, pointcol='k', plottype='png',raster=False, nmargtoplot=0, save=True,**kwargs):
"""Varied plotting function for RVs of MonoTransit model
Args:
Expand Down Expand Up @@ -3035,7 +3035,7 @@ def PlotRVs(self, interactive=False, plot_alias='best', nbest=4, n_samp=300, ove

# plot_alias - can be 'all' or 'best'. All will plot all aliases. Best will assume the highest logprob.
ncol=3+2*np.max(list(self.n_margs.values())) if plot_alias=='all' else 3+2*nbest
self.init_plot(plottype='rv', pointcol=pointcol, ncols=ncol)
self.init_plot(plottype='rv', pointcol=pointcol, ncols=ncol,**kwargs)

if not hasattr(self,'rvs_to_plot') or n_samp!=self.rvs_to_plot['n_samp'] or overwrite:
self.init_rvs_to_plot(n_samp, plot_alias)
Expand Down

0 comments on commit 93353e1

Please sign in to comment.