Skip to content

Commit

Permalink
Added the font_size parameter to plot_spectrum
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasstolker committed Aug 19, 2024
1 parent c9a8149 commit 03b8bd0
Showing 1 changed file with 50 additions and 16 deletions.
66 changes: 50 additions & 16 deletions species/plot/plot_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def plot_spectrum(
grid_hspace: float = 0.1,
inc_model_name: bool = False,
units: Tuple[str, str] = ("um", "W m-2 um-1"),
font_size: Optional[Dict[str, float]] = None,
) -> mpl.figure.Figure:
"""
Function for plotting a spectral energy distribution and combining
Expand Down Expand Up @@ -173,6 +174,13 @@ def plot_spectrum(
Tuple with the wavelength and flux units. Supported
units can be found in the docstring of
:func:`~species.util.data_util.convert_units`.
font_size : dict(str, float), None
Dictionary with the font sizes. The keys can be set to
'xlabel', 'ylabel', 'title', and 'legend'. The values
should be set to the font sizes. Default font size are
used when setting the argument to ``None``. The legend
font size is not used if it is also set with the
``legend`` parameter.
Returns
-------
Expand Down Expand Up @@ -413,35 +421,52 @@ def plot_spectrum(
else:
y_unit = units[1]

if font_size is None:
font_size = {}

if "xlabel" not in font_size:
font_size["xlabel"] = 11.

if "ylabel" not in font_size:
font_size["ylabel"] = 11.

if "title" not in font_size:
font_size["title"] = 13.

if "legend" not in font_size:
font_size["legend"] = 9.

print(f"Font sizes: {font_size}")

if residuals is not None and filters is not None:
ax1.set_xlabel("")
ax2.set_xlabel("")
ax3.set_xlabel(x_label, fontsize=11)
ax3.set_xlabel(x_label, fontsize=font_size["xlabel"])

elif residuals is not None:
ax1.set_xlabel("")
ax3.set_xlabel(x_label, fontsize=11)
ax3.set_xlabel(x_label, fontsize=font_size["xlabel"])

elif filters is not None:
ax1.set_xlabel(x_label, fontsize=11)
ax1.set_xlabel(x_label, fontsize=font_size["xlabel"])
ax2.set_xlabel("")

else:
ax1.set_xlabel(x_label, fontsize=11)
ax1.set_xlabel(x_label, fontsize=font_size["xlabel"])

if filters is not None:
ax2.set_ylabel(r"$T_\lambda$", fontsize=11)
ax2.set_ylabel(r"$T_\lambda$", fontsize=font_size["ylabel"])

if residuals is not None:
if quantity == "flux density":
ax3.set_ylabel(r"$\Delta$$F_\lambda$ ($\sigma$)", fontsize=11)
ax3.set_ylabel(r"$\Delta$$F_\lambda$ ($\sigma$)", fontsize=font_size["ylabel"])

elif quantity == "flux":
ax3.set_ylabel(r"$\Delta$$F_\lambda$ ($\sigma$)", fontsize=11)
ax3.set_ylabel(r"$\Delta$$F_\lambda$ ($\sigma$)", fontsize=font_size["ylabel"])

if quantity == "magnitude":
scaling = 1.0
ax1.set_ylabel("Contrast (mag)", fontsize=11)
ax1.set_ylabel("Contrast (mag)", fontsize=font_size["ylabel"])

if ylim is not None:
ax1.set_ylim(ylim[0], ylim[1])
Expand Down Expand Up @@ -480,7 +505,7 @@ def plot_spectrum(
+ r"}$ W m$^{-2}$)"
)

ax1.set_ylabel(ylabel, fontsize=11)
ax1.set_ylabel(ylabel, fontsize=font_size["ylabel"])
ax1.set_ylim(ylim[0] / scaling, ylim[1] / scaling)

if ylim[0] < 0.0:
Expand All @@ -492,11 +517,11 @@ def plot_spectrum(
if quantity == "flux density":
ax1.set_ylabel(
rf"$F_\lambda$ ({y_unit})",
fontsize=11,
fontsize=font_size["ylabel"],
)

elif quantity == "flux":
ax1.set_ylabel(r"$\lambda$$F_\lambda$ (W m$^{-2}$)", fontsize=11)
ax1.set_ylabel(r"$\lambda$$F_\lambda$ (W m$^{-2}$)", fontsize=font_size["ylabel"])

scaling = 1.0

Expand Down Expand Up @@ -1337,9 +1362,9 @@ def plot_spectrum(

if title is not None:
if filters:
ax2.set_title(title, y=1.02, fontsize=13)
ax2.set_title(title, y=1.02, fontsize=font_size["title"])
else:
ax1.set_title(title, y=1.02, fontsize=13)
ax1.set_title(title, y=1.02, fontsize=font_size["title"])

handles, labels = ax1.get_legend_handles_labels()

Expand All @@ -1364,10 +1389,13 @@ def plot_spectrum(
model_handles,
model_labels,
loc=legend[0],
fontsize=10.0,
fontsize=font_size["legend"],
frameon=False,
)
else:
if "fontsize" not in legend[0]:
legend[0]["fontsize"] = font_size["legend"]

leg_1 = ax1.legend(model_handles, model_labels, **legend[0])

else:
Expand All @@ -1379,19 +1407,25 @@ def plot_spectrum(
data_handles,
data_labels,
loc=legend[1],
fontsize=8,
fontsize=font_size["legend"],
frameon=False,
)
else:
if "fontsize" not in legend[1]:
legend[1]["fontsize"] = font_size["legend"]

ax1.legend(data_handles, data_labels, **legend[1])

if leg_1 is not None:
ax1.add_artist(leg_1)

elif isinstance(legend, (str, tuple)):
ax1.legend(loc=legend, fontsize=8, frameon=False)
ax1.legend(loc=legend, fontsize=font_size["legend"], frameon=False)

else:
if "fontsize" not in legend:
legend["fontsize"] = font_size["legend"]

ax1.legend(**legend)

# if scale[0] == "log":
Expand Down

0 comments on commit 03b8bd0

Please sign in to comment.