-
Notifications
You must be signed in to change notification settings - Fork 2
/
fit_lcs.py
94 lines (86 loc) · 3.69 KB
/
fit_lcs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
import click
import pickle
import sncosmo
import logging
import numpy as np
logger = logging.getLogger()
logger = logger.setLevel('INFO')
DS_NAMES = ['jla', 'csp', 'des', 'foundation', 'ps1']
DATA_DIR = '/home/samdixon/sncosmo_lc_fits/data'
OUT_DIR = '/home/samdixon/sncosmo_lc_fits/fits'
def fit_lc_and_save(lc, model_name, save_path, no_mc):
name = lc.meta['name']
model = sncosmo.Model(source=model_name,
effects=[sncosmo.CCM89Dust()],
effect_names=['mw'],
effect_frames=['obs'])
z = lc.meta['z']
if np.isnan(lc.meta['mwebv']):
mwebv = 0
else:
mwebv = lc.meta['mwebv']
bounds = {}
if np.isnan(lc.meta['t0']):
t0 = np.mean(lc['time'])
bounds['t0'] = (min(lc['time'])-20, max(lc['time']))
else:
t0 = lc.meta['t0']
bounds['t0'] = (t0 - 5, t0 + 5)
bounds['z'] = ((1 - 1e-4) * z, (1 + 1e-4) * z)
for param_name in model.source.param_names[1:]:
bounds[param_name] = (-10, 10)
bounds['x0'] = (0, 1)
modelcov = model_name=='salt2' # model covariance only supported for SALT2
model.set(z=z, t0=t0, mwebv=mwebv)
phase_range = (-15, 45) if model_name=='salt2' else (-10, 40)
wave_range = (3000, 7000) if model_name=='salt2' else None
min_result, min_fit_model = sncosmo.fit_lc(lc, model,
model.param_names[:-2],
bounds=bounds,
phase_range=phase_range,
wave_range=wave_range,
warn=False,
modelcov=modelcov)
if not no_mc:
cut_lc = sncosmo.select_data(lc, min_result['data_mask'])
try:
mc_result, mc_fit_model = sncosmo.mcmc_lc(cut_lc,
min_fit_model,
model.param_names[:-2],
guess_t0=False,
bounds=bounds,
warn=False,
nwalkers=10,
modelcov=modelcov)
pickle.dump(mc_result, open(save_path, 'wb'))
except:
pickle.dump(min_result, open(save_path, 'wb'))
logging.warning('MCMC fit to {} failed, using minuit result'.format(name))
else:
pickle.dump(min_result, open(save_path, 'wb'))
@click.command()
@click.argument('dataset', type=click.Choice(DS_NAMES))
@click.argument('start', default=0)
@click.argument('end', default=-1)
@click.option('--outdir', default=OUT_DIR)
@click.option('--model', default='salt2')
@click.option('--no_mc', is_flag=True)
def main(dataset, start, end, outdir, model, no_mc):
data_path = os.path.join(DATA_DIR, '{}_lcs.pkl'.format(dataset))
save_dir = os.path.join(outdir, dataset)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
with open(data_path, 'rb') as f:
data = pickle.load(f)
names = sorted([str(_) for _ in data.keys()])
for sn_name in names[start:end]:
save_path = os.path.join(save_dir, '{}.pkl'.format(sn_name))
try:
pickle.load(open(save_path, 'rb'))
logging.info('{} loaded from file'.format(sn_name))
except FileNotFoundError:
logging.info('Fitting {}'.format(sn_name))
fit_lc_and_save(data[sn_name], model, save_path, no_mc)
if __name__ == '__main__':
main()