diff --git a/MonoTools/fit.py b/MonoTools/fit.py index cdef361..7691c5c 100755 --- a/MonoTools/fit.py +++ b/MonoTools/fit.py @@ -1235,7 +1235,7 @@ def init_model(self, overwrite=False, **kwargs): elif self.use_multinest: self.run_multinest(**kwargs) - def init_GP(self, n_draws=900, max_len_lc=25000, use_binned=False, overwrite=False, n_chains=4, **kwargs): + def init_GP(self, n_draws=900, max_len_lc=25000, use_binned=False, overwrite=False, n_chains=4, cores=4, **kwargs): """Function to train GPs on out-of-transit photometry Args: @@ -1349,7 +1349,7 @@ def init_GP(self, n_draws=900, max_len_lc=25000, use_binned=False, overwrite=Fal self.gp_init_soln = pmx.optimize(start=None, vars=optvars) if self.debug: print("sampling init GP", int(n_draws*0.66),"times with",len(self.lc.flux[mask]),"-point lightcurve") - self.gp_init_trace = pm.sample(tune=int(n_draws*0.66), draws=n_draws, start=self.gp_init_soln, chains=n_chains, **kwargs) + self.gp_init_trace = pm.sample(tune=int(n_draws*0.66), draws=n_draws, start=self.gp_init_soln, chains=n_chains,cores=cores)# **kwargs) def init_interpolated_Mp_prior(self): """Initialise a 2D interpolated prior for the mass of a planet given the radius @@ -2649,7 +2649,7 @@ def create_orbit(pl, Rs, rho_S, pers, t0s, bs, n_marg=1, eccs=None, omegas=None) self.model = model self.init_soln = map_soln - def sample_model(self, n_draws=500, n_burn_in=None, overwrite=False, continue_sampling=False, n_chains=4, **kwargs): + def sample_model(self, n_draws=500, n_burn_in=None, overwrite=False, continue_sampling=False, n_chains=4, cores=4, **kwargs): """Run PyMC3 sampler Args: @@ -2683,9 +2683,9 @@ def sample_model(self, n_draws=500, n_burn_in=None, overwrite=False, continue_sa if self.debug: print(self.init_soln.keys()) if hasattr(self,'trace') and continue_sampling: print("Using already-generated MCMC trace as start point for new trace") - self.trace = pm.sample(tune=n_burn_in, draws=n_draws, chains=n_chains, trace=self.trace, compute_convergence_checks=False)#, **kwargs) + self.trace = pm.sample(tune=n_burn_in, draws=n_draws, chains=n_chains, trace=self.trace, compute_convergence_checks=False, cores=cores)#, **kwargs) else: - self.trace = pm.sample(tune=n_burn_in, draws=n_draws, start=self.init_soln, chains=n_chains, compute_convergence_checks=False)#, **kwargs) + self.trace = pm.sample(tune=n_burn_in, draws=n_draws, start=self.init_soln, chains=n_chains, compute_convergence_checks=False, cores=cores)#, **kwargs) #Saving both the class and a pandas dataframe of output data. self.save_model_to_file() _=self.make_table(save=True)