Skip to content

Commit

Permalink
Several minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasstolker committed Jun 4, 2024
1 parent 4054308 commit cb0b6eb
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 32 deletions.
3 changes: 2 additions & 1 deletion species/data/model_data/model_data.json
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@
"lambda/d_lambda": 5000,
"teff range": [2300, 12000],
"reference": "Husser et al. (2013)",
"url": "https://ui.adsabs.harvard.edu/abs/2013A%26A...553A...6H"
"url": "https://ui.adsabs.harvard.edu/abs/2013A%26A...553A...6H",
"information": "[alpha/Fe] = 0.0"
},
"saumon2008-clear": {
"parameters": ["teff", "logg"],
Expand Down
6 changes: 5 additions & 1 deletion species/plot/plot_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
from species.read.read_model import ReadModel
from species.read.read_object import ReadObject
from species.util.dust_util import ism_extinction
from species.util.plot_util import create_model_label, create_param_format, update_labels
from species.util.plot_util import (
create_model_label,
create_param_format,
update_labels,
)
from species.util.spec_util import smooth_spectrum


Expand Down
27 changes: 20 additions & 7 deletions species/plot/plot_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def plot_posterior(
Include the mass in the posterior plot as calculated
from the surface gravity and radius.
inc_log_mass : bool
Include the logarithm of the mass, :math:`\\log10{M}`, in
Include the logarithm of the mass, :math:`\\log_{10}{M}`, in
the posterior plot, as calculated from the surface gravity
and radius.
inc_pt_param : bool
Expand Down Expand Up @@ -523,7 +523,10 @@ def plot_posterior(

else:
for disk_idx in range(100):
if f"disk_teff_{disk_idx}" in box.parameters and f"disk_radius_{disk_idx}" in box.parameters:
if (
f"disk_teff_{disk_idx}" in box.parameters
and f"disk_radius_{disk_idx}" in box.parameters
):
n_disk += 1
else:
break
Expand Down Expand Up @@ -568,8 +571,12 @@ def plot_posterior(
lum_disk = 0.0

for disk_idx in range(n_disk):
teff_index = np.argwhere(np.array(box.parameters) == f"disk_teff_{disk_idx}")[0]
radius_index = np.argwhere(np.array(box.parameters) == f"disk_radius_{disk_idx}")[0]
teff_index = np.argwhere(
np.array(box.parameters) == f"disk_teff_{disk_idx}"
)[0]
radius_index = np.argwhere(
np.array(box.parameters) == f"disk_radius_{disk_idx}"
)[0]

lum_disk += (
4.0
Expand Down Expand Up @@ -818,7 +825,9 @@ def plot_posterior(

for radius_idx in range(100):
if f"radius_{radius_idx}" in box.parameters:
radius_index = np.argwhere(np.array(box.parameters) == f"radius_{radius_idx}")[0]
radius_index = np.argwhere(
np.array(box.parameters) == f"radius_{radius_idx}"
)[0]
if object_type == "star":
samples[:, radius_index] *= constants.R_JUP / constants.R_SUN
else:
Expand Down Expand Up @@ -867,7 +876,9 @@ def plot_posterior(

for disk_idx in range(100):
if f"disk_radius_{disk_idx}" in box.parameters:
radius_index = np.argwhere(np.array(box.parameters) == f"disk_radius_{disk_idx}")[0]
radius_index = np.argwhere(
np.array(box.parameters) == f"disk_radius_{disk_idx}"
)[0]
if object_type == "star":
samples[:, radius_index] *= constants.R_JUP / constants.AU
else:
Expand All @@ -880,7 +891,9 @@ def plot_posterior(

for disk_idx in range(100):
if f"radius_bb_{disk_idx}" in box.parameters:
radius_index = np.argwhere(np.array(box.parameters) == f"radius_bb_{disk_idx}")[0]
radius_index = np.argwhere(
np.array(box.parameters) == f"radius_bb_{disk_idx}"
)[0]
if object_type == "star":
samples[:, radius_index] *= constants.R_JUP / constants.AU
else:
Expand Down
77 changes: 56 additions & 21 deletions species/plot/plot_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,11 +874,24 @@ def plot_spectrum(
fwhm_micron = transmission.filter_fwhm()

if isinstance(box_item.flux[filter_item][0], np.ndarray):
raise NotImplementedError(
"Unit conversion has not yet been implemented! "
"Please open an issue on Github."
wavel_array = np.full(
box_item.flux[filter_item].shape[1], wavel_micron
)

data_in = np.column_stack(
[
wavel_array,
box_item.flux[filter_item][0],
box_item.flux[filter_item][1],
]
)

data_out = convert_units(data_in, units, convert_from=False)

wavelength = data_out[:, 0]
flux_conv = data_out[:, 1]
sigma_conv = data_out[:, 2]

else:
data_in = np.column_stack(
[
Expand Down Expand Up @@ -918,10 +931,6 @@ def plot_spectrum(
# wavelength to frequency
fwhm = (fwhm_up + fwhm_down) / 2.0

if plot_kwargs[j] and filter_item in plot_kwargs[j]:
if "label" in plot_kwargs[j][filter_item]:
labels_data.append(plot_kwargs[j][filter_item]["label"])

if not plot_kwargs[j] or filter_item not in plot_kwargs[j]:
if not plot_kwargs[j]:
plot_kwargs[j] = {}
Expand All @@ -932,12 +941,13 @@ def plot_spectrum(
scale_tmp = flux_scaling / scaling

if isinstance(box_item.flux[filter_item][0], np.ndarray):
for i in range(box_item.flux[filter_item].shape[1]):
for phot_idx in range(box_item.flux[filter_item].shape[1]):
plot_obj = ax1.errorbar(
wavelength,
scale_tmp * box_item.flux[filter_item][0, i],
scale_tmp * box_item.flux[filter_item][0, phot_idx],
xerr=fwhm / 2.0,
yerr=scale_tmp * box_item.flux[filter_item][1, i],
yerr=scale_tmp
* box_item.flux[filter_item][1, phot_idx],
marker="s",
ms=5,
zorder=3,
Expand Down Expand Up @@ -970,28 +980,47 @@ def plot_spectrum(
if not isinstance(plot_kwargs[j][filter_item], list):
raise ValueError(
f"A list with {box_item.flux[filter_item].shape[1]} "
f"dictionaries are required because the filter "
f"dictionaries is required because the filter "
f"{filter_item} has {box_item.flux[filter_item].shape[1]} "
f"values."
)

for i in range(box_item.flux[filter_item].shape[1]):
if "zorder" not in plot_kwargs[j][filter_item][i]:
plot_kwargs[j][filter_item][i]["zorder"] = 3.0
for phot_idx in range(box_item.flux[filter_item].shape[1]):
if (
"zorder"
not in plot_kwargs[j][filter_item][phot_idx]
):
plot_kwargs[j][filter_item][phot_idx][
"zorder"
] = 3.0

if plot_kwargs[j] and filter_item in plot_kwargs[j]:
if "label" in plot_kwargs[j][filter_item][phot_idx]:
labels_data.append(
plot_kwargs[j][filter_item][phot_idx][
"label"
]
)

ax1.errorbar(
wavelength,
wavelength[phot_idx],
flux_scaling
* box_item.flux[filter_item][0, i]
* box_item.flux[filter_item][0, phot_idx]
/ scaling,
xerr=fwhm / 2.0,
yerr=flux_scaling
* box_item.flux[filter_item][1, i]
* box_item.flux[filter_item][1, phot_idx]
/ scaling,
**plot_kwargs[j][filter_item][i],
**plot_kwargs[j][filter_item][phot_idx],
)

else:
if plot_kwargs[j] and filter_item in plot_kwargs[j]:
if "label" in plot_kwargs[j][filter_item]:
labels_data.append(
plot_kwargs[j][filter_item]["label"]
)

if box_item.flux[filter_item][1] == 0.0:
if "zorder" not in plot_kwargs[j][filter_item]:
plot_kwargs[j][filter_item]["zorder"] = 3.0
Expand Down Expand Up @@ -1242,10 +1271,16 @@ def plot_spectrum(

for sigma_item in sigma_line:
if res_lim > sigma_item or (
ylim_res is not None and ylim_res[0] < -sigma_item and ylim_res[1] > sigma_item
ylim_res is not None
and ylim_res[0] < -sigma_item
and ylim_res[1] > sigma_item
):
ax3.axhline(-sigma_item, ls=":", lw=0.7, color="gray", dashes=(1, 4), zorder=0.5)
ax3.axhline(sigma_item, ls=":", lw=0.7, color="gray", dashes=(1, 4), zorder=0.5)
ax3.axhline(
-sigma_item, ls=":", lw=0.7, color="gray", dashes=(1, 4), zorder=0.5
)
ax3.axhline(
sigma_item, ls=":", lw=0.7, color="gray", dashes=(1, 4), zorder=0.5
)

if ylim_res is None:
ax3.set_ylim(-res_lim, res_lim)
Expand Down
4 changes: 2 additions & 2 deletions species/read/read_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from PyAstronomy.pyasl import rotBroad, fastRotBroad
from typeguard import typechecked
from scipy.integrate import simps
from scipy.integrate import simpson
from scipy.interpolate import interp1d, RegularGridInterpolator

from species.core import constants
Expand Down Expand Up @@ -1887,7 +1887,7 @@ def integrate_spectrum(self, model_param: Dict[str, float]) -> float:
4.0
* np.pi
* (model_param["radius"] * constants.R_JUP) ** 2
* simps(model_box.flux, model_box.wavelength)
* simpson(y=model_box.flux, x=model_box.wavelength)
)

return np.log10(bol_lum / constants.L_SUN)
Expand Down

0 comments on commit cb0b6eb

Please sign in to comment.