Skip to content

Commit

Permalink
Added Database.get_compare_sample for extracting the best-fit sample …
Browse files Browse the repository at this point in the history
…from CompareSpectra.compare_model
  • Loading branch information
tomasstolker committed May 4, 2021
1 parent 94563cc commit ab93f46
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 17 deletions.
5 changes: 3 additions & 2 deletions docs/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ The following data and models are currently supported:

**Dust extinction**

- ISM empirical relation from `Cardelli et al. (1989) <https://ui.adsabs.harvard.edu/abs/1989ApJ...345..245C/abstract>`_
- ISM relation from `Cardelli et al. (1989) <https://ui.adsabs.harvard.edu/abs/1989ApJ...345..245C/abstract>`_
- Extinction cross sections computed with `PyMieScatt <https://pymiescatt.readthedocs.io>`_
- Optical constants compiled by `Mollière et al. (2019) <https://ui.adsabs.harvard.edu/abs/2019A%26A...627A..67M/abstract>`_

Expand All @@ -66,10 +66,11 @@ After adding the relevant data to the database, the user can take advantage of t
- Calculating synthetic photometry spectra (see :class:`~species.analysis.photometry.SyntheticPhotometry`).
- Interpolating and plotting model spectra (see :class:`~species.read.read_model.ReadModel` and :func:`~species.plot.plot_spectrum.plot_spectrum`).
- Grid retrievals with Bayesian inference (see :class:`~species.analysis.fit_model.FitModel` and :class:`~species.plot.plot_mcmc`).
- Comparing a spectrum with a full grid of model spectra (see :meth:`~species.analysis.compare_spectra.CompareSpectra.compare_model`).
- Free retrievals through a frontend for `petitRADTRANS <https://petitradtrans.readthedocs.io>`_ (see `AtmosphericRetrieval <https://github.com/tomasstolker/species/blob/retrieval/species/analysis/retrieval.py>`_ on the `retrieval branch <https://github.com/tomasstolker/species/tree/retrieval>`_).
- Creating color-magnitude diagrams (see :class:`~species.read.read_color.ReadColorMagnitude` and :class:`~species.plot.plot_color.plot_color_magnitude`).
- Creating color-color diagrams (see :class:`~species.read.read_color.ReadColorColor` and :class:`~species.plot.plot_color.plot_color_color`).
- Computing synthetic fluxes from isochrones and model spectra (see :class:`~species.read.read_isochrone.ReadIsochrone`)
- Flux calibration of photometric and spectroscopic data (see :class:`~species.read.read_calibration.ReadCalibration`, :class:`~species.analysis.fit_model.FitModel`, and :class:`~species.analysis.fit_spectrum.FitSpectrum`).
- Empirical comparison of spectra to infer the spectral type (see :class:`~species.analysis.empirical.CompareSpectra`).
- Empirical comparison of spectra to infer the spectral type (see :meth:`~species.analysis.compare_spectra.CompareSpectra.spectral_type`).
- Analyzing emission lines from accreting planets (see :class:`~species.analysis.emission_line.EmissionLine`).
27 changes: 25 additions & 2 deletions species/analysis/compare_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,10 @@ def compare_model(self,
av_points: Optional[Union[List[float], np.array]] = None) -> None:
"""
Method for finding the best fitting spectrum from a grid of atmospheric model spectra by
evaluating the goodness-of-fit statistic from Cushing et al. (2008).
evaluating the goodness-of-fit statistic from Cushing et al. (2008). Currently, this method
only supports model grids with only :math:`T_\mathrm{eff}` and :math:`\log(g)` as free
parameters (e.g. BT-Settl). Please create an issue on Github if support for models with
more than two parameters is required.
Parameters
----------
Expand Down Expand Up @@ -337,7 +340,7 @@ def compare_model(self,
elif isinstance(av_points, list):
av_points = np.array(av_points)

readmodel = read_model.ReadModel(model, wavel_range=None)
readmodel = read_model.ReadModel(model)

model_param = readmodel.get_parameters()
grid_points = readmodel.get_points()
Expand All @@ -359,17 +362,25 @@ def compare_model(self,
fit_stat = np.zeros(grid_shape)
flux_scaling = np.zeros(grid_shape)

count = 1

if len(coord_points) == 2:
n_iter = len(coord_points[0])*len(coord_points[1])

for i, item_i in enumerate(coord_points[0]):
for j, item_j in enumerate(coord_points[1]):
for k, spec_item in enumerate(self.spec_name):
print(f'Processing model spectrum {count}/{n_iter}...', end='')

obj_spec = self.object.get_spectrum()[spec_item][0]
obj_res = self.object.get_spectrum()[spec_item][3]

param_dict = {model_param[0]: item_i,
model_param[1]: item_j}

wavel_range = (0.9*obj_spec[0, 0], 1.1*obj_spec[-1, 0])
readmodel = read_model.ReadModel(model, wavel_range=wavel_range)

model_box = readmodel.get_data(param_dict,
spec_res=obj_res,
wavel_resample=obj_spec[:, 0])
Expand All @@ -382,19 +393,27 @@ def compare_model(self,
residual = obj_spec[:, 1] - flux_scaling[i, j, k]*model_box.flux
fit_stat[i, j, k] = np.sum(w_i * (residual/obj_spec[:, 2])**2)

count += 1

if len(coord_points) == 3:
n_iter = len(coord_points[0])*len(coord_points[1])*len(coord_points[2])

for i, item_i in enumerate(coord_points[0]):
for j, item_j in enumerate(coord_points[1]):
for k, item_k in enumerate(coord_points[2]):
for m, spec_item in enumerate(self.spec_name):
print(f'\rProcessing model spectrum {count}/{n_iter}...', end='')

obj_spec = self.object.get_spectrum()[spec_item][0]
obj_res = self.object.get_spectrum()[spec_item][3]

param_dict = {model_param[0]: item_i,
model_param[1]: item_j,
model_param[2]: item_k}

wavel_range = (0.9*obj_spec[0, 0], 1.1*obj_spec[-1, 0])
readmodel = read_model.ReadModel(model, wavel_range=wavel_range)

model_box = readmodel.get_data(param_dict,
spec_res=obj_res,
wavel_resample=obj_spec[:, 0])
Expand All @@ -407,6 +426,10 @@ def compare_model(self,
residual = obj_spec[:, 1] - flux_scaling[i, j, k, m]*model_box.flux
fit_stat[i, j, k, m] = np.sum(w_i * (residual/obj_spec[:, 2])**2)

count += 1

print(' [DONE]')

species_db = database.Database()

species_db.add_comparison(tag=tag,
Expand Down
98 changes: 95 additions & 3 deletions species/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
from typeguard import typechecked

from species.analysis import photometry
from species.core import box
from species.core import box, constants
from species.data import ames_cond, ames_dusty, atmo, blackbody, btcond, btcond_feh, btnextgen, \
btsettl, btsettl_cifist, companions, drift_phoenix, dust, exo_rem, \
filters, irtf, isochrones, leggett, petitcode, spex, vega, vlm_plx, \
kesseli2017
from species.read import read_calibration, read_filter, read_model, read_planck
from species.read import read_calibration, read_filter, read_model, read_object, read_planck
from species.util import data_util, dust_util, read_util


Expand Down Expand Up @@ -1376,6 +1376,70 @@ def get_median_sample(self,

return median_sample

@typechecked
def get_compare_sample(self,
tag: str,
spec_fix: Optional[str] = None) -> Dict[str, float]:
"""
Function for extracting the sample parameters with the highest posterior probability.
Parameters
----------
tag : str
Database tag where the results from
:meth:`~species.analysis.compare_spectra.CompareSpectra.compare_model` are stored.
spec_fix : str, None
After comparing multiple spectra with a model grid, one of the flux scalings need to
be used to calculate the planet radius (i.e. scaling = (radius/distance)^2). The name
of this spectrum is specified as argument of ``spec_fix``. For the other spectra,
the same radius will be used and an additional scaling parameter will be included in
the returned dictionary. When passing the returned dictionary to
:func:`~species.util.read_util.update_spectra`, the spectra can be updated with the
derived scaling corrections. The argument can be set to ``None`` if a single spectrum
was used for the comparison.
Returns
-------
dict
Dictionary with the best-fit parameters.
"""

with h5py.File(self.database, 'a') as h5_file:
dset = h5_file[f'results/comparison/{tag}/goodness_of_fit']

n_param = dset.attrs['n_param']
n_spec_name = dset.attrs['n_spec_name']

model_param = {}

for i in range(n_param):
model_param[dset.attrs[f'parameter{i}']] = dset.attrs[f'best_param{i}']

model_param['distance'] = dset.attrs['distance']

if n_spec_name == 1:
model_param['radius'] = dset.attrs[f'radius_{item}']

else:
if spec_fix is None:
raise ValueError('The argument of \'spec_fix\' should be set when the results '
'from CompareSpectra.compare_model have been obtained by '
'combining multiple spectra (i.e. the argument of '
'\'spec_name\' in CompareSpectra).')

model_param['radius'] = dset.attrs[f'radius_{spec_fix}']

for i in range(n_spec_name):
spec_name = dset.attrs[f'spec_name{i}']

if spec_name == spec_fix:
continue

model_param[f'scaling_{spec_name}'] = (
dset.attrs[f'radius_{spec_fix}'] / dset.attrs[f'radius_{spec_name}'])**2

return model_param

@typechecked
def get_mcmc_spectra(self,
tag: str,
Expand Down Expand Up @@ -1828,7 +1892,7 @@ def get_samples(self,
header += f'{item}'
if i != len(param) - 1:
header += ' - '

if out_file.endswith('.fits'):
fits.writeto(out_file, samples, overwrite=True)

Expand Down Expand Up @@ -1958,6 +2022,9 @@ def add_comparison(self,
None
"""

read_obj = read_object.ReadObject(object_name)
distance = read_obj.get_distance()[0] # (pc)

with h5py.File(self.database, 'a') as h5_file:

if 'results' not in h5_file:
Expand All @@ -1976,6 +2043,7 @@ def add_comparison(self,
dset.attrs['model'] = str(model)
dset.attrs['n_param'] = len(model_param)
dset.attrs['n_spec_name'] = len(spec_name)
dset.attrs['distance'] = distance

for i, item in enumerate(model_param):
dset.attrs[f'parameter{i}'] = item
Expand All @@ -1987,3 +2055,27 @@ def add_comparison(self,

for i, item in enumerate(coord_points):
h5_file.create_dataset(f'results/comparison/{tag}/coord_points{i}', data=item)

# Sum the goodness-of-fit of the different spectra
goodness_sum = np.sum(goodness_of_fit, axis=-1)

# Indices of the best-fit model
best_index = np.unravel_index(goodness_sum.argmin(), goodness_sum.shape)

print('Best-fit parameters:')

for i, item in enumerate(model_param):
best_param = coord_points[i][best_index[i]]
dset.attrs[f'best_param{i}'] = best_param

for i, item in enumerate(spec_name):
scaling = flux_scaling[best_index[0], best_index[1], best_index[2], i]

radius = np.sqrt(scaling * (distance*constants.PARSEC)**2) # (m)
radius /= constants.R_JUP # (Rjup)

dset.attrs[f'radius_{item}'] = radius
print(f' - {item} radius (Rjup) = {radius:.2f}')

dset.attrs[f'scaling_{item}'] = scaling
print(f' - {item} scaling = {scaling:.2e}')
2 changes: 1 addition & 1 deletion species/data/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def download_filter(filter_id: str) -> Tuple[Optional[np.ndarray],

os.remove('VisAO_zp_filter_curve.dat')

elif filter_id in ['LCO/VisAO.Ys', 'Magellan/VisAO.Ys']:
elif filter_id in ['LCO/VisAO.Ys', 'Magellan/VisAO.Ys']:
url = 'https://xwcl.science/magao/visao/VisAO_Ys_filter_curve.dat'
urllib.request.urlretrieve(url, 'VisAO_Ys_filter_curve.dat')

Expand Down
4 changes: 2 additions & 2 deletions species/data/spex.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,10 @@ def add_spex(input_path: str,
dset = database.create_dataset(f'spectra/spex/{name}', data=spdata)

dset.attrs['name'] = str(name).encode()

if sptype_opt is not None:
dset.attrs['sptype'] = str(sptype_opt).encode()
elif sptype_nir is not None:
elif sptype_nir is not None:
dset.attrs['sptype'] = str(sptype_nir).encode()
else:
dset.attrs['sptype'] = str('None').encode()
Expand Down
31 changes: 25 additions & 6 deletions species/plot/plot_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@

from matplotlib.ticker import AutoMinorLocator
from scipy.interpolate import interp1d, RegularGridInterpolator
from scipy.ndimage import gaussian_filter
from typeguard import typechecked

from species.core import constants
from species.read import read_object
from species.util import dust_util, read_util
from species.util import dust_util, read_util, plot_util


@typechecked
Expand Down Expand Up @@ -351,6 +350,12 @@ def plot_grid_statistic(tag: str,

n_param = dset.attrs['n_param']

read_obj = read_object.ReadObject(dset.attrs['object_name'])

n_wavel = 0
for key, value in read_obj.get_spectrum().items():
n_wavel += value[0].shape[0]

goodness_fit = np.array(dset)

model_param = []
Expand Down Expand Up @@ -407,6 +412,9 @@ def plot_grid_statistic(tag: str,
# Sum the goodness-of-fit of the different spectra
goodness_fit = np.sum(goodness_fit, axis=-1)

# Indices of the best-fit model
best_index = np.unravel_index(goodness_fit.argmin(), goodness_fit.shape)

# Make Teff the x axis and log(g) the y axis
goodness_fit = np.transpose(goodness_fit)

Expand All @@ -428,12 +436,11 @@ def plot_grid_statistic(tag: str,
x_grid, y_grid = np.meshgrid(x_new, y_new)

goodness_fit = fit_interp((y_grid, x_grid))
goodness_fit = gaussian_filter(goodness_fit, 1.)

c = ax.contourf(x_grid, y_grid, np.log10(goodness_fit))

cb = mpl.colorbar.Colorbar(ax=ax_cb, mappable=c, orientation='vertical',
ticklocation='right', format='%.2f')
ticklocation='right', format='%.1f')

cb.ax.tick_params(width=0.8, length=5, labelsize=12, direction='in', color='black')
cb.ax.set_ylabel(r'$\mathregular{log}\,G_k$', rotation=270, labelpad=22, fontsize=13.)
Expand All @@ -442,11 +449,23 @@ def plot_grid_statistic(tag: str,
extra_interp = RegularGridInterpolator((coord_points[1], coord_points[0]), extra_map)

extra_map = extra_interp((y_grid, x_grid))
extra_map = gaussian_filter(extra_map, 1.)

cs = ax.contour(x_grid, y_grid, extra_map, levels=5, colors='white', linewidths=0.5)
cs = ax.contour(x_grid, y_grid, extra_map, levels=10, colors='white', linewidths=0.7)
ax.clabel(cs, cs.levels, inline=True, fontsize=8, fmt='%1.1f')

ax.plot(coord_points[0][best_index[0]], coord_points[1][best_index[1]], marker='X',
ms=10., color='#eb4242', mfc='#eb4242', mec='black')

# best_param = (coord_points[0][best_index[0]], coord_points[1][best_index[1]])
#
# par_key, par_unit, par_label = plot_util.quantity_unit(model_param, object_type='planet')
#
# par_text = f'{par_label[0]} = {best_param[0]:.0f} {par_unit[0]}\n' \
# f'{par_label[1]} = {best_param[1]:.1f}'
#
# ax.annotate(par_text, (best_param[0]+50., best_param[1]), ha='left', va='center',
# color='white', fontsize=12.)

plt.savefig(os.getcwd()+'/'+output, bbox_inches='tight')
plt.clf()
plt.close()
Expand Down
2 changes: 1 addition & 1 deletion species/util/read_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def gaussian_spectrum(wavel_range: Union[Tuple[float, float],
flux = model_param['gauss_amplitude'] * gauss_exp

if double_gaussian:
gauss_exp = np.exp(-0.5*(wavel-model_param['gauss_mean_2'])**2 /
gauss_exp = np.exp(-0.5*(wavel-model_param['gauss_mean_2'])**2 /
model_param['gauss_sigma_2']**2)

flux += model_param['gauss_amplitude_2'] * gauss_exp
Expand Down

0 comments on commit ab93f46

Please sign in to comment.