Skip to content

Commit

Permalink
Added flux_units parameter to get_mcmc_photometry method of Database
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasstolker committed Apr 12, 2024
1 parent b38c9ac commit 0adc7fa
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 44 deletions.
106 changes: 68 additions & 38 deletions species/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,7 +1435,9 @@ def add_object(
read_cov[spec_item] = None

elif isinstance(spec_value[1], str):
if spec_value[1].endswith(".fits") or spec_value[1].endswith(".fit"):
if spec_value[1].endswith(".fits") or spec_value[1].endswith(
".fit"
):
with fits.open(spec_value[1]) as hdulist:
if (
"INSTRU" in hdulist[0].header
Expand Down Expand Up @@ -1480,7 +1482,10 @@ def add_object(
correlation_to_covariance,
)

if data.ndim == 2 and data.shape[0] == data.shape[1]:
if (
data.ndim == 2
and data.shape[0] == data.shape[1]
):
if spec_item not in read_cov:
if (
data.shape[0]
Expand Down Expand Up @@ -1855,7 +1860,11 @@ def _fetch_bibcode(ref: str) -> Optional[str]:
)

bibcode[spec_tag] = _fetch_bibcode(row["reference"])
spectrum[spec_tag] = (spec_data, np.diag(spec_data[:, 2]**2), float(spec_res))
spectrum[spec_tag] = (
spec_data,
np.diag(spec_data[:, 2] ** 2),
float(spec_res),
)

if spectrum is None:
if app_mag is not None:
Expand Down Expand Up @@ -2760,6 +2769,7 @@ def get_mcmc_photometry(
filter_name: str,
burnin: Optional[int] = None,
phot_type: str = "magnitude",
flux_units: str = "W m-2 um-1",
) -> np.ndarray:
"""
Function for calculating synthetic magnitudes or fluxes
Expand All @@ -2779,11 +2789,16 @@ def get_mcmc_photometry(
have been sampled with ``emcee``.
phot_type : str
Photometry type ('magnitude' or 'flux').
flux_units : tuple(str, str), None
Flux units that will be used when the ``phot_type``
argument is set to ``'flux``. Supported units can
be found in the docstring of
:func:`~species.util.data_util.convert_units`.
Returns
-------
np.ndarray
Synthetic magnitudes or fluxes (W m-2 um-1).
Synthetic magnitudes or fluxes.
"""

if phot_type not in ["magnitude", "flux"]:
Expand All @@ -2795,49 +2810,50 @@ def get_mcmc_photometry(
if burnin is None:
burnin = 0

hdf5_file = h5py.File(self.database, "r")
dset = hdf5_file[f"results/fit/{tag}/samples"]

if "n_param" in dset.attrs:
n_param = dset.attrs["n_param"]
elif "nparam" in dset.attrs:
n_param = dset.attrs["nparam"]
with h5py.File(self.database, "r") as hdf5_file:
dset = hdf5_file[f"results/fit/{tag}/samples"]

spectrum_type = dset.attrs["type"]
spectrum_name = dset.attrs["spectrum"]
if "n_param" in dset.attrs:
n_param = dset.attrs["n_param"]
elif "nparam" in dset.attrs:
n_param = dset.attrs["nparam"]

if "binary" in dset.attrs:
binary = dset.attrs["binary"]
else:
binary = False
spectrum_type = dset.attrs["type"]
spectrum_name = dset.attrs["spectrum"]

if "parallax" in dset.attrs:
parallax = dset.attrs["parallax"]
else:
parallax = None
if "binary" in dset.attrs:
binary = dset.attrs["binary"]
else:
binary = False

if "distance" in dset.attrs:
distance = dset.attrs["distance"]
else:
distance = None
if "parallax" in dset.attrs:
parallax = dset.attrs["parallax"]
else:
parallax = None

samples = np.asarray(dset)
if "distance" in dset.attrs:
distance = dset.attrs["distance"]
else:
distance = None

if samples.ndim == 3:
if burnin > samples.shape[0]:
raise ValueError(
f"The 'burnin' value is larger than the number of steps "
f"({samples.shape[1]}) that are made by the walkers."
)
samples = np.asarray(dset)

samples = samples[burnin:, :, :]
samples = samples.reshape((samples.shape[0] * samples.shape[1], n_param))
if samples.ndim == 3:
if burnin > samples.shape[0]:
raise ValueError(
"The 'burnin' value is larger than the "
f"number of steps ({samples.shape[1]}) "
"that are made by the walkers."
)

param = []
for i in range(n_param):
param.append(dset.attrs[f"parameter{i}"])
samples = samples[burnin:, :, :]
samples = samples.reshape(
(samples.shape[0] * samples.shape[1], n_param)
)

hdf5_file.close()
param = []
for i in range(n_param):
param.append(dset.attrs[f"parameter{i}"])

if spectrum_type == "model":
if spectrum_name == "powerlaw":
Expand Down Expand Up @@ -2932,6 +2948,20 @@ def get_mcmc_photometry(
elif phot_type == "flux":
mcmc_phot[i], _ = readcalib.get_flux(model_param=model_param)

if phot_type == "flux":
from species.read.read_filter import ReadFilter
from species.util.data_util import convert_units

read_filt = ReadFilter(filter_name)
filt_wavel = read_filt.mean_wavelength()

wavel_ones = np.full(mcmc_phot.size, filt_wavel)
data_in = np.column_stack([wavel_ones, mcmc_phot])

data_out = convert_units(data_in, ("um", flux_units), convert_from=False)

mcmc_phot = data_out[:, 1]

return mcmc_phot

@typechecked
Expand Down
21 changes: 18 additions & 3 deletions species/phot/syn_phot.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,9 @@ def spectrum_to_flux(
f"{self.wavel_range[1]:.4f}) extends beyond "
f"the wavelength range of the spectrum "
f"({wavelength[0]:.4f}-{wavelength[-1]:.4f}). "
"The flux is set to NaN. Setting the 'threshold' "
"parameter will loosen the wavelength constraints."
"The synthetic flux is set to NaN. Setting "
"the 'threshold' parameter will loosen the "
"wavelength constraints."
)

syn_flux = np.nan
Expand All @@ -275,7 +276,21 @@ def spectrum_to_flux(

transmission = self.filter_interp(wavelength)

if (
if np.sum(wavelength) == 0.0:
# The wavelength array looks empty but it is and
# empty array inside another array so the size
# of the wavelength array is 1. The sum however
# is 0 so that is used to check if it is empty
warnings.warn(
f"The filter profile of {self.filter_name} "
f"({self.wavel_range[0]:.4f}-{self.wavel_range[1]:.4f}) "
f"lies outside the wavelength range of the spectrum. "
f"The synthetic flux is set to NaN."
)

syn_flux = np.nan

elif (
threshold is not None
and (transmission[0] > threshold or transmission[-1] > threshold)
and (
Expand Down
2 changes: 1 addition & 1 deletion species/plot/plot_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def plot_posterior(
o_h_ratio = np.zeros(samples.shape[0])
c_o_ratio = np.zeros(samples.shape[0])

for i,sample_item in enumerate(samples):
for i, sample_item in enumerate(samples):
abund_dict = {}
for line_item in line_species:
abund_dict[line_item] = sample_item[abund_index[line_item]]
Expand Down
4 changes: 2 additions & 2 deletions species/read/read_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,12 +1534,12 @@ def get_flux(
if self.spectrum_interp is None:
self.interpolate_model()

spectrum = self.get_model(model_param)
model_box = self.get_model(model_param)

if synphot is None:
synphot = SyntheticPhotometry(self.filter_name)

model_flux = synphot.spectrum_to_flux(spectrum.wavelength, spectrum.flux)
model_flux = synphot.spectrum_to_flux(model_box.wavelength, model_box.flux)

if return_box:
model_mag = self.get_magnitude(model_param)
Expand Down

0 comments on commit 0adc7fa

Please sign in to comment.