Skip to content

Commit

Permalink
changing pm.sample to remove kwargs but add cores
Browse files Browse the repository at this point in the history
  • Loading branch information
hposborn committed Jul 12, 2024
1 parent d260cb6 commit 4230899
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions MonoTools/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ def init_model(self, overwrite=False, **kwargs):
elif self.use_multinest:
self.run_multinest(**kwargs)

def init_GP(self, n_draws=900, max_len_lc=25000, use_binned=False, overwrite=False, n_chains=4, **kwargs):
def init_GP(self, n_draws=900, max_len_lc=25000, use_binned=False, overwrite=False, n_chains=4, cores=4, **kwargs):
"""Function to train GPs on out-of-transit photometry
Args:
Expand Down Expand Up @@ -1349,7 +1349,7 @@ def init_GP(self, n_draws=900, max_len_lc=25000, use_binned=False, overwrite=Fal

self.gp_init_soln = pmx.optimize(start=None, vars=optvars)
if self.debug: print("sampling init GP", int(n_draws*0.66),"times with",len(self.lc.flux[mask]),"-point lightcurve")
self.gp_init_trace = pm.sample(tune=int(n_draws*0.66), draws=n_draws, start=self.gp_init_soln, chains=n_chains, **kwargs)
self.gp_init_trace = pm.sample(tune=int(n_draws*0.66), draws=n_draws, start=self.gp_init_soln, chains=n_chains,cores=cores)# **kwargs)

def init_interpolated_Mp_prior(self):
"""Initialise a 2D interpolated prior for the mass of a planet given the radius
Expand Down Expand Up @@ -2649,7 +2649,7 @@ def create_orbit(pl, Rs, rho_S, pers, t0s, bs, n_marg=1, eccs=None, omegas=None)
self.model = model
self.init_soln = map_soln

def sample_model(self, n_draws=500, n_burn_in=None, overwrite=False, continue_sampling=False, n_chains=4, **kwargs):
def sample_model(self, n_draws=500, n_burn_in=None, overwrite=False, continue_sampling=False, n_chains=4, cores=4, **kwargs):
"""Run PyMC3 sampler
Args:
Expand Down Expand Up @@ -2683,9 +2683,9 @@ def sample_model(self, n_draws=500, n_burn_in=None, overwrite=False, continue_sa
if self.debug: print(self.init_soln.keys())
if hasattr(self,'trace') and continue_sampling:
print("Using already-generated MCMC trace as start point for new trace")
self.trace = pm.sample(tune=n_burn_in, draws=n_draws, chains=n_chains, trace=self.trace, compute_convergence_checks=False)#, **kwargs)
self.trace = pm.sample(tune=n_burn_in, draws=n_draws, chains=n_chains, trace=self.trace, compute_convergence_checks=False, cores=cores)#, **kwargs)
else:
self.trace = pm.sample(tune=n_burn_in, draws=n_draws, start=self.init_soln, chains=n_chains, compute_convergence_checks=False)#, **kwargs)
self.trace = pm.sample(tune=n_burn_in, draws=n_draws, start=self.init_soln, chains=n_chains, compute_convergence_checks=False, cores=cores)#, **kwargs)
#Saving both the class and a pandas dataframe of output data.
self.save_model_to_file()
_=self.make_table(save=True)
Expand Down

0 comments on commit 4230899

Please sign in to comment.