diff --git a/MonoTools/fit.py b/MonoTools/fit.py index 78aada8..d2ca77e 100755 --- a/MonoTools/fit.py +++ b/MonoTools/fit.py @@ -1824,8 +1824,11 @@ def init_pymc(self,ld_mult=1.5): if self.debug: print(np.isnan(self.model_time),np.isnan(self.model_flux),np.isnan(self.model_flux_err)) if self.train_GP: #Using histograms from the output of the previous GP training as priors for the true model. - minmaxs={var:np.percentile(self.gp_init_trace.posterior[var].values,[0.5,99.5]).astype(floattype) for var in self.gp_init_trace.posterior if '__' not in var and len(self.gp_init_trace.posterior[var].shape)==1} - hists={var:np.histogram(self.gp_init_trace.posterior[var].values,np.linspace(minmaxs[var][0],minmaxs[var][1],101))[0] for var in self.gp_init_trace.posterior if '__' not in var and len(self.gp_init_trace.posterior[var].shape)==1} + vars=[var for var in self.gp_init_trace.posterior if '__' not in var and np.product(self.gp_init_trace.posterior[var].shape)<5*len(self.gp_init_trace.posterior.chain)*len(self.gp_init_trace.posterior.draw)] + ext_gp_init_trace=az.extract(self.gp_init_trace.posterior,var_names=vars) + + minmaxs={var:np.percentile(ext_gp_init_trace[var].values,[0.5,99.5]).astype(floattype) for var in vars} + hists={var:np.histogram(ext_gp_init_trace[var].values,np.linspace(minmaxs[var][0],minmaxs[var][1],101))[0] for var in vars} gpvars=[] if hasattr(self, 'periodic_kernel') and self.periodic_kernel is not None: if self.train_GP: