Skip to content

Commit

Permalink
Merge pull request #300 from California-Planet-Search/next-release
Browse files Browse the repository at this point in the history
Version 1.3.7
  • Loading branch information
bjfultn authored Mar 13, 2020
2 parents 261a820 + 8c9449b commit a935048
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 54 deletions.
2 changes: 1 addition & 1 deletion radvel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _custom_warningfmt(msg, *a, **b):
__all__ = ['model', 'likelihood', 'posterior', 'mcmc', 'prior', 'utils',
'fitting', 'report', 'cli', 'driver', 'gp']

__version__ = '1.3.6'
__version__ = '1.3.7'
__spec__ = __name__
__package__ = __path__[0]

Expand Down
17 changes: 17 additions & 0 deletions radvel/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
'per tc secosw sesinw k',
'per tc ecosw esinw k',
'per tc e w k',
'per tc tp e w k', # print basis for summary report
'logper tc secosw sesinw k',
'logper tc secosw sesinw logk',
'per tc se w k',
Expand Down Expand Up @@ -183,6 +184,13 @@ def _setpar(key, new_value):
k = _getpar('k')
tp = timetrans_to_timeperi(tc, per, e, w)

if basis_name == 'per tc tp e w k':
per = _getpar('per')
tp = _getpar('tp')
e = _getpar('e')
w = _getpar('w')
k = _getpar('k')

if basis_name == 'per tc se w k':
# pull out parameters
per = _getpar('per')
Expand Down Expand Up @@ -355,6 +363,15 @@ def _delpar(key):
if not kwargs.get('keep', True):
_delpar('tp')

if newbasis == 'per tc tp e w k':
per = _getpar('per')
e = _getpar('e')
w = _getpar('w')
tp = _getpar('tp')

_setpar('tc', timeperi_to_timetrans(tp, per, e, w))
_setpar('w', w)

if newbasis == 'per tc se w k':
per = _getpar('per')
e = _getpar('e')
Expand Down
6 changes: 3 additions & 3 deletions radvel/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def main():
psr_mcmc.add_argument('--serial', dest='serial', action='store', default=False, type=bool,
help='''If True, run MCMC in serial instead of parallel. [False]'''
)
psr_mcmc.add_argument('--save', dest='save', action='store', default=False, type=bool,
help='If True, MCMC chains will be saved to be continued in a future run [False]'
psr_mcmc.add_argument('--save', dest='save', action='store_true',
help='If set, MCMC chains will be saved to be continued in a future run [False]'
)
psr_mcmc.add_argument('--proceed', dest='proceed', action='store', default=False, type=bool,
psr_mcmc.add_argument('--proceed', dest='proceed', action='store_true',
help='If True, MCMC chains will resume from the previous run'
)
psr_mcmc.set_defaults(func=radvel.driver.mcmc)
Expand Down
60 changes: 31 additions & 29 deletions radvel/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def plots(args):
args.plotkw['status'] = status
if 'saveplot' not in args.plotkw:
saveto = os.path.join(
args.outputdir,conf_base+'_rv_multipanel.pdf'
args.outputdir, conf_base+'_rv_multipanel.pdf'
)
else:
saveto = args.plotkw['saveplot']
Expand Down Expand Up @@ -106,19 +106,20 @@ def plots(args):

if ptype == 'trend':
nwalkers = status.getint('mcmc', 'nwalkers')
nensembles = status.getint('mcmc', 'nensembles')

saveto = os.path.join(args.outputdir, conf_base+'_trends.pdf')
Trend = mcmc_plots.TrendPlot(post, chains, nwalkers, saveto)
Trend = mcmc_plots.TrendPlot(post, chains, nwalkers, nensembles, saveto)
Trend.plot()

if ptype == 'derived':
assert status.has_section('derive'), \
"Must run `radvel derive` before plotting derived parameters"

P,_ = radvel.utils.initialize_posterior(config_file)
P, _ = radvel.utils.initialize_posterior(config_file)
chains = pd.read_csv(status.get('derive', 'chainfile'))
saveto = os.path.join(
args.outputdir,conf_base+'_corner_derived_pars.pdf'
args.outputdir, conf_base+'_corner_derived_pars.pdf'
)

Derived = mcmc_plots.DerivedPlot(chains, P, saveplot=saveto)
Expand Down Expand Up @@ -190,11 +191,11 @@ def mcmc(args):

print(msg1 + '\n' + msg2)

chains = radvel.mcmc(
post, nwalkers=args.nwalkers, nrun=args.nsteps, ensembles=args.ensembles, minAfactor=args.minAfactor,
maxArchange=args.maxArchange, burnAfactor=args.burnAfactor, burnGR=args.burnGR, maxGR=args.maxGR,
minTz=args.minTz, minsteps=args.minsteps, minpercent=args.minpercent, thin=args.thin, serial=args.serial,
save=args.save, savename=backend_loc, proceed=args.proceed, proceedname=backend_loc)
chains = radvel.mcmc(post, nwalkers=args.nwalkers, nrun=args.nsteps, ensembles=args.ensembles,
minAfactor=args.minAfactor, maxArchange=args.maxArchange, burnAfactor=args.burnAfactor,
burnGR=args.burnGR, maxGR=args.maxGR, minTz=args.minTz, minsteps=args.minsteps,
minpercent=args.minpercent, thin=args.thin, serial=args.serial, save=args.save,
savename=backend_loc, proceed=args.proceed, proceedname=backend_loc)

mintz = statevars.mintz
maxgr = statevars.maxgr
Expand Down Expand Up @@ -238,7 +239,7 @@ def mcmc(args):
med = synth_quantile[par][0.5]
high = synth_quantile[par][0.841] - med
low = med - synth_quantile[par][0.159]
err = np.mean([high,low])
err = np.mean([high, low])
if maxlike == -np.inf and med == -np.inf and np.isnan(low) and np.isnan(high):
err = 0.0
else:
Expand All @@ -264,7 +265,7 @@ def mcmc(args):
'{}_post_obj.pkl'.format(conf_base))
post.writeto(postfile)

csvfn = os.path.join(args.outputdir, conf_base+'_chains.csv.tar.bz2')
csvfn = os.path.join(args.outputdir, conf_base+'_chains.csv.bz2')
chains.to_csv(csvfn, compression='bz2')

auto = pd.DataFrame()
Expand All @@ -277,19 +278,19 @@ def mcmc(args):
auto.to_csv(autocorr, sep=',')

savestate = {'run': True,
'postfile': os.path.relpath(postfile),
'chainfile': os.path.relpath(csvfn),
'autocorrfile': os.path.relpath(autocorr),
'summaryfile': os.path.relpath(saveto),
'nwalkers': statevars.nwalkers,
'nensembles': args.ensembles,
'maxsteps': args.nsteps*statevars.nwalkers*args.ensembles,
'nsteps': statevars.ncomplete,
'nburn': statevars.nburn,
'minafactor': minafactor,
'maxarchange': maxarchange,
'minTz': mintz,
'maxGR': maxgr}
'postfile': os.path.relpath(postfile),
'chainfile': os.path.relpath(csvfn),
'autocorrfile': os.path.relpath(autocorr),
'summaryfile': os.path.relpath(saveto),
'nwalkers': statevars.nwalkers,
'nensembles': args.ensembles,
'maxsteps': args.nsteps*statevars.nwalkers*args.ensembles,
'nsteps': statevars.ncomplete,
'nburn': statevars.nburn,
'minafactor': minafactor,
'maxarchange': maxarchange,
'minTz': mintz,
'maxGR': maxgr}
save_status(statfile, 'mcmc', savestate)

statevars.reset()
Expand Down Expand Up @@ -386,7 +387,8 @@ def tables(args):
dchains = pd.read_csv(status.get('derive', 'chainfile'))
chains = chains.join(dchains, rsuffix='_derived')
derived = True
else: derived = False
else:
derived = False
report = radvel.report.RadvelReport(P, post, chains, minafactor, maxarchange, maxgr, mintz, derived=derived)
tabletex = radvel.report.TexTable(report)
attrdict = {'priors': 'tab_prior_summary', 'rv': 'tab_rv',
Expand All @@ -408,7 +410,7 @@ def tables(args):
elif tabtype == 'rv':
tex = getattr(tabletex, attrdict[tabtype])(name_in_title=args.name_in_title, max_lines=None)
elif tabtype == 'crit':
tex = getattr(tabletex, attrdict[tabtype])( name_in_title=args.name_in_title)
tex = getattr(tabletex, attrdict[tabtype])(name_in_title=args.name_in_title)
else:
if tabtype == 'derived':
assert status.has_option('derive', 'run'), \
Expand Down Expand Up @@ -502,7 +504,7 @@ def _get_colname(key):
_set_param('mpsini', mpsini)
outcols.append(_get_colname('mpsini'))

mtotal = mstar + (mpsini*c.M_earth.value)/c.M_sun.value # get total star plus planet mass
mtotal = mstar + (mpsini * c.M_earth.value) / c.M_sun.value # get total star plus planet mass
a = radvel.utils.semi_major_axis(per, mtotal) # changed from mstar to mtotal

_set_param('a', a)
Expand Down Expand Up @@ -539,7 +541,7 @@ def _get_colname(key):
post.writeto(postfile)
savestate['quantfile'] = os.path.relpath(csvfn)

csvfn = os.path.join(args.outputdir, conf_base+'_derived.csv.tar.bz2')
csvfn = os.path.join(args.outputdir, conf_base+'_derived.csv.bz2')
chains.to_csv(csvfn, columns=outcols, compression='bz2')
savestate['chainfile'] = os.path.relpath(csvfn)

Expand Down Expand Up @@ -581,7 +583,7 @@ def report(args):
compstats = eval(status.get('ic_compare', args.comptype))
except:
print("WARNING: Could not find {} model comparison \
in {}.\nPlease make sure that you have run `radvel ic` (or, e.g., `radvel \
in {}.\nPlease make sure that you have run `radvel ic -t {}` (or, e.g., `radvel \
ic -t nplanets e trend jit gp`)\
\nif you would like to include the model comparison table in the \
report.".format(args.comptype,
Expand Down
28 changes: 16 additions & 12 deletions radvel/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,16 @@ def mcmc(post, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50, minAfacto
pi = post.get_vary_params()
statevars.ndim = pi.size

if proceed:
if len(h5p.keys()) != (3 * statevars.ensembles + 6) or h5p['0_chain'].shape[2] != statevars.ndim or h5p['0_chain'].shape[1] != statevars.nwalkers:
raise ValueError('nensembles, nwalkers, and the number of parameters must be equal to those from previous run')
if nwalkers < 2 * statevars.ndim:
print("WARNING: Number of walkers is less than 2 times number of free parameters. " +
"Adjusting number of walkers to {}".format(2 * statevars.ndim))
statevars.nwalkers = 2 * statevars.ndim

if nwalkers < 2*statevars.ndim:
print("WARNING: Number of walkers is less than 2 times number of free parameters. Adjusting number of walkers to {}".format(2*statevars.ndim))
statevars.nwalkers = 2*statevars.ndim
if proceed:
if len(h5p.keys()) != (3 * statevars.ensembles + 6) or h5p['0_chain'].shape[2] != statevars.ndim \
or h5p['0_chain'].shape[1] != statevars.nwalkers:
raise ValueError('nensembles, nwalkers, and the number of ' +
'parameters must be equal to those from previous run.')

# set up perturbation size
pscales = []
Expand Down Expand Up @@ -316,9 +319,9 @@ def mcmc(post, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50, minAfacto
if not proceed:
statevars.initial_positions.append(p0)
else:
statevars.initial_positions.append(statevars.prechains[i][-1,:,:])
statevars.initial_positions.append(statevars.prechains[i][-1, :, :])
statevars.samplers.append(emcee.EnsembleSampler(statevars.nwalkers, statevars.ndim, post.logprob_array,
threads=1))
threads=1))

if proceed:
for i, sampler in enumerate(statevars.samplers):
Expand All @@ -332,7 +335,7 @@ def mcmc(post, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50, minAfacto
statevars.totsteps = nrun*statevars.nwalkers*statevars.ensembles
statevars.mixcount = 0
statevars.ismixed = 0
if proceed == True and statevars.preburned != 0:
if proceed and statevars.preburned != 0:
statevars.burn_complete = True
statevars.nburn = statevars.preburned
else:
Expand Down Expand Up @@ -376,9 +379,9 @@ def mcmc(post, nwalkers=50, nrun=10000, ensembles=8, checkinterval=50, minAfacto
statevars.interval = t2 - t1

convergence_check(minAfactor=minAfactor, maxArchange=maxArchange, maxGR=maxGR, minTz=minTz,
minsteps=minsteps, minpercent=minpercent)
minsteps=minsteps, minpercent=minpercent)

if save==True:
if save:
for i, sampler in enumerate(statevars.samplers):
str_chain = str(i) + '_chain'
str_log_prob = str(i) + '_log_prob'
Expand Down Expand Up @@ -580,6 +583,7 @@ def convergence_calculate(chains, oldautocorrelation, minAfactor, maxArchange, m
archange = np.divide(np.abs(np.subtract(autocorrelation, oldautocorrelation)), oldautocorrelation)

# well-mixed criteria
ismixed = min(tz) > minTz and max(gelmanrubin) < maxGR and np.amin(afactor) > minAfactor and np.amax(archange) < maxArchange
ismixed = min(tz) > minTz and max(gelmanrubin) < maxGR and \
np.amin(afactor) > minAfactor and np.amax(archange) < maxArchange

return (ismixed, afactor, archange, autocorrelation, gelmanrubin, tz)
2 changes: 1 addition & 1 deletion radvel/plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
rcParams['font.size'] = 9
rcParams['lines.markersize'] = 5
rcParams['axes.grid'] = False
default_colors = ['orange', 'purple', 'magenta', 'pink', 'green', 'grey', 'red']
default_colors = ['orange', 'purple', 'magenta', 'pink', 'green', 'grey', 'red', 'blue', 'yellow']

highlight_format = dict(marker='o', ms=16, mfc='none', mew=2, mec='gold', zorder=99)

Expand Down
16 changes: 9 additions & 7 deletions radvel/plot/mcmc_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ class TrendPlot(object):
"""

def __init__(self, post, chains, nwalkers, outfile=None):
def __init__(self, post, chains, nwalkers, nensembles, outfile=None):

self.chains = chains
self.outfile = outfile
self.nwalkers = nwalkers
self.nensembles = nensembles

self.labels = sorted([k for k in post.params.keys() if post.params[k].vary])
self.texlabels = [post.params.tex_labels().get(l, l) for l in self.labels]
Expand All @@ -45,16 +46,17 @@ def plot(self):
with PdfPages(self.outfile) as pdf:
for param, tex in zip(self.labels, self.texlabels):
flatchain = self.chains[param].values
wchain = flatchain.reshape((self.nwalkers, -1))
wchain = flatchain.reshape((self.nwalkers, self.nensembles, -1))

_ = pl.figure(figsize=(18, 10))
for w in range(self.nwalkers):
pl.plot(
wchain[w, :], '.', rasterized=True, color=self.colors[w],
markersize=3
)
for e in range(self.nensembles):
pl.plot(
wchain[w][e], '.', rasterized=True, color=self.colors[w],
markersize=4
)

pl.xlim(0, wchain.shape[1])
pl.xlim(0, wchain.shape[2])

pl.xlabel('Step Number')
try:
Expand Down
2 changes: 1 addition & 1 deletion radvel/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import radvel

env = Environment(loader=PackageLoader('radvel', 'templates'))
print_basis = 'per tc e w k'
print_basis = 'per tc tp e w k'
units = {
'per': 'days',
'tp': 'JD',
Expand Down

0 comments on commit a935048

Please sign in to comment.