Skip to content

Commit

Permalink
Added param_fmt parameter to plot_spectrum, fixed issue with appendin…
Browse files Browse the repository at this point in the history
…g to database during multiprocessing, updated plot_util
  • Loading branch information
tomasstolker committed Feb 12, 2024
1 parent d491cf0 commit 093df1d
Show file tree
Hide file tree
Showing 4 changed files with 173 additions and 94 deletions.
51 changes: 49 additions & 2 deletions species/plot/plot_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import math
import warnings

from typing import Optional, Union, Tuple, List
from typing import Dict, List, Optional, Tuple, Union

import numpy as np
import matplotlib as mpl
Expand Down Expand Up @@ -53,6 +53,7 @@ def plot_spectrum(
quantity: str = "flux density",
output: Optional[str] = None,
leg_param: Optional[List[str]] = None,
param_fmt: Optional[Dict[str, str]] = None,
grid_hspace: float = 0.1,
inc_model_name: bool = False,
units: Tuple[str, str] = ("um", "W m-2 um-1"),
Expand Down Expand Up @@ -151,6 +152,13 @@ def plot_spectrum(
and 'luminosity' can be included. The default atmospheric
parameters are included in the legend if the argument is
set to ``None``.
param_fmt : dict(str, str), None
Dictionary with formats that will be used for the model
parameter. The parameters are included in the ``legend``
when plotting the model spectra. Default formats are
used if the argument of ``param_fmt`` is set to ``None``.
Formats should provided for example as '.2f' for two
decimals and '.0f' for zero decimals.
grid_hspace : float
The relative height spacing between subplots, expressed
as a fraction of the average axis height. The default
Expand Down Expand Up @@ -181,6 +189,44 @@ def plot_spectrum(
f"number of items in 'plot_kwargs' ({len(plot_kwargs)})."
)

if leg_param is None:
leg_param = []

if param_fmt is None:
param_fmt = {}

# Add missing parameter formats

param_add = ["teff", "disk_teff", "disk_radius"]

for param_item in param_add:
if param_item not in param_fmt:
param_fmt[param_item] = ".0f"

param_add = [
"radius",
"logg",
"feh",
"metallicity",
"fsed",
"distance",
"parallax",
"mass",
"ism_ext",
"lognorm_ext",
"powerlaw_ext",
]

for param_item in param_add:
if param_item not in param_fmt:
param_fmt[param_item] = ".1f"

param_add = ["co", "c_o_ratio", "ad_index", "luminosity"]

for param_item in param_add:
if param_item not in param_fmt:
param_fmt[param_item] = ".2f"

if residuals is not None and filters is not None:
fig = plt.figure(figsize=figsize)
grid_sp = mpl.gridspec.GridSpec(3, 1, height_ratios=[1, 3, 1])
Expand Down Expand Up @@ -500,10 +546,11 @@ def plot_spectrum(

label = create_model_label(
model_param=param,
object_type=object_type,
model_name=box_item.model,
inc_model_name=inc_model_name,
object_type=object_type,
leg_param=leg_param,
param_fmt=param_fmt,
)

else:
Expand Down
20 changes: 15 additions & 5 deletions species/read/read_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,27 @@ def open_database(self) -> h5py._hl.files.File:
The HDF5 database.
"""

with h5py.File(self.database, "a") as hdf5_file:
if f"models/{self.model}" not in hdf5_file:
with h5py.File(self.database, "r") as hdf5_file:
if f"models/{self.model}" in hdf5_file:
model_found = True

else:
model_found = False

warnings.warn(
f"The '{self.model}' model spectra are not present "
"in the database. Will try to add the model grid. "
"If this does not work (e.g. currently without an "
"internet connection) then please use the "
"internet connection), then please use the "
"'add_model' method of 'Database' to add the "
"grid of spectra at a later moment."
)

if not model_found:
# This will not work when using multiprocessing.
# Model spectra should be added to the database
# before running FitModel with MPI
with h5py.File(self.database, "a") as hdf5_file:
add_model_grid(self.model, self.data_folder, hdf5_file)

return h5py.File(self.database, "r")
Expand Down Expand Up @@ -865,7 +875,7 @@ def get_model(
flux *= model_param["flux_scaling"]

elif "log_flux_scaling" in model_param:
flux *= 10.0**model_param["log_flux_scaling"]
flux *= 10.0 ** model_param["log_flux_scaling"]

# Add optional offset to the flux

Expand Down Expand Up @@ -1326,7 +1336,7 @@ def get_data(
flux *= model_param["flux_scaling"]

elif "log_flux_scaling" in model_param:
flux *= 10.0**model_param["log_flux_scaling"]
flux *= 10.0 ** model_param["log_flux_scaling"]

# Add optional offset to the flux

Expand Down
Loading

0 comments on commit 093df1d

Please sign in to comment.