Skip to content

Commit

Permalink
Merge pull request #318 from California-Planet-Search/next-release
Browse files Browse the repository at this point in the history
Version 1.4.1
  • Loading branch information
bjfultn authored Jun 5, 2020
2 parents 75fdf4e + e623851 commit 2f93693
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 19 deletions.
2 changes: 1 addition & 1 deletion radvel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _custom_warningfmt(msg, *a, **b):
__all__ = ['model', 'likelihood', 'posterior', 'mcmc', 'prior', 'utils',
'fitting', 'report', 'cli', 'driver', 'gp']

__version__ = '1.4.0'
__version__ = '1.4.1'
__spec__ = __name__
__package__ = __path__[0]

Expand Down
3 changes: 3 additions & 0 deletions radvel/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ def main():
psr_mcmc.add_argument('--proceed', dest='proceed', action='store_true',
help='If True, MCMC chains will resume from the previous run'
)
psr_mcmc.add_argument('--headless', dest='headless', action='store_true',
help='If True, convergence stats will not display in real time'
)
psr_mcmc.set_defaults(func=radvel.driver.mcmc)

# Derive physical parameters
Expand Down
2 changes: 1 addition & 1 deletion radvel/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def mcmc(args):
minAfactor=args.minAfactor, maxArchange=args.maxArchange, burnAfactor=args.burnAfactor,
burnGR=args.burnGR, maxGR=args.maxGR, minTz=args.minTz, minsteps=args.minsteps,
minpercent=args.minpercent, thin=args.thin, serial=args.serial, save=args.save,
savename=backend_loc, proceed=args.proceed, proceedname=backend_loc)
savename=backend_loc, proceed=args.proceed, proceedname=backend_loc, headless=args.headless)

mintz = statevars.mintz
maxgr = statevars.maxgr
Expand Down
17 changes: 7 additions & 10 deletions radvel/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,18 +143,15 @@ def set_vary_params(self, param_values_array):
param_values_array = list(param_values_array)
i = 0
try:
for index in self.vary_params:
self.vector.vector[index][0] = param_values_array[i]
i += 1
assert i == len(param_values_array), \
"Length of array must match number of varied parameters"
if len(self.vary_params) != len(param_values_array):
self.list_vary_params()
except AttributeError:
self.list_vary_params()
for index in self.vary_params:
self.vector.vector[index][0] = param_values_array[i]
i += 1
assert i == len(param_values_array), \
"Length of array must match number of varied parameters"
for index in self.vary_params:
self.vector.vector[index][0] = param_values_array[i]
i += 1
assert i == len(param_values_array), \
"Length of array must match number of varied parameters"

def get_vary_params(self):
try:
Expand Down
17 changes: 10 additions & 7 deletions radvel/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def _status_message_CLI(statevars):
statevars.screen.refresh()


def convergence_check(minAfactor, maxArchange, maxGR, minTz, minsteps, minpercent):
def convergence_check(minAfactor, maxArchange, maxGR, minTz, minsteps, minpercent, headless):
"""Check for convergence
Check for convergence for a list of emcee samplers
Expand All @@ -116,6 +116,7 @@ def convergence_check(minAfactor, maxArchange, maxGR, minTz, minsteps, minpercen
will start after the minsteps threshold or the minpercent threshold has been hit.
minpercent (float): Minimum percentage of total steps before convergence tests are performed. Convergence checks
will start after the minsteps threshold or the minpercent threshold has been hit.
headless (bool): if set to true, the convergence statistics will not be displayed in real time.
"""

statevars.ar = 0
Expand Down Expand Up @@ -165,10 +166,11 @@ def convergence_check(minAfactor, maxArchange, maxGR, minTz, minsteps, minpercen
else:
statevars.mixcount = 0

if isnotebook():
_status_message_NB(statevars)
else:
_status_message_CLI(statevars)
if not headless:
if isnotebook():
_status_message_NB(statevars)
else:
_status_message_CLI(statevars)


def _domcmc(input_tuple):
Expand All @@ -186,7 +188,7 @@ def _domcmc(input_tuple):

def mcmc(post, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50, minAfactor=40, maxArchange=.03, burnAfactor=25,
burnGR=1.03, maxGR=1.01, minTz=1000, minsteps=1000, minpercent=5, thin=1, serial=False, save=False,
savename=None, proceed=False, proceedname=None):
savename=None, proceed=False, proceedname=None, headless=False):
"""Run MCMC
Run MCMC chains using the emcee EnsambleSampler
Args:
Expand Down Expand Up @@ -215,6 +217,7 @@ def mcmc(post, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50, minAfacto
savename (string): location of h5py file where MCMC chains will be saved for future use
proceed (bool): set to true to continue a previously saved run
proceedname (string): location of h5py file with previously MCMC run chains
headless (bool): if set to true, the convergence statistics will not display in real time
Returns:
DataFrame: DataFrame containing the MCMC samples
"""
Expand Down Expand Up @@ -385,7 +388,7 @@ def mcmc(post, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50, minAfacto
statevars.interval = t2 - t1

convergence_check(minAfactor=minAfactor, maxArchange=maxArchange, maxGR=maxGR, minTz=minTz,
minsteps=minsteps, minpercent=minpercent)
minsteps=minsteps, minpercent=minpercent, headless=headless)

if save:
for i, sampler in enumerate(statevars.samplers):
Expand Down
1 change: 1 addition & 0 deletions radvel/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class _args(types.SimpleNamespace):
savename = 'rawchains.h5'
proceed = False
proceedname = None
headless=False


def _standard_run(setupfn, arguments):
Expand Down

0 comments on commit 2f93693

Please sign in to comment.