diff --git a/species/plot/plot_spectrum.py b/species/plot/plot_spectrum.py index 65b518a..f78ad28 100644 --- a/species/plot/plot_spectrum.py +++ b/species/plot/plot_spectrum.py @@ -58,6 +58,7 @@ def plot_spectrum( grid_hspace: float = 0.1, inc_model_name: bool = False, units: Tuple[str, str] = ("um", "W m-2 um-1"), + font_size: Optional[Dict[str, float]] = None, ) -> mpl.figure.Figure: """ Function for plotting a spectral energy distribution and combining @@ -173,6 +174,13 @@ def plot_spectrum( Tuple with the wavelength and flux units. Supported units can be found in the docstring of :func:`~species.util.data_util.convert_units`. + font_size : dict(str, float), None + Dictionary with the font sizes. The keys can be set to + 'xlabel', 'ylabel', 'title', and 'legend'. The values + should be set to the font sizes. Default font size are + used when setting the argument to ``None``. The legend + font size is not used if it is also set with the + ``legend`` parameter. Returns ------- @@ -413,35 +421,52 @@ def plot_spectrum( else: y_unit = units[1] + if font_size is None: + font_size = {} + + if "xlabel" not in font_size: + font_size["xlabel"] = 11. + + if "ylabel" not in font_size: + font_size["ylabel"] = 11. + + if "title" not in font_size: + font_size["title"] = 13. + + if "legend" not in font_size: + font_size["legend"] = 9. + + print(f"Font sizes: {font_size}") + if residuals is not None and filters is not None: ax1.set_xlabel("") ax2.set_xlabel("") - ax3.set_xlabel(x_label, fontsize=11) + ax3.set_xlabel(x_label, fontsize=font_size["xlabel"]) elif residuals is not None: ax1.set_xlabel("") - ax3.set_xlabel(x_label, fontsize=11) + ax3.set_xlabel(x_label, fontsize=font_size["xlabel"]) elif filters is not None: - ax1.set_xlabel(x_label, fontsize=11) + ax1.set_xlabel(x_label, fontsize=font_size["xlabel"]) ax2.set_xlabel("") else: - ax1.set_xlabel(x_label, fontsize=11) + ax1.set_xlabel(x_label, fontsize=font_size["xlabel"]) if filters is not None: - ax2.set_ylabel(r"$T_\lambda$", fontsize=11) + ax2.set_ylabel(r"$T_\lambda$", fontsize=font_size["ylabel"]) if residuals is not None: if quantity == "flux density": - ax3.set_ylabel(r"$\Delta$$F_\lambda$ ($\sigma$)", fontsize=11) + ax3.set_ylabel(r"$\Delta$$F_\lambda$ ($\sigma$)", fontsize=font_size["ylabel"]) elif quantity == "flux": - ax3.set_ylabel(r"$\Delta$$F_\lambda$ ($\sigma$)", fontsize=11) + ax3.set_ylabel(r"$\Delta$$F_\lambda$ ($\sigma$)", fontsize=font_size["ylabel"]) if quantity == "magnitude": scaling = 1.0 - ax1.set_ylabel("Contrast (mag)", fontsize=11) + ax1.set_ylabel("Contrast (mag)", fontsize=font_size["ylabel"]) if ylim is not None: ax1.set_ylim(ylim[0], ylim[1]) @@ -480,7 +505,7 @@ def plot_spectrum( + r"}$ W m$^{-2}$)" ) - ax1.set_ylabel(ylabel, fontsize=11) + ax1.set_ylabel(ylabel, fontsize=font_size["ylabel"]) ax1.set_ylim(ylim[0] / scaling, ylim[1] / scaling) if ylim[0] < 0.0: @@ -492,11 +517,11 @@ def plot_spectrum( if quantity == "flux density": ax1.set_ylabel( rf"$F_\lambda$ ({y_unit})", - fontsize=11, + fontsize=font_size["ylabel"], ) elif quantity == "flux": - ax1.set_ylabel(r"$\lambda$$F_\lambda$ (W m$^{-2}$)", fontsize=11) + ax1.set_ylabel(r"$\lambda$$F_\lambda$ (W m$^{-2}$)", fontsize=font_size["ylabel"]) scaling = 1.0 @@ -1337,9 +1362,9 @@ def plot_spectrum( if title is not None: if filters: - ax2.set_title(title, y=1.02, fontsize=13) + ax2.set_title(title, y=1.02, fontsize=font_size["title"]) else: - ax1.set_title(title, y=1.02, fontsize=13) + ax1.set_title(title, y=1.02, fontsize=font_size["title"]) handles, labels = ax1.get_legend_handles_labels() @@ -1364,10 +1389,13 @@ def plot_spectrum( model_handles, model_labels, loc=legend[0], - fontsize=10.0, + fontsize=font_size["legend"], frameon=False, ) else: + if "fontsize" not in legend[0]: + legend[0]["fontsize"] = font_size["legend"] + leg_1 = ax1.legend(model_handles, model_labels, **legend[0]) else: @@ -1379,19 +1407,25 @@ def plot_spectrum( data_handles, data_labels, loc=legend[1], - fontsize=8, + fontsize=font_size["legend"], frameon=False, ) else: + if "fontsize" not in legend[1]: + legend[1]["fontsize"] = font_size["legend"] + ax1.legend(data_handles, data_labels, **legend[1]) if leg_1 is not None: ax1.add_artist(leg_1) elif isinstance(legend, (str, tuple)): - ax1.legend(loc=legend, fontsize=8, frameon=False) + ax1.legend(loc=legend, fontsize=font_size["legend"], frameon=False) else: + if "fontsize" not in legend: + legend["fontsize"] = font_size["legend"] + ax1.legend(**legend) # if scale[0] == "log":