From b7ab53546762b01f37c0cf6828a03c932d48d374 Mon Sep 17 00:00:00 2001 From: Tomas Stolker Date: Wed, 26 Jul 2023 12:50:13 +0200 Subject: [PATCH] Optimized CompareSpectra such that a model spectrum is only extracted once per grid point when using the compare_model method, using context managers for opening the HDF5 database throughout the database module, replaced the use of spectres with interp1d in ReadModel --- species/analysis/compare_spectra.py | 26 +- species/data/database.py | 1615 +++++++++++++-------------- species/data/model_data.json | 2 +- species/read/read_model.py | 96 +- 4 files changed, 856 insertions(+), 883 deletions(-) diff --git a/species/analysis/compare_spectra.py b/species/analysis/compare_spectra.py index 8c89f098..23d5cfba 100644 --- a/species/analysis/compare_spectra.py +++ b/species/analysis/compare_spectra.py @@ -553,26 +553,26 @@ def compare_model( if model_param[5] is not None: param_dict[model_param[5]] = coord_5_item + model_reader = read_model.ReadModel(model) + + model_box_full = model_reader.get_data(param_dict) + for spec_item in self.spec_name: obj_spec = self.object.get_spectrum()[spec_item][0] obj_res = self.object.get_spectrum()[spec_item][3] - wavel_range = ( - 0.9 * obj_spec[0, 0], - 1.1 * obj_spec[-1, 0], - ) + # Smooth model spectrum - model_reader = read_model.ReadModel( - model, wavel_range=wavel_range + model_flux = read_util.smooth_spectrum( + model_box_full.wavelength, model_box_full.flux, obj_res ) - model_box = model_reader.get_data( - param_dict, - spec_res=obj_res, - wavel_resample=obj_spec[:, 0], - ) + # Resample model spectrum + + flux_intep = interp1d(model_box_full.wavelength, model_flux) + model_flux = flux_intep(obj_spec[:, 0]) - nan_wavel = np.sum(np.isnan(model_box.flux)) + nan_wavel = np.sum(np.isnan(model_flux)) if nan_wavel > 0: warnings.warn( @@ -585,7 +585,7 @@ def compare_model( "calculating the goodness-of-fit statistic." ) - model_spec[spec_item] = model_box.flux + model_spec[spec_item] = model_flux model_list = [] data_list = [] diff --git a/species/data/database.py b/species/data/database.py index b7ca7d5d..0e94297e 100644 --- a/species/data/database.py +++ b/species/data/database.py @@ -396,17 +396,14 @@ def add_dust(self) -> None: None """ - h5_file = h5py.File(self.database, "a") - - if "dust" in h5_file: - del h5_file["dust"] + with h5py.File(self.database, "a") as h5_file: + if "dust" in h5_file: + del h5_file["dust"] - h5_file.create_group("dust") + h5_file.create_group("dust") - dust.add_optical_constants(self.input_path, h5_file) - dust.add_cross_sections(self.input_path, h5_file) - - h5_file.close() + dust.add_optical_constants(self.input_path, h5_file) + dust.add_cross_sections(self.input_path, h5_file) @typechecked def add_accretion(self) -> None: @@ -427,16 +424,13 @@ def add_accretion(self) -> None: None """ - h5_file = h5py.File(self.database, "a") - - if "accretion" in h5_file: - del h5_file["accretion"] - - h5_file.create_group("accretion") + with h5py.File(self.database, "a") as h5_file: + if "accretion" in h5_file: + del h5_file["accretion"] - accretion.add_accretion_relation(self.input_path, h5_file) + h5_file.create_group("accretion") - h5_file.close() + accretion.add_accretion_relation(self.input_path, h5_file) @typechecked def add_filter( @@ -487,44 +481,42 @@ def add_filter( filter_split = filter_name.split("/") - h5_file = h5py.File(self.database, "a") - - if f"filters/{filter_name}" in h5_file: - del h5_file[f"filters/{filter_name}"] - - if "filters" not in h5_file: - h5_file.create_group("filters") + with h5py.File(self.database, "a") as h5_file: + if f"filters/{filter_name}" in h5_file: + del h5_file[f"filters/{filter_name}"] - if f"filters/{filter_split[0]}" not in h5_file: - h5_file.create_group(f"filters/{filter_split[0]}") + if "filters" not in h5_file: + h5_file.create_group("filters") - if filename is not None: - data = np.loadtxt(filename) - wavelength = data[:, 0] - transmission = data[:, 1] + if f"filters/{filter_split[0]}" not in h5_file: + h5_file.create_group(f"filters/{filter_split[0]}") - else: - wavelength, transmission, detector_type = filters.download_filter( - filter_name - ) + if filename is not None: + data = np.loadtxt(filename) + wavelength = data[:, 0] + transmission = data[:, 1] - if wavelength is not None and transmission is not None: - wavel_new = [wavelength[0]] - transm_new = [transmission[0]] + else: + wavelength, transmission, detector_type = filters.download_filter( + filter_name + ) - for i in range(wavelength.size - 1): - if wavelength[i + 1] > wavel_new[-1]: - # Required for the issue with the Keck/NIRC2.J filter on SVO - wavel_new.append(wavelength[i + 1]) - transm_new.append(transmission[i + 1]) + if wavelength is not None and transmission is not None: + wavel_new = [wavelength[0]] + transm_new = [transmission[0]] - dset = h5_file.create_dataset( - f"filters/{filter_name}", data=np.column_stack((wavel_new, transm_new)) - ) + for i in range(wavelength.size - 1): + if wavelength[i + 1] > wavel_new[-1]: + # Required for the issue with the Keck/NIRC2.J filter on SVO + wavel_new.append(wavelength[i + 1]) + transm_new.append(transmission[i + 1]) - dset.attrs["det_type"] = str(detector_type) + dset = h5_file.create_dataset( + f"filters/{filter_name}", + data=np.column_stack((wavel_new, transm_new)), + ) - h5_file.close() + dset.attrs["det_type"] = str(detector_type) if verbose: print(" [DONE]") @@ -572,98 +564,95 @@ def add_isochrones( DeprecationWarning, ) - h5_file = h5py.File(self.database, "a") - - if "isochrones" not in h5_file: - h5_file.create_group("isochrones") - - if model in ["manual", "marleau", "phoenix"]: - if f"isochrones/{tag}" in h5_file: - del h5_file[f"isochrones/{tag}"] - - elif model == "ames": - if "isochrones/ames-cond" in h5_file: - del h5_file["isochrones/ames-cond"] - if "isochrones/ames-dusty" in h5_file: - del h5_file["isochrones/ames-dusty"] - - elif model == "atmo": - if "isochrones/atmo-ceq" in h5_file: - del h5_file["isochrones/atmo-ceq"] - if "isochrones/atmo-neq-weak" in h5_file: - del h5_file["isochrones/atmo-neq-weak"] - if "isochrones/atmo-neq-strong" in h5_file: - del h5_file["isochrones/atmo-neq-strong"] - - elif model == "baraffe2015": - if "isochrones/baraffe2015" in h5_file: - del h5_file["isochrones/baraffe2015"] - - elif model == "bt-settl": - if "isochrones/bt-settl" in h5_file: - del h5_file["isochrones/bt-settl"] - - elif model == "nextgen": - if "isochrones/nextgen" in h5_file: - del h5_file["isochrones/nextgen"] - - elif model == "saumon2008": - if "isochrones/saumon2008-nc_solar" in h5_file: - del h5_file["isochrones/saumon2008-nc_solar"] - if "isochrones/saumon2008-nc_-0.3" in h5_file: - del h5_file["isochrones/saumon2008-nc_-0.3"] - if "isochrones/saumon2008-nc_+0.3" in h5_file: - del h5_file["isochrones/saumon2008-nc_+0.3"] - if "isochrones/saumon2008-f2_solar" in h5_file: - del h5_file["isochrones/saumon2008-f2_solar"] - if "isochrones/saumon2008-hybrid_solar" in h5_file: - del h5_file["isochrones/saumon2008-hybrid_solar"] - - elif model == "sonora": - if "isochrones/sonora+0.0" in h5_file: - del h5_file["isochrones/sonora+0.0"] - if "isochrones/sonora+0.5" in h5_file: - del h5_file["isochrones/sonora+0.5"] - if "isochrones/sonora-0.5" in h5_file: - del h5_file["isochrones/sonora-0.5"] - - if model == "ames": - isochrones.add_ames(h5_file, self.input_path) - - elif model == "atmo": - isochrones.add_atmo(h5_file, self.input_path) - - elif model == "baraffe2015": - isochrones.add_baraffe2015(h5_file, self.input_path) - - elif model == "bt-settl": - isochrones.add_btsettl(h5_file, self.input_path) - - elif model == "manual": - isochrones.add_manual(h5_file, tag, filename) - - elif model == "marleau": - isochrones.add_marleau(h5_file, tag, filename) - - elif model == "nextgen": - isochrones.add_nextgen(h5_file, self.input_path) - - elif model == "saumon2008": - isochrones.add_saumon(h5_file, self.input_path) - - elif model == "sonora": - isochrones.add_sonora(h5_file, self.input_path) - - else: - raise ValueError( - f"The evolutionary model '{model}' is " - "not supported. Please choose another " - "argument for 'model'. Have a look " - "at the documentation of add_isochrones " - "for details on the supported models." - ) + with h5py.File(self.database, "a") as h5_file: + if "isochrones" not in h5_file: + h5_file.create_group("isochrones") + + if model in ["manual", "marleau", "phoenix"]: + if f"isochrones/{tag}" in h5_file: + del h5_file[f"isochrones/{tag}"] + + elif model == "ames": + if "isochrones/ames-cond" in h5_file: + del h5_file["isochrones/ames-cond"] + if "isochrones/ames-dusty" in h5_file: + del h5_file["isochrones/ames-dusty"] + + elif model == "atmo": + if "isochrones/atmo-ceq" in h5_file: + del h5_file["isochrones/atmo-ceq"] + if "isochrones/atmo-neq-weak" in h5_file: + del h5_file["isochrones/atmo-neq-weak"] + if "isochrones/atmo-neq-strong" in h5_file: + del h5_file["isochrones/atmo-neq-strong"] + + elif model == "baraffe2015": + if "isochrones/baraffe2015" in h5_file: + del h5_file["isochrones/baraffe2015"] + + elif model == "bt-settl": + if "isochrones/bt-settl" in h5_file: + del h5_file["isochrones/bt-settl"] + + elif model == "nextgen": + if "isochrones/nextgen" in h5_file: + del h5_file["isochrones/nextgen"] + + elif model == "saumon2008": + if "isochrones/saumon2008-nc_solar" in h5_file: + del h5_file["isochrones/saumon2008-nc_solar"] + if "isochrones/saumon2008-nc_-0.3" in h5_file: + del h5_file["isochrones/saumon2008-nc_-0.3"] + if "isochrones/saumon2008-nc_+0.3" in h5_file: + del h5_file["isochrones/saumon2008-nc_+0.3"] + if "isochrones/saumon2008-f2_solar" in h5_file: + del h5_file["isochrones/saumon2008-f2_solar"] + if "isochrones/saumon2008-hybrid_solar" in h5_file: + del h5_file["isochrones/saumon2008-hybrid_solar"] + + elif model == "sonora": + if "isochrones/sonora+0.0" in h5_file: + del h5_file["isochrones/sonora+0.0"] + if "isochrones/sonora+0.5" in h5_file: + del h5_file["isochrones/sonora+0.5"] + if "isochrones/sonora-0.5" in h5_file: + del h5_file["isochrones/sonora-0.5"] + + if model == "ames": + isochrones.add_ames(h5_file, self.input_path) + + elif model == "atmo": + isochrones.add_atmo(h5_file, self.input_path) + + elif model == "baraffe2015": + isochrones.add_baraffe2015(h5_file, self.input_path) + + elif model == "bt-settl": + isochrones.add_btsettl(h5_file, self.input_path) + + elif model == "manual": + isochrones.add_manual(h5_file, tag, filename) + + elif model == "marleau": + isochrones.add_marleau(h5_file, tag, filename) + + elif model == "nextgen": + isochrones.add_nextgen(h5_file, self.input_path) + + elif model == "saumon2008": + isochrones.add_saumon(h5_file, self.input_path) + + elif model == "sonora": + isochrones.add_sonora(h5_file, self.input_path) - h5_file.close() + else: + raise ValueError( + f"The evolutionary model '{model}' is " + "not supported. Please choose another " + "argument for 'model'. Have a look " + "at the documentation of add_isochrones " + "for details on the supported models." + ) @typechecked def add_model( @@ -1522,21 +1511,18 @@ def add_photometry(self, phot_library: str) -> None: None """ - h5_file = h5py.File(self.database, "a") - - if "photometry" not in h5_file: - h5_file.create_group("photometry") - - if "photometry/" + phot_library in h5_file: - del h5_file["photometry/" + phot_library] + with h5py.File(self.database, "a") as h5_file: + if "photometry" not in h5_file: + h5_file.create_group("photometry") - if phot_library[0:7] == "vlm-plx": - vlm_plx.add_vlm_plx(self.input_path, h5_file) + if "photometry/" + phot_library in h5_file: + del h5_file["photometry/" + phot_library] - elif phot_library[0:7] == "leggett": - leggett.add_leggett(self.input_path, h5_file) + if phot_library[0:7] == "vlm-plx": + vlm_plx.add_vlm_plx(self.input_path, h5_file) - h5_file.close() + elif phot_library[0:7] == "leggett": + leggett.add_leggett(self.input_path, h5_file) @typechecked def add_calibration( @@ -1593,102 +1579,99 @@ def add_calibration( if scaling is None: scaling = (1.0, 1.0) - h5_file = h5py.File(self.database, "a") - - if "spectra/calibration" not in h5_file: - h5_file.create_group("spectra/calibration") - - if "spectra/calibration/" + tag in h5_file: - del h5_file["spectra/calibration/" + tag] + with h5py.File(self.database, "a") as h5_file: + if "spectra/calibration" not in h5_file: + h5_file.create_group("spectra/calibration") - if filename is not None: - if filename[-5:] == ".fits": - data = fits.getdata(filename) + if "spectra/calibration/" + tag in h5_file: + del h5_file["spectra/calibration/" + tag] - if data.ndim != 2: - raise RuntimeError( - "The FITS file that is provided " - "as argument of 'filename' does " - "not contain a 2D dataset." - ) + if filename is not None: + if filename[-5:] == ".fits": + data = fits.getdata(filename) - if data.shape[1] != 3 and data.shape[0]: - warnings.warn( - f"Transposing the data that is read " - f"from {filename} because the shape " - f"is {data.shape} instead of " - f"{data.shape[1], data.shape[0]}." - ) + if data.ndim != 2: + raise RuntimeError( + "The FITS file that is provided " + "as argument of 'filename' does " + "not contain a 2D dataset." + ) - data = np.transpose(data) + if data.shape[1] != 3 and data.shape[0]: + warnings.warn( + f"Transposing the data that is read " + f"from {filename} because the shape " + f"is {data.shape} instead of " + f"{data.shape[1], data.shape[0]}." + ) - else: - data = np.loadtxt(filename) + data = np.transpose(data) - nan_idx = np.isnan(data[:, 1]) + else: + data = np.loadtxt(filename) - if np.sum(nan_idx) != 0: - data = data[~nan_idx, :] + nan_idx = np.isnan(data[:, 1]) - warnings.warn( - f"Found {np.sum(nan_idx)} fluxes with NaN in " - f"the data of {filename}. Removing the " - f"spectral fluxes that contain a NaN." - ) + if np.sum(nan_idx) != 0: + data = data[~nan_idx, :] - if units is None: - wavelength = scaling[0] * data[:, 0] # (um) - flux = scaling[1] * data[:, 1] # (W m-2 um-1) + warnings.warn( + f"Found {np.sum(nan_idx)} fluxes with NaN in " + f"the data of {filename}. Removing the " + f"spectral fluxes that contain a NaN." + ) - else: - if units["wavelength"] == "um": + if units is None: wavelength = scaling[0] * data[:, 0] # (um) - - elif units["wavelength"] == "angstrom": - wavelength = scaling[0] * data[:, 0] * 1e-4 # (um) - - if units["flux"] == "w m-2 um-1": flux = scaling[1] * data[:, 1] # (W m-2 um-1) - elif units["flux"] == "w m-2": + else: if units["wavelength"] == "um": - flux = scaling[1] * data[:, 1] / wavelength # (W m-2 um-1) + wavelength = scaling[0] * data[:, 0] # (um) - if data.shape[1] == 3: - if units is None: - error = scaling[1] * data[:, 2] # (W m-2 um-1) + elif units["wavelength"] == "angstrom": + wavelength = scaling[0] * data[:, 0] * 1e-4 # (um) - else: if units["flux"] == "w m-2 um-1": - error = scaling[1] * data[:, 2] # (W m-2 um-1) + flux = scaling[1] * data[:, 1] # (W m-2 um-1) elif units["flux"] == "w m-2": if units["wavelength"] == "um": - error = scaling[1] * data[:, 2] / wavelength # (W m-2 um-1) + flux = scaling[1] * data[:, 1] / wavelength # (W m-2 um-1) - else: - error = np.repeat(0.0, wavelength.size) - - # nan_idx = np.isnan(flux) - # - # if np.sum(nan_idx) != 0: - # wavelength = wavelength[~nan_idx] - # flux = flux[~nan_idx] - # error = error[~nan_idx] - # - # warnings.warn( - # f"Found {np.sum(nan_idx)} fluxes with NaN in " - # f"the calibration spectrum. Removing the " - # f"spectral fluxes that contain a NaN." - # ) - - print(f"Adding calibration spectrum: {tag}...", end="", flush=True) - - h5_file.create_dataset( - f"spectra/calibration/{tag}", data=np.vstack((wavelength, flux, error)) - ) + if data.shape[1] == 3: + if units is None: + error = scaling[1] * data[:, 2] # (W m-2 um-1) - h5_file.close() + else: + if units["flux"] == "w m-2 um-1": + error = scaling[1] * data[:, 2] # (W m-2 um-1) + + elif units["flux"] == "w m-2": + if units["wavelength"] == "um": + error = scaling[1] * data[:, 2] / wavelength # (W m-2 um-1) + + else: + error = np.repeat(0.0, wavelength.size) + + # nan_idx = np.isnan(flux) + # + # if np.sum(nan_idx) != 0: + # wavelength = wavelength[~nan_idx] + # flux = flux[~nan_idx] + # error = error[~nan_idx] + # + # warnings.warn( + # f"Found {np.sum(nan_idx)} fluxes with NaN in " + # f"the calibration spectrum. Removing the " + # f"spectral fluxes that contain a NaN." + # ) + + print(f"Adding calibration spectrum: {tag}...", end="", flush=True) + + h5_file.create_dataset( + f"spectra/calibration/{tag}", data=np.vstack((wavelength, flux, error)) + ) print(" [DONE]") @@ -1722,33 +1705,30 @@ def add_spectrum( "use the add_spectra method instead." ) - h5_file = h5py.File(self.database, "a") - - if "spectra" not in h5_file: - h5_file.create_group("spectra") - - if "spectra/" + spec_library in h5_file: - del h5_file["spectra/" + spec_library] + with h5py.File(self.database, "a") as h5_file: + if "spectra" not in h5_file: + h5_file.create_group("spectra") - if spec_library[0:5] == "vega": - vega.add_vega(self.input_path, h5_file) + if "spectra/" + spec_library in h5_file: + del h5_file["spectra/" + spec_library] - elif spec_library[0:5] == "irtf": - irtf.add_irtf(self.input_path, h5_file, sptypes) + if spec_library[0:5] == "vega": + vega.add_vega(self.input_path, h5_file) - elif spec_library[0:5] == "spex": - spex.add_spex(self.input_path, h5_file) + elif spec_library[0:5] == "irtf": + irtf.add_irtf(self.input_path, h5_file, sptypes) - elif spec_library[0:12] == "kesseli+2017": - kesseli2017.add_kesseli2017(self.input_path, h5_file) + elif spec_library[0:5] == "spex": + spex.add_spex(self.input_path, h5_file) - elif spec_library[0:13] == "bonnefoy+2014": - bonnefoy2014.add_bonnefoy2014(self.input_path, h5_file) + elif spec_library[0:12] == "kesseli+2017": + kesseli2017.add_kesseli2017(self.input_path, h5_file) - elif spec_library[0:11] == "allers+2013": - allers2013.add_allers2013(self.input_path, h5_file) + elif spec_library[0:13] == "bonnefoy+2014": + bonnefoy2014.add_bonnefoy2014(self.input_path, h5_file) - h5_file.close() + elif spec_library[0:11] == "allers+2013": + allers2013.add_allers2013(self.input_path, h5_file) @typechecked def add_spectra( @@ -1776,33 +1756,30 @@ def add_spectra( None """ - h5_file = h5py.File(self.database, "a") - - if "spectra" not in h5_file: - h5_file.create_group("spectra") - - if f"spectra/{spec_library}" in h5_file: - del h5_file["spectra/" + spec_library] + with h5py.File(self.database, "a") as h5_file: + if "spectra" not in h5_file: + h5_file.create_group("spectra") - if spec_library[0:5] == "vega": - vega.add_vega(self.input_path, h5_file) + if f"spectra/{spec_library}" in h5_file: + del h5_file["spectra/" + spec_library] - elif spec_library[0:5] == "irtf": - irtf.add_irtf(self.input_path, h5_file, sptypes) + if spec_library[0:5] == "vega": + vega.add_vega(self.input_path, h5_file) - elif spec_library[0:5] == "spex": - spex.add_spex(self.input_path, h5_file) + elif spec_library[0:5] == "irtf": + irtf.add_irtf(self.input_path, h5_file, sptypes) - elif spec_library[0:12] == "kesseli+2017": - kesseli2017.add_kesseli2017(self.input_path, h5_file) + elif spec_library[0:5] == "spex": + spex.add_spex(self.input_path, h5_file) - elif spec_library[0:13] == "bonnefoy+2014": - bonnefoy2014.add_bonnefoy2014(self.input_path, h5_file) + elif spec_library[0:12] == "kesseli+2017": + kesseli2017.add_kesseli2017(self.input_path, h5_file) - elif spec_library[0:11] == "allers+2013": - allers2013.add_allers2013(self.input_path, h5_file) + elif spec_library[0:13] == "bonnefoy+2014": + bonnefoy2014.add_bonnefoy2014(self.input_path, h5_file) - h5_file.close() + elif spec_library[0:11] == "allers+2013": + allers2013.add_allers2013(self.input_path, h5_file) @typechecked def add_samples( @@ -1864,80 +1841,77 @@ def add_samples( if spec_labels is None: spec_labels = [] - h5_file = h5py.File(self.database, "a") - - if "results" not in h5_file: - h5_file.create_group("results") - - if "results/fit" not in h5_file: - h5_file.create_group("results/fit") + with h5py.File(self.database, "a") as h5_file: + if "results" not in h5_file: + h5_file.create_group("results") - if f"results/fit/{tag}" in h5_file: - del h5_file[f"results/fit/{tag}"] + if "results/fit" not in h5_file: + h5_file.create_group("results/fit") - dset = h5_file.create_dataset(f"results/fit/{tag}/samples", data=samples) - h5_file.create_dataset(f"results/fit/{tag}/ln_prob", data=ln_prob) + if f"results/fit/{tag}" in h5_file: + del h5_file[f"results/fit/{tag}"] - if attr_dict is not None and "spec_type" in attr_dict: - dset.attrs["type"] = attr_dict["spec_type"] - else: - dset.attrs["type"] = str(spectrum[0]) + dset = h5_file.create_dataset(f"results/fit/{tag}/samples", data=samples) + h5_file.create_dataset(f"results/fit/{tag}/ln_prob", data=ln_prob) - if attr_dict is not None and "spec_name" in attr_dict: - dset.attrs["spectrum"] = attr_dict["spec_name"] - else: - dset.attrs["spectrum"] = str(spectrum[1]) + if attr_dict is not None and "spec_type" in attr_dict: + dset.attrs["type"] = attr_dict["spec_type"] + else: + dset.attrs["type"] = str(spectrum[0]) - dset.attrs["n_param"] = int(len(modelpar)) - dset.attrs["sampler"] = str(sampler) + if attr_dict is not None and "spec_name" in attr_dict: + dset.attrs["spectrum"] = attr_dict["spec_name"] + else: + dset.attrs["spectrum"] = str(spectrum[1]) - if parallax is not None: - dset.attrs["parallax"] = float(parallax) + dset.attrs["n_param"] = int(len(modelpar)) + dset.attrs["sampler"] = str(sampler) - if attr_dict is not None and "mean_accept" in attr_dict: - mean_accept = float(attr_dict["mean_accept"]) - dset.attrs["mean_accept"] = mean_accept - print(f"Mean acceptance fraction: {mean_accept:.3f}") + if parallax is not None: + dset.attrs["parallax"] = float(parallax) - elif mean_accept is not None: - dset.attrs["mean_accept"] = float(mean_accept) - print(f"Mean acceptance fraction: {mean_accept:.3f}") + if attr_dict is not None and "mean_accept" in attr_dict: + mean_accept = float(attr_dict["mean_accept"]) + dset.attrs["mean_accept"] = mean_accept + print(f"Mean acceptance fraction: {mean_accept:.3f}") - if ln_evidence is not None: - dset.attrs["ln_evidence"] = ln_evidence + elif mean_accept is not None: + dset.attrs["mean_accept"] = float(mean_accept) + print(f"Mean acceptance fraction: {mean_accept:.3f}") - count_scaling = 0 + if ln_evidence is not None: + dset.attrs["ln_evidence"] = ln_evidence - for i, item in enumerate(modelpar): - dset.attrs[f"parameter{i}"] = str(item) + count_scaling = 0 - if item in spec_labels: - dset.attrs[f"scaling{count_scaling}"] = str(item) - count_scaling += 1 + for i, item in enumerate(modelpar): + dset.attrs[f"parameter{i}"] = str(item) - dset.attrs["n_scaling"] = int(count_scaling) + if item in spec_labels: + dset.attrs[f"scaling{count_scaling}"] = str(item) + count_scaling += 1 - if "teff_0" in modelpar and "teff_1" in modelpar: - dset.attrs["binary"] = True - else: - dset.attrs["binary"] = False + dset.attrs["n_scaling"] = int(count_scaling) - print("Integrated autocorrelation time:") + if "teff_0" in modelpar and "teff_1" in modelpar: + dset.attrs["binary"] = True + else: + dset.attrs["binary"] = False - for i, item in enumerate(modelpar): - auto_corr = emcee.autocorr.integrated_time(samples[:, i], quiet=True)[0] + print("Integrated autocorrelation time:") - if np.allclose(samples[:, i], np.mean(samples[:, i])): - print(f" - {item}: fixed") - else: - print(f" - {item}: {auto_corr:.2f}") + for i, item in enumerate(modelpar): + auto_corr = emcee.autocorr.integrated_time(samples[:, i], quiet=True)[0] - dset.attrs[f"autocorrelation{i}"] = float(auto_corr) + if np.allclose(samples[:, i], np.mean(samples[:, i])): + print(f" - {item}: fixed") + else: + print(f" - {item}: {auto_corr:.2f}") - for key, value in attr_dict.items(): - dset.attrs[key] = value + dset.attrs[f"autocorrelation{i}"] = float(auto_corr) - h5_file.close() + for key, value in attr_dict.items(): + dset.attrs[key] = value @typechecked def get_probable_sample( @@ -1967,53 +1941,51 @@ def get_probable_sample( if burnin is None: burnin = 0 - h5_file = h5py.File(self.database, "r") - dset = h5_file[f"results/fit/{tag}/samples"] - - samples = np.asarray(dset) - ln_prob = np.asarray(h5_file[f"results/fit/{tag}/ln_prob"]) + with h5py.File(self.database, "r") as h5_file: + dset = h5_file[f"results/fit/{tag}/samples"] - if "n_param" in dset.attrs: - n_param = dset.attrs["n_param"] - elif "nparam" in dset.attrs: - n_param = dset.attrs["nparam"] + samples = np.asarray(dset) + ln_prob = np.asarray(h5_file[f"results/fit/{tag}/ln_prob"]) - if samples.ndim == 3: - if burnin > samples.shape[0]: - raise ValueError( - f"The 'burnin' value is larger than the number of steps " - f"({samples.shape[1]}) that are made by the walkers." - ) + if "n_param" in dset.attrs: + n_param = dset.attrs["n_param"] + elif "nparam" in dset.attrs: + n_param = dset.attrs["nparam"] - samples = samples[burnin:, :, :] - ln_prob = ln_prob[burnin:, :] + if samples.ndim == 3: + if burnin > samples.shape[0]: + raise ValueError( + f"The 'burnin' value is larger than the number of steps " + f"({samples.shape[1]}) that are made by the walkers." + ) - samples = np.reshape(samples, (-1, n_param)) - ln_prob = np.reshape(ln_prob, -1) + samples = samples[burnin:, :, :] + ln_prob = ln_prob[burnin:, :] - index_max = np.unravel_index(ln_prob.argmax(), ln_prob.shape) + samples = np.reshape(samples, (-1, n_param)) + ln_prob = np.reshape(ln_prob, -1) - # max_prob = ln_prob[index_max] - max_sample = samples[index_max] + index_max = np.unravel_index(ln_prob.argmax(), ln_prob.shape) - prob_sample = {} + # max_prob = ln_prob[index_max] + max_sample = samples[index_max] - for i in range(n_param): - par_key = dset.attrs[f"parameter{i}"] - par_value = max_sample[i] + prob_sample = {} - prob_sample[par_key] = par_value + for i in range(n_param): + par_key = dset.attrs[f"parameter{i}"] + par_value = max_sample[i] - if "parallax" not in prob_sample and "parallax" in dset.attrs: - prob_sample["parallax"] = dset.attrs["parallax"] + prob_sample[par_key] = par_value - elif "distance" not in prob_sample and "distance" in dset.attrs: - prob_sample["distance"] = dset.attrs["distance"] + if "parallax" not in prob_sample and "parallax" in dset.attrs: + prob_sample["parallax"] = dset.attrs["parallax"] - if "pt_smooth" in dset.attrs: - prob_sample["pt_smooth"] = dset.attrs["pt_smooth"] + elif "distance" not in prob_sample and "distance" in dset.attrs: + prob_sample["distance"] = dset.attrs["distance"] - h5_file.close() + if "pt_smooth" in dset.attrs: + prob_sample["pt_smooth"] = dset.attrs["pt_smooth"] return prob_sample @@ -2186,193 +2158,193 @@ def get_mcmc_spectra( if burnin is None: burnin = 0 - h5_file = h5py.File(self.database, "r") - dset = h5_file[f"results/fit/{tag}/samples"] - - spectrum_type = dset.attrs["type"] - spectrum_name = dset.attrs["spectrum"] + with h5py.File(self.database, "r") as h5_file: + dset = h5_file[f"results/fit/{tag}/samples"] - if "n_param" in dset.attrs: - n_param = dset.attrs["n_param"] - elif "nparam" in dset.attrs: - n_param = dset.attrs["nparam"] + spectrum_type = dset.attrs["type"] + spectrum_name = dset.attrs["spectrum"] - if "n_scaling" in dset.attrs: - n_scaling = dset.attrs["n_scaling"] - elif "nscaling" in dset.attrs: - n_scaling = dset.attrs["nscaling"] - else: - n_scaling = 0 + if "n_param" in dset.attrs: + n_param = dset.attrs["n_param"] + elif "nparam" in dset.attrs: + n_param = dset.attrs["nparam"] - if "n_error" in dset.attrs: - n_error = dset.attrs["n_error"] - else: - n_error = 0 + if "n_scaling" in dset.attrs: + n_scaling = dset.attrs["n_scaling"] + elif "nscaling" in dset.attrs: + n_scaling = dset.attrs["nscaling"] + else: + n_scaling = 0 - if "binary" in dset.attrs: - binary = dset.attrs["binary"] - else: - binary = False + if "n_error" in dset.attrs: + n_error = dset.attrs["n_error"] + else: + n_error = 0 - if "ext_filter" in dset.attrs: - ext_filter = dset.attrs["ext_filter"] - else: - ext_filter = None + if "binary" in dset.attrs: + binary = dset.attrs["binary"] + else: + binary = False - ignore_param = [] + if "ext_filter" in dset.attrs: + ext_filter = dset.attrs["ext_filter"] + else: + ext_filter = None - for i in range(n_scaling): - ignore_param.append(dset.attrs[f"scaling{i}"]) + ignore_param = [] - for i in range(n_error): - ignore_param.append(dset.attrs[f"error{i}"]) + for i in range(n_scaling): + ignore_param.append(dset.attrs[f"scaling{i}"]) - for i in range(n_param): - if dset.attrs[f"parameter{i}"][:9] == "corr_len_": - ignore_param.append(dset.attrs[f"parameter{i}"]) + for i in range(n_error): + ignore_param.append(dset.attrs[f"error{i}"]) - elif dset.attrs[f"parameter{i}"][:9] == "corr_amp_": - ignore_param.append(dset.attrs[f"parameter{i}"]) + for i in range(n_param): + if dset.attrs[f"parameter{i}"][:9] == "corr_len_": + ignore_param.append(dset.attrs[f"parameter{i}"]) - if spec_res is not None and spectrum_type == "calibration": - warnings.warn( - "Smoothing of the spectral resolution is not implemented for calibration " - "spectra." - ) + elif dset.attrs[f"parameter{i}"][:9] == "corr_amp_": + ignore_param.append(dset.attrs[f"parameter{i}"]) - if "parallax" in dset.attrs: - parallax = dset.attrs["parallax"] - else: - parallax = None + if spec_res is not None and spectrum_type == "calibration": + warnings.warn( + "Smoothing of the spectral resolution is not implemented for calibration " + "spectra." + ) - if "distance" in dset.attrs: - distance = dset.attrs["distance"] - else: - distance = None + if "parallax" in dset.attrs: + parallax = dset.attrs["parallax"] + else: + parallax = None - samples = np.asarray(dset) + if "distance" in dset.attrs: + distance = dset.attrs["distance"] + else: + distance = None - # samples = samples[samples[:, 2] > 100., ] + samples = np.asarray(dset) - if samples.ndim == 2: - rand_index = np.random.randint(samples.shape[0], size=random) - samples = samples[rand_index,] + # samples = samples[samples[:, 2] > 100., ] - elif samples.ndim == 3: - if burnin > samples.shape[0]: - raise ValueError( - f"The 'burnin' value is larger than the number of steps " - f"({samples.shape[1]}) that are made by the walkers." - ) + if samples.ndim == 2: + rand_index = np.random.randint(samples.shape[0], size=random) + samples = samples[rand_index,] - samples = samples[burnin:, :, :] + elif samples.ndim == 3: + if burnin > samples.shape[0]: + raise ValueError( + f"The 'burnin' value is larger than the number of steps " + f"({samples.shape[1]}) that are made by the walkers." + ) - ran_walker = np.random.randint(samples.shape[0], size=random) - ran_step = np.random.randint(samples.shape[1], size=random) - samples = samples[ran_walker, ran_step, :] + samples = samples[burnin:, :, :] - param = [] - for i in range(n_param): - param.append(dset.attrs[f"parameter{i}"]) + ran_walker = np.random.randint(samples.shape[0], size=random) + ran_step = np.random.randint(samples.shape[1], size=random) + samples = samples[ran_walker, ran_step, :] - if spectrum_type == "model": - if spectrum_name == "planck": - readmodel = read_planck.ReadPlanck(wavel_range) + param = [] + for i in range(n_param): + param.append(dset.attrs[f"parameter{i}"]) - elif spectrum_name == "powerlaw": - pass + if spectrum_type == "model": + if spectrum_name == "planck": + readmodel = read_planck.ReadPlanck(wavel_range) - else: - readmodel = read_model.ReadModel(spectrum_name, wavel_range=wavel_range) + elif spectrum_name == "powerlaw": + pass - elif spectrum_type == "calibration": - readcalib = read_calibration.ReadCalibration( - spectrum_name, filter_name=None - ) + else: + readmodel = read_model.ReadModel( + spectrum_name, wavel_range=wavel_range + ) - boxes = [] + elif spectrum_type == "calibration": + readcalib = read_calibration.ReadCalibration( + spectrum_name, filter_name=None + ) - for i in tqdm(range(samples.shape[0]), desc="Getting MCMC spectra"): - model_param = {} - for j in range(samples.shape[1]): - if param[j] not in ignore_param: - model_param[param[j]] = samples[i, j] + boxes = [] - if "parallax" not in model_param and parallax is not None: - model_param["parallax"] = parallax + for i in tqdm(range(samples.shape[0]), desc="Getting MCMC spectra"): + model_param = {} + for j in range(samples.shape[1]): + if param[j] not in ignore_param: + model_param[param[j]] = samples[i, j] - elif "distance" not in model_param and distance is not None: - model_param["distance"] = distance + if "parallax" not in model_param and parallax is not None: + model_param["parallax"] = parallax - if spectrum_type == "model": - if spectrum_name == "planck": - specbox = readmodel.get_spectrum( - model_param, - spec_res, - smooth=True, - wavel_resample=wavel_resample, - ) + elif "distance" not in model_param and distance is not None: + model_param["distance"] = distance - elif spectrum_name == "powerlaw": - if wavel_resample is not None: - warnings.warn( - "The 'wavel_resample' parameter is not support by the " - "'powerlaw' model so the argument will be ignored." + if spectrum_type == "model": + if spectrum_name == "planck": + specbox = readmodel.get_spectrum( + model_param, + spec_res, + smooth=True, + wavel_resample=wavel_resample, ) - specbox = read_util.powerlaw_spectrum(wavel_range, model_param) + elif spectrum_name == "powerlaw": + if wavel_resample is not None: + warnings.warn( + "The 'wavel_resample' parameter is not support by the " + "'powerlaw' model so the argument will be ignored." + ) - else: - if binary: - param_0 = read_util.binary_to_single(model_param, 0) + specbox = read_util.powerlaw_spectrum(wavel_range, model_param) - specbox_0 = readmodel.get_model( - param_0, - spec_res=spec_res, - wavel_resample=wavel_resample, - smooth=True, - ext_filter=ext_filter, - ) + else: + if binary: + param_0 = read_util.binary_to_single(model_param, 0) - param_1 = read_util.binary_to_single(model_param, 1) + specbox_0 = readmodel.get_model( + param_0, + spec_res=spec_res, + wavel_resample=wavel_resample, + smooth=True, + ext_filter=ext_filter, + ) - specbox_1 = readmodel.get_model( - param_1, - spec_res=spec_res, - wavel_resample=wavel_resample, - smooth=True, - ext_filter=ext_filter, - ) + param_1 = read_util.binary_to_single(model_param, 1) - flux_comb = ( - model_param["spec_weight"] * specbox_0.flux - + (1.0 - model_param["spec_weight"]) * specbox_1.flux - ) + specbox_1 = readmodel.get_model( + param_1, + spec_res=spec_res, + wavel_resample=wavel_resample, + smooth=True, + ext_filter=ext_filter, + ) - specbox = box.create_box( - boxtype="model", - model=spectrum_name, - wavelength=specbox_0.wavelength, - flux=flux_comb, - parameters=model_param, - quantity="flux", - ) + flux_comb = ( + model_param["spec_weight"] * specbox_0.flux + + (1.0 - model_param["spec_weight"]) * specbox_1.flux + ) - else: - specbox = readmodel.get_model( - model_param, - spec_res=spec_res, - wavel_resample=wavel_resample, - smooth=True, - ext_filter=ext_filter, - ) + specbox = box.create_box( + boxtype="model", + model=spectrum_name, + wavelength=specbox_0.wavelength, + flux=flux_comb, + parameters=model_param, + quantity="flux", + ) - elif spectrum_type == "calibration": - specbox = readcalib.get_spectrum(model_param) + else: + specbox = readmodel.get_model( + model_param, + spec_res=spec_res, + wavel_resample=wavel_resample, + smooth=True, + ext_filter=ext_filter, + ) - boxes.append(specbox) + elif spectrum_type == "calibration": + specbox = readcalib.get_spectrum(model_param) - h5_file.close() + boxes.append(specbox) return boxes @@ -2418,49 +2390,49 @@ def get_mcmc_photometry( if burnin is None: burnin = 0 - h5_file = h5py.File(self.database, "r") - dset = h5_file[f"results/fit/{tag}/samples"] - - if "n_param" in dset.attrs: - n_param = dset.attrs["n_param"] - elif "nparam" in dset.attrs: - n_param = dset.attrs["nparam"] + with h5py.File(self.database, "r") as h5_file: + dset = h5_file[f"results/fit/{tag}/samples"] - spectrum_type = dset.attrs["type"] - spectrum_name = dset.attrs["spectrum"] + if "n_param" in dset.attrs: + n_param = dset.attrs["n_param"] + elif "nparam" in dset.attrs: + n_param = dset.attrs["nparam"] - if "binary" in dset.attrs: - binary = dset.attrs["binary"] - else: - binary = False + spectrum_type = dset.attrs["type"] + spectrum_name = dset.attrs["spectrum"] - if "parallax" in dset.attrs: - parallax = dset.attrs["parallax"] - else: - parallax = None + if "binary" in dset.attrs: + binary = dset.attrs["binary"] + else: + binary = False - if "distance" in dset.attrs: - distance = dset.attrs["distance"] - else: - distance = None + if "parallax" in dset.attrs: + parallax = dset.attrs["parallax"] + else: + parallax = None - samples = np.asarray(dset) + if "distance" in dset.attrs: + distance = dset.attrs["distance"] + else: + distance = None - if samples.ndim == 3: - if burnin > samples.shape[0]: - raise ValueError( - f"The 'burnin' value is larger than the number of steps " - f"({samples.shape[1]}) that are made by the walkers." - ) + samples = np.asarray(dset) - samples = samples[burnin:, :, :] - samples = samples.reshape((samples.shape[0] * samples.shape[1], n_param)) + if samples.ndim == 3: + if burnin > samples.shape[0]: + raise ValueError( + f"The 'burnin' value is larger than the number of steps " + f"({samples.shape[1]}) that are made by the walkers." + ) - param = [] - for i in range(n_param): - param.append(dset.attrs[f"parameter{i}"]) + samples = samples[burnin:, :, :] + samples = samples.reshape( + (samples.shape[0] * samples.shape[1], n_param) + ) - h5_file.close() + param = [] + for i in range(n_param): + param.append(dset.attrs[f"parameter{i}"]) if spectrum_type == "model": if spectrum_name == "powerlaw": @@ -2581,71 +2553,69 @@ def get_object( print(f"Getting object: {object_name}...", end="", flush=True) - h5_file = h5py.File(self.database, "r") - dset = h5_file[f"objects/{object_name}"] - - if "parallax" in dset: - parallax = np.asarray(dset["parallax"]) - else: - parallax = None + with h5py.File(self.database, "r") as h5_file: + dset = h5_file[f"objects/{object_name}"] - if "distance" in dset: - distance = np.asarray(dset["distance"]) - else: - distance = None + if "parallax" in dset: + parallax = np.asarray(dset["parallax"]) + else: + parallax = None - if inc_phot: - magnitude = {} - flux = {} - mean_wavel = {} + if "distance" in dset: + distance = np.asarray(dset["distance"]) + else: + distance = None - for observatory in dset.keys(): - if observatory not in ["parallax", "distance", "spectrum"]: - for filter_name in dset[observatory]: - name = f"{observatory}/{filter_name}" + if inc_phot: + magnitude = {} + flux = {} + mean_wavel = {} - if isinstance(inc_phot, bool) or name in inc_phot: - magnitude[name] = dset[name][0:2] - flux[name] = dset[name][2:4] + for observatory in dset.keys(): + if observatory not in ["parallax", "distance", "spectrum"]: + for filter_name in dset[observatory]: + name = f"{observatory}/{filter_name}" - read_filt = read_filter.ReadFilter(name) - mean_wavel[name] = read_filt.mean_wavelength() + if isinstance(inc_phot, bool) or name in inc_phot: + magnitude[name] = dset[name][0:2] + flux[name] = dset[name][2:4] - phot_filters = list(magnitude.keys()) + read_filt = read_filter.ReadFilter(name) + mean_wavel[name] = read_filt.mean_wavelength() - else: - magnitude = None - flux = None - phot_filters = None - mean_wavel = None - - if inc_spec and f"objects/{object_name}/spectrum" in h5_file: - spectrum = {} - - for item in h5_file[f"objects/{object_name}/spectrum"]: - data_group = f"objects/{object_name}/spectrum/{item}" - - if isinstance(inc_spec, bool) or item in inc_spec: - if f"{data_group}/covariance" not in h5_file: - spectrum[item] = ( - np.asarray(h5_file[f"{data_group}/spectrum"]), - None, - None, - h5_file[f"{data_group}"].attrs["specres"], - ) + phot_filters = list(magnitude.keys()) - else: - spectrum[item] = ( - np.asarray(h5_file[f"{data_group}/spectrum"]), - np.asarray(h5_file[f"{data_group}/covariance"]), - np.asarray(h5_file[f"{data_group}/inv_covariance"]), - h5_file[f"{data_group}"].attrs["specres"], - ) + else: + magnitude = None + flux = None + phot_filters = None + mean_wavel = None + + if inc_spec and f"objects/{object_name}/spectrum" in h5_file: + spectrum = {} + + for item in h5_file[f"objects/{object_name}/spectrum"]: + data_group = f"objects/{object_name}/spectrum/{item}" + + if isinstance(inc_spec, bool) or item in inc_spec: + if f"{data_group}/covariance" not in h5_file: + spectrum[item] = ( + np.asarray(h5_file[f"{data_group}/spectrum"]), + None, + None, + h5_file[f"{data_group}"].attrs["specres"], + ) - else: - spectrum = None + else: + spectrum[item] = ( + np.asarray(h5_file[f"{data_group}/spectrum"]), + np.asarray(h5_file[f"{data_group}/covariance"]), + np.asarray(h5_file[f"{data_group}/inv_covariance"]), + h5_file[f"{data_group}"].attrs["specres"], + ) - h5_file.close() + else: + spectrum = None print(" [DONE]") @@ -2695,52 +2665,50 @@ def get_samples( if burnin is None: burnin = 0 - h5_file = h5py.File(self.database, "r") - dset = h5_file[f"results/fit/{tag}/samples"] - ln_prob = np.asarray(h5_file[f"results/fit/{tag}/ln_prob"]) - - attributes = {} - for item in dset.attrs: - attributes[item] = dset.attrs[item] + with h5py.File(self.database, "r") as h5_file: + dset = h5_file[f"results/fit/{tag}/samples"] + ln_prob = np.asarray(h5_file[f"results/fit/{tag}/ln_prob"]) - spectrum = dset.attrs["spectrum"] + attributes = {} + for item in dset.attrs: + attributes[item] = dset.attrs[item] - if "n_param" in dset.attrs: - n_param = dset.attrs["n_param"] - elif "nparam" in dset.attrs: - n_param = dset.attrs["nparam"] + spectrum = dset.attrs["spectrum"] - if "ln_evidence" in dset.attrs: - ln_evidence = dset.attrs["ln_evidence"] - else: - # For backward compatibility - ln_evidence = None + if "n_param" in dset.attrs: + n_param = dset.attrs["n_param"] + elif "nparam" in dset.attrs: + n_param = dset.attrs["nparam"] - samples = np.asarray(dset) + if "ln_evidence" in dset.attrs: + ln_evidence = dset.attrs["ln_evidence"] + else: + # For backward compatibility + ln_evidence = None - if samples.ndim == 3: - if burnin > samples.shape[0]: - raise ValueError( - f"The 'burnin' value is larger than the number of steps " - f"({samples.shape[1]}) that are made by the walkers." - ) + samples = np.asarray(dset) - samples = samples[burnin:, :, :] + if samples.ndim == 3: + if burnin > samples.shape[0]: + raise ValueError( + f"The 'burnin' value is larger than the number of steps " + f"({samples.shape[1]}) that are made by the walkers." + ) - if random is not None: - ran_walker = np.random.randint(samples.shape[0], size=random) - ran_step = np.random.randint(samples.shape[1], size=random) - samples = samples[ran_walker, ran_step, :] + samples = samples[burnin:, :, :] - elif samples.ndim == 2 and random is not None: - indices = np.random.randint(samples.shape[0], size=random) - samples = samples[indices, :] + if random is not None: + ran_walker = np.random.randint(samples.shape[0], size=random) + ran_step = np.random.randint(samples.shape[1], size=random) + samples = samples[ran_walker, ran_step, :] - param = [] - for i in range(n_param): - param.append(dset.attrs[f"parameter{i}"]) + elif samples.ndim == 2 and random is not None: + indices = np.random.randint(samples.shape[0], size=random) + samples = samples[indices, :] - h5_file.close() + param = [] + for i in range(n_param): + param.append(dset.attrs[f"parameter{i}"]) median_sample = self.get_median_sample(tag, burnin) prob_sample = self.get_probable_sample(tag, burnin) @@ -2751,7 +2719,7 @@ def get_samples( for i, item in enumerate(param): samples_dict[item] = list(samples[:, i]) - with open(json_file, "w") as out_file: + with open(json_file, "w", encoding="utf-8") as out_file: json.dump(samples_dict, out_file, indent=4) return box.create_box( @@ -2836,45 +2804,43 @@ def get_pt_profiles( of the array is (n_pressures, n_samples). """ - h5_file = h5py.File(self.database, "r") - dset = h5_file[f"results/fit/{tag}/samples"] + with h5py.File(self.database, "r") as h5_file: + dset = h5_file[f"results/fit/{tag}/samples"] - spectrum = dset.attrs["spectrum"] - pt_profile = dset.attrs["pt_profile"] + spectrum = dset.attrs["spectrum"] + pt_profile = dset.attrs["pt_profile"] - if spectrum != "petitradtrans": - raise ValueError( - f"The model spectrum of the posterior samples is '{spectrum}' " - f"instead of 'petitradtrans'. Extracting P-T profiles is " - f"therefore not possible." - ) + if spectrum != "petitradtrans": + raise ValueError( + f"The model spectrum of the posterior samples is '{spectrum}' " + f"instead of 'petitradtrans'. Extracting P-T profiles is " + f"therefore not possible." + ) - if "n_param" in dset.attrs: - n_param = dset.attrs["n_param"] - elif "nparam" in dset.attrs: - n_param = dset.attrs["nparam"] + if "n_param" in dset.attrs: + n_param = dset.attrs["n_param"] + elif "nparam" in dset.attrs: + n_param = dset.attrs["nparam"] - if "temp_nodes" in dset.attrs: - temp_nodes = dset.attrs["temp_nodes"] - else: - temp_nodes = 15 + if "temp_nodes" in dset.attrs: + temp_nodes = dset.attrs["temp_nodes"] + else: + temp_nodes = 15 - samples = np.asarray(dset) + samples = np.asarray(dset) - if random is None: - n_profiles = samples.shape[0] + if random is None: + n_profiles = samples.shape[0] - else: - n_profiles = random - - indices = np.random.randint(samples.shape[0], size=random) - samples = samples[indices, :] + else: + n_profiles = random - param_index = {} - for i in range(n_param): - param_index[dset.attrs[f"parameter{i}"]] = i + indices = np.random.randint(samples.shape[0], size=random) + samples = samples[indices, :] - h5_file.close() + param_index = {} + for i in range(n_param): + param_index[dset.attrs[f"parameter{i}"]] = i press = np.logspace(-6, 3, 180) # (bar) @@ -3532,10 +3498,7 @@ def add_retrieval( knot_temp = np.asarray(knot_temp) - if "pt_smooth" in sample_dict: - pt_smooth = sample_dict["pt_smooth"] - else: - pt_smooth = radtrans["pt_smooth"] + pt_smooth = sample_dict.get("pt_smooth", radtrans["pt_smooth"]) temp = retrieval_util.pt_spline_interp( knot_press, knot_temp, pressure, pt_smooth=pt_smooth @@ -3676,235 +3639,230 @@ def get_retrieval_spectra( # Open the HDF5 database - h5_file = h5py.File(database_path, "r") - - # Read the posterior samples + with h5py.File(database_path, "r") as h5_file: + # Read the posterior samples - dset = h5_file[f"results/fit/{tag}/samples"] - samples = np.asarray(dset) + dset = h5_file[f"results/fit/{tag}/samples"] + samples = np.asarray(dset) - # Select random samples + # Select random samples - if random is None: - # Required for the printed output in the for loop - random = samples.shape[0] + if random is None: + # Required for the printed output in the for loop + random = samples.shape[0] - else: - random_indices = np.random.randint(samples.shape[0], size=random) - samples = samples[random_indices, :] + else: + random_indices = np.random.randint(samples.shape[0], size=random) + samples = samples[random_indices, :] - # Get number of model parameters + # Get number of model parameters - if "n_param" in dset.attrs: - n_param = dset.attrs["n_param"] - elif "nparam" in dset.attrs: - n_param = dset.attrs["nparam"] + if "n_param" in dset.attrs: + n_param = dset.attrs["n_param"] + elif "nparam" in dset.attrs: + n_param = dset.attrs["nparam"] - # Get number of line and cloud species + # Get number of line and cloud species - n_line_species = dset.attrs["n_line_species"] - n_cloud_species = dset.attrs["n_cloud_species"] + n_line_species = dset.attrs["n_line_species"] + n_cloud_species = dset.attrs["n_cloud_species"] - # Get number of abundance nodes + # Get number of abundance nodes - if "abund_nodes" in dset.attrs: - if dset.attrs["abund_nodes"] == "None": - abund_nodes = None + if "abund_nodes" in dset.attrs: + if dset.attrs["abund_nodes"] == "None": + abund_nodes = None + else: + abund_nodes = dset.attrs["abund_nodes"] else: - abund_nodes = dset.attrs["abund_nodes"] - else: - abund_nodes = None + abund_nodes = None - # Convert numpy boolean to regular boolean + # Convert numpy boolean to regular boolean - scattering = bool(dset.attrs["scattering"]) + scattering = bool(dset.attrs["scattering"]) - # Get chemistry attributes + # Get chemistry attributes - chemistry = dset.attrs["chemistry"] + chemistry = dset.attrs["chemistry"] - if dset.attrs["quenching"] == "None": - quenching = None - else: - quenching = dset.attrs["quenching"] + if dset.attrs["quenching"] == "None": + quenching = None + else: + quenching = dset.attrs["quenching"] - # Get P-T profile attributes + # Get P-T profile attributes - pt_profile = dset.attrs["pt_profile"] + pt_profile = dset.attrs["pt_profile"] - if "pressure_grid" in dset.attrs: - pressure_grid = dset.attrs["pressure_grid"] - else: - pressure_grid = "smaller" - - # Get free temperarture nodes - - if "temp_nodes" in dset.attrs: - if dset.attrs["temp_nodes"] == "None": - temp_nodes = None + if "pressure_grid" in dset.attrs: + pressure_grid = dset.attrs["pressure_grid"] else: - temp_nodes = dset.attrs["temp_nodes"] - - else: - # For backward compatibility - temp_nodes = None + pressure_grid = "smaller" - # Get distance + # Get free temperarture nodes - if "parallax" in dset.attrs: - distance = 1e3 / dset.attrs["parallax"][0] - elif "distance" in dset.attrs: - distance = dset.attrs["distance"] - else: - distance = None - - # Get maximum pressure + if "temp_nodes" in dset.attrs: + if dset.attrs["temp_nodes"] == "None": + temp_nodes = None + else: + temp_nodes = dset.attrs["temp_nodes"] - if "max_press" in dset.attrs: - max_press = dset.attrs["max_press"] - else: - max_press = None + else: + # For backward compatibility + temp_nodes = None - # Get model parameters + # Get distance - parameters = [] - for i in range(n_param): - parameters.append(dset.attrs[f"parameter{i}"]) + if "parallax" in dset.attrs: + distance = 1e3 / dset.attrs["parallax"][0] + elif "distance" in dset.attrs: + distance = dset.attrs["distance"] + else: + distance = None - parameters = np.asarray(parameters) + # Get maximum pressure - # Get wavelength range for median cloud optical depth + if "max_press" in dset.attrs: + max_press = dset.attrs["max_press"] + else: + max_press = None - if "log_tau_cloud" in parameters and wavel_range is not None: - cloud_wavel = (dset.attrs["wavel_min"], dset.attrs["wavel_max"]) - else: - cloud_wavel = None + # Get model parameters - # Get wavelength range for spectrum + parameters = [] + for i in range(n_param): + parameters.append(dset.attrs[f"parameter{i}"]) - if wavel_range is None: - wavel_range = (dset.attrs["wavel_min"], dset.attrs["wavel_max"]) + parameters = np.asarray(parameters) - # Create dictionary with array indices of the model parameters + # Get wavelength range for median cloud optical depth - indices = {} - for item in parameters: - indices[item] = np.argwhere(parameters == item)[0][0] + if "log_tau_cloud" in parameters and wavel_range is not None: + cloud_wavel = (dset.attrs["wavel_min"], dset.attrs["wavel_max"]) + else: + cloud_wavel = None - # Create list with line species + # Get wavelength range for spectrum - line_species = [] - for i in range(n_line_species): - line_species.append(dset.attrs[f"line_species{i}"]) + if wavel_range is None: + wavel_range = (dset.attrs["wavel_min"], dset.attrs["wavel_max"]) - # Create list with cloud species + # Create dictionary with array indices of the model parameters - cloud_species = [] - for i in range(n_cloud_species): - cloud_species.append(dset.attrs[f"cloud_species{i}"]) + indices = {} + for item in parameters: + indices[item] = np.argwhere(parameters == item)[0][0] - # Get resolution mode + # Create list with line species - if "res_mode" in dset.attrs: - res_mode = dset.attrs["res_mode"] - else: - res_mode = "c-k" + line_species = [] + for i in range(n_line_species): + line_species.append(dset.attrs[f"line_species{i}"]) - # High-resolution downsampling factor + # Create list with cloud species - if "lbl_opacity_sampling" in dset.attrs: - lbl_opacity_sampling = dset.attrs["lbl_opacity_sampling"] - else: - lbl_opacity_sampling = None + cloud_species = [] + for i in range(n_cloud_species): + cloud_species.append(dset.attrs[f"cloud_species{i}"]) - # Create an instance of ReadRadtrans - # Afterwards, the names of the cloud_species have been shortened - # from e.g. 'MgSiO3(c)_cd' to 'MgSiO3(c)' + # Get resolution mode - read_rad = read_radtrans.ReadRadtrans( - line_species=line_species, - cloud_species=cloud_species, - scattering=scattering, - wavel_range=wavel_range, - pressure_grid=pressure_grid, - cloud_wavel=cloud_wavel, - max_press=max_press, - res_mode=res_mode, - lbl_opacity_sampling=lbl_opacity_sampling, - ) + if "res_mode" in dset.attrs: + res_mode = dset.attrs["res_mode"] + else: + res_mode = "c-k" - # Set quenching attribute such that the parameter of get_model is not required + # High-resolution downsampling factor - read_rad.quenching = quenching + if "lbl_opacity_sampling" in dset.attrs: + lbl_opacity_sampling = dset.attrs["lbl_opacity_sampling"] + else: + lbl_opacity_sampling = None - # pool = multiprocessing.Pool(os.cpu_count()) - # processes = [] + # Create an instance of ReadRadtrans + # Afterwards, the names of the cloud_species have been shortened + # from e.g. 'MgSiO3(c)_cd' to 'MgSiO3(c)' - # Initiate empty list for ModelBox objects + read_rad = read_radtrans.ReadRadtrans( + line_species=line_species, + cloud_species=cloud_species, + scattering=scattering, + wavel_range=wavel_range, + pressure_grid=pressure_grid, + cloud_wavel=cloud_wavel, + max_press=max_press, + res_mode=res_mode, + lbl_opacity_sampling=lbl_opacity_sampling, + ) - boxes = [] + # Set quenching attribute such that the parameter of get_model is not required - for i, item in enumerate(samples): - print(f"\rGetting posterior spectra {i+1}/{random}...", end="") + read_rad.quenching = quenching - # Get the P-T smoothing parameter - if "pt_smooth" in dset.attrs: - pt_smooth = dset.attrs["pt_smooth"] + # pool = multiprocessing.Pool(os.cpu_count()) + # processes = [] - elif "pt_smooth_0" in parameters: - pt_smooth = {} - for j in range(temp_nodes - 1): - pt_smooth[f"pt_smooth_{j}"] = item[-1 * temp_nodes + j] + # Initiate empty list for ModelBox objects - else: - pt_smooth = item[indices["pt_smooth"]] + boxes = [] - # Calculate the petitRADTRANS spectrum + for i, item in enumerate(samples): + print(f"\rGetting posterior spectra {i+1}/{random}...", end="") - model_box = data_util.retrieval_spectrum( - indices=indices, - chemistry=chemistry, - pt_profile=pt_profile, - line_species=line_species, - cloud_species=cloud_species, - quenching=quenching, - spec_res=spec_res, - distance=distance, - pt_smooth=pt_smooth, - temp_nodes=temp_nodes, - abund_nodes=abund_nodes, - read_rad=read_rad, - sample=item, - ) + # Get the P-T smoothing parameter + if "pt_smooth" in dset.attrs: + pt_smooth = dset.attrs["pt_smooth"] - # Add the ModelBox to the list + elif "pt_smooth_0" in parameters: + pt_smooth = {} + for j in range(temp_nodes - 1): + pt_smooth[f"pt_smooth_{j}"] = item[-1 * temp_nodes + j] - boxes.append(model_box) + else: + pt_smooth = item[indices["pt_smooth"]] + + # Calculate the petitRADTRANS spectrum + + model_box = data_util.retrieval_spectrum( + indices=indices, + chemistry=chemistry, + pt_profile=pt_profile, + line_species=line_species, + cloud_species=cloud_species, + quenching=quenching, + spec_res=spec_res, + distance=distance, + pt_smooth=pt_smooth, + temp_nodes=temp_nodes, + abund_nodes=abund_nodes, + read_rad=read_rad, + sample=item, + ) - # proc = pool.apply_async(data_util.retrieval_spectrum, - # args=(indices, - # chemistry, - # pt_profile, - # line_species, - # cloud_species, - # quenching, - # spec_res, - # read_rad, - # item)) - # - # processes.append(proc) + # Add the ModelBox to the list - # pool.close() - # - # for i, item in enumerate(processes): - # boxes.append(item.get(timeout=30)) - # print(f'\rGetting posterior spectra {i+1}/{random}...', end='', flush=True) + boxes.append(model_box) - print(" [DONE]") + # proc = pool.apply_async(data_util.retrieval_spectrum, + # args=(indices, + # chemistry, + # pt_profile, + # line_species, + # cloud_species, + # quenching, + # spec_res, + # read_rad, + # item)) + # + # processes.append(proc) - # Close the HDF5 database + # pool.close() + # + # for i, item in enumerate(processes): + # boxes.append(item.get(timeout=30)) + # print(f'\rGetting posterior spectra {i+1}/{random}...', end='', flush=True) - h5_file.close() + print(" [DONE]") return boxes, read_rad @@ -4121,10 +4079,7 @@ def petitcode_param( knot_temp = np.asarray(knot_temp) - if "pt_smooth" in model_param: - pt_smooth = model_param["pt_smooth"] - else: - pt_smooth = 0.0 + pt_smooth = model_param.get("pt_smooth", 0.0) temperature = retrieval_util.pt_spline_interp( knot_press, knot_temp, pressure, pt_smooth=pt_smooth diff --git a/species/data/model_data.json b/species/data/model_data.json index 98f871b7..db16544f 100644 --- a/species/data/model_data.json +++ b/species/data/model_data.json @@ -63,7 +63,7 @@ "resolution": 3000, "teff range": [800, 3000], "reference": "Petrus et al. (2023)", - "url": "https://arxiv.org/abs/2207.06622" + "url": "https://ui.adsabs.harvard.edu/abs/2023A%26A...670L...9P/abstract" }, "blackbody": { "parameters": ["teff"], diff --git a/species/read/read_model.py b/species/read/read_model.py index 1c8c875a..25ffc09e 100644 --- a/species/read/read_model.py +++ b/species/read/read_model.py @@ -10,7 +10,7 @@ from typing import Dict, List, Optional, Tuple, Union import h5py -import spectres +# import spectres import numpy as np from PyAstronomy.pyasl import rotBroad, fastRotBroad @@ -841,9 +841,12 @@ def get_model( else: planck_box = readplanck.get_spectrum(disk_param, spec_res, smooth=False) - flux += spectres.spectres( - self.wl_points, planck_box.wavelength, planck_box.flux - ) + flux_interp = interp1d(planck_box.wavelength, planck_box.flux) + flux += flux_interp(self.wl_points) + + # flux += spectres.spectres( + # self.wl_points, planck_box.wavelength, planck_box.flux + # ) # Create ModelBox with the spectrum @@ -972,14 +975,17 @@ def get_model( # Resample the spectrum if wavel_resample is not None: - model_box.flux = spectres.spectres( - wavel_resample, - model_box.wavelength, - model_box.flux, - spec_errs=None, - fill=np.nan, - verbose=True, - ) + flux_interp = interp1d(model_box.wavelength, model_box.flux) + model_box.flux = flux_interp(wavel_resample) + + # model_box.flux = spectres.spectres( + # wavel_resample, + # model_box.wavelength, + # model_box.flux, + # spec_errs=None, + # fill=np.nan, + # verbose=True, + # ) model_box.wavelength = wavel_resample @@ -1004,14 +1010,17 @@ def get_model( wavel_resample = wavel_resample[indices] - model_box.flux = spectres.spectres( - wavel_resample, - model_box.wavelength, - model_box.flux, - spec_errs=None, - fill=np.nan, - verbose=True, - ) + flux_interp = interp1d(model_box.wavelength, model_box.flux) + model_box.flux = flux_interp(wavel_resample) + + # model_box.flux = spectres.spectres( + # wavel_resample, + # model_box.wavelength, + # model_box.flux, + # spec_errs=None, + # fill=np.nan, + # verbose=True, + # ) model_box.wavelength = wavel_resample @@ -1028,17 +1037,20 @@ def get_model( else: h5_file.close() - readcalib = read_calibration.ReadCalibration("vega", filter_name=None) - calibbox = readcalib.get_spectrum() + read_calib = read_calibration.ReadCalibration("vega", filter_name=None) + calib_box = read_calib.get_spectrum() - flux_vega, _ = spectres.spectres( - model_box.wavelength, - calibbox.wavelength, - calibbox.flux, - spec_errs=calibbox.error, - fill=np.nan, - verbose=True, - ) + flux_interp = interp1d(calib_box.wavelength, calib_box.flux) + flux_vega = flux_interp(model_box.wavelength) + + # flux_vega, _ = spectres.spectres( + # model_box.wavelength, + # calib_box.wavelength, + # calib_box.flux, + # spec_errs=calib_box.error, + # fill=np.nan, + # verbose=True, + # ) model_box.flux = -2.5 * np.log10(model_box.flux / flux_vega) model_box.quantity = "magnitude" @@ -1269,7 +1281,10 @@ def get_data( else: planck_box = readplanck.get_spectrum(disk_param, spec_res, smooth=False) - flux += spectres.spectres(wl_points, planck_box.wavelength, planck_box.flux) + flux_interp = interp1d(planck_box.wavelength, planck_box.flux) + flux += flux_interp(wl_points) + + # flux += spectres.spectres(wl_points, planck_box.wavelength, planck_box.flux) # Create ModelBox with the spectrum @@ -1344,14 +1359,17 @@ def get_data( # Resample the spectrum if wavel_resample is not None: - model_box.flux = spectres.spectres( - wavel_resample, - model_box.wavelength, - model_box.flux, - spec_errs=None, - fill=np.nan, - verbose=True, - ) + flux_interp = interp1d(model_box.wavelength, model_box.flux) + model_box.flux = flux_interp(wavel_resample) + + # model_box.flux = spectres.spectres( + # wavel_resample, + # model_box.wavelength, + # model_box.flux, + # spec_errs=None, + # fill=np.nan, + # verbose=True, + # ) model_box.wavelength = wavel_resample