Skip to content

Commit

Permalink
Merge pull request #13 from tomasstolker/legend
Browse files Browse the repository at this point in the history
Separation of data and model legend in plot_spectrum
  • Loading branch information
Tomas Stolker authored May 13, 2020
2 parents 2c956de + 32ff773 commit 32a65b8
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 62 deletions.
16 changes: 10 additions & 6 deletions species/plot/plot_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,6 @@
from species.util import plot_util


mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)


@typechecked
def plot_color_magnitude(boxes: list,
objects: Optional[Union[List[Tuple[str, str, str, str]],
Expand Down Expand Up @@ -107,6 +101,11 @@ def plot_color_magnitude(boxes: list,
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

model_color = ('#234398', '#f6a432', 'black')
model_linestyle = ('-', '--', ':', '-.')

Expand Down Expand Up @@ -559,6 +558,11 @@ def plot_color_color(boxes: list,
None
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

model_color = ('#234398', '#f6a432', 'black')
model_linestyle = ('-', '--', ':', '-.')

Expand Down
21 changes: 15 additions & 6 deletions species/plot/plot_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,6 @@
from species.util import plot_util


mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)


@typechecked
def plot_walkers(tag: str,
nsteps: Optional[int] = None,
Expand Down Expand Up @@ -52,6 +46,11 @@ def plot_walkers(tag: str,

print(f'Plotting walkers: {output}...', end='', flush=True)

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

species_db = database.Database()
box = species_db.get_samples(tag)

Expand Down Expand Up @@ -154,6 +153,11 @@ def plot_posterior(tag: str,
None
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

if burnin is None:
burnin = 0

Expand Down Expand Up @@ -303,6 +307,11 @@ def plot_photometry(tag,
None
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)

species_db = database.Database()

samples = species_db.get_mcmc_photometry(tag, burnin, filter_id)
Expand Down
179 changes: 129 additions & 50 deletions species/plot/plot_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,39 @@

import os
import math
import warnings
import itertools

from typing import Optional, Union, Tuple, List

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

from typeguard import typechecked

from species.core import box, constants
from species.read import read_filter
from species.util import plot_util


mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)
plt.rcParams['axes.axisbelow'] = False


def plot_spectrum(boxes,
filters=None,
residuals=None,
plot_kwargs=None,
xlim=None,
ylim=None,
ylim_res=None,
scale=('linear', 'linear'),
title=None,
offset=None,
legend=None,
figsize=(7., 5.),
object_type='planet',
quantity='flux',
output='spectrum.pdf'):
@typechecked
def plot_spectrum(boxes: list,
filters: Optional[List[str]] = None,
residuals: Optional[box.ResidualsBox] = None,
plot_kwargs: Optional[List[Optional[dict]]] = None,
xlim: Optional[Tuple[float, float]] = None,
ylim: Optional[Tuple[float, float]] = None,
ylim_res: Optional[Tuple[float, float]] = None,
scale: Optional[Tuple[str, str]] = None,
title: Optional[str] = None,
offset: Optional[Tuple[float, float]] = None,
legend: Union[str, dict, Tuple[float, float],
List[Optional[Union[dict, str, Tuple[float, float]]]]] = None,
figsize: Optional[Tuple[float, float]] = (7., 5.),
object_type: str = 'planet',
quantity: str = 'flux',
output: str = 'spectrum.pdf'):
"""
Parameters
----------
Expand Down Expand Up @@ -75,15 +75,20 @@ def plot_spectrum(boxes,
ylim_res : tuple(float, float), None
Limits of the residuals axis. Automatically chosen (based on the minimum and maximum
residual value) if set to None.
scale : tuple(str, str)
Scale of the axes ('linear' or 'log').
scale : tuple(str, str), None
Scale of the x and y axes ('linear' or 'log'). The scale is set to ``('linear', 'linear')``
if set to ``None``.
title : str
Title.
offset : tuple(float, float)
Offset for the label of the x- and y-axis.
legend : str, tuple, dict, None
Location of the legend (str, tuple) or a dictionary with the ``**kwargs`` of
``matplotlib.pyplot.legend``, e.g. ``{'loc': 'upper left', 'fontsize: 12.}``.
legend : str, tuple, dict, list(dict, dict), None
Location of the legend (str or tuple(float, float)) or a dictionary with the ``**kwargs``
of ``matplotlib.pyplot.legend``, for example ``{'loc': 'upper left', 'fontsize: 12.}``.
Alternatively, a list with two values can be provided to separate the model and data
handles in two legends. Each of these two elements can be set to ``None``. For example,
``[None, {'loc': 'upper left', 'fontsize: 12.}]``, if only the data points should be
included in a legend.
figsize : tuple(float, float)
Figure size.
object_type : str
Expand All @@ -100,6 +105,12 @@ def plot_spectrum(boxes,
None
"""

mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
mpl.rcParams['font.family'] = 'serif'

plt.rc('axes', edgecolor='black', linewidth=2.2)
plt.rcParams['axes.axisbelow'] = False

if plot_kwargs is None:
plot_kwargs = []

Expand Down Expand Up @@ -190,7 +201,7 @@ def plot_spectrum(boxes,
ax2.set_ylabel('Transmission', fontsize=13)

if residuals is not None:
ax3.set_ylabel(r'Residual ($\sigma$)', fontsize=13)
ax3.set_ylabel(r'$\Delta$$F_\lambda$ ($\sigma$)', fontsize=13)

if xlim is not None:
ax1.set_xlim(xlim[0], xlim[1])
Expand All @@ -213,7 +224,7 @@ def plot_spectrum(boxes,
exponent = math.floor(math.log10(ylim[1]))
scaling = 10.**exponent

ylabel = r'Flux (10$^{'+str(exponent)+r'}$ W m$^{-2}$ $\mu$m$^{-1}$)'
ylabel = r'$F_\lambda$ (10$^{'+str(exponent)+r'}$ W m$^{-2}$ $\mu$m$^{-1}$)'

ax1.set_ylabel(ylabel, fontsize=13)
ax1.set_ylim(ylim[0]/scaling, ylim[1]/scaling)
Expand All @@ -222,7 +233,7 @@ def plot_spectrum(boxes,
ax1.axhline(0.0, linestyle='--', color='gray', dashes=(2, 4), zorder=0.5)

else:
ax1.set_ylabel(r'Flux (W m$^{-2}$ $\mu$m$^{-1}$)', fontsize=13)
ax1.set_ylabel(r'$F_\lambda$ (W m$^{-2}$ $\mu$m$^{-1}$)', fontsize=13)
scaling = 1.

if filters is not None:
Expand Down Expand Up @@ -263,6 +274,9 @@ def plot_spectrum(boxes,
ax1.get_xaxis().set_label_coords(0.5, -0.12)
ax1.get_yaxis().set_label_coords(-0.1, 0.5)

if scale is None:
scale = ('linear', 'linear')

ax1.set_xscale(scale[0])
ax1.set_yscale(scale[1])

Expand Down Expand Up @@ -346,7 +360,7 @@ def plot_spectrum(boxes,
label = kwargs_copy['label']

del kwargs_copy['label']

ax1.plot(wavelength, masked/scaling, zorder=2, label=label, **kwargs_copy)

else:
Expand Down Expand Up @@ -391,8 +405,52 @@ def plot_spectrum(boxes,
zorder=3)

elif isinstance(boxitem, box.ObjectBox):
if boxitem.spectrum is not None:
spec_list = []
wavel_list = []

for item in boxitem.spectrum:
spec_list.append(item)
wavel_list.append(boxitem.spectrum[item][0][0, 0])

sort_index = np.argsort(wavel_list)
spec_sort = []

for i in range(sort_index.size):
spec_sort.append(spec_list[sort_index[i]])

for key in spec_sort:
masked = np.ma.array(boxitem.spectrum[key][0],
mask=np.isnan(boxitem.spectrum[key][0]))

if not plot_kwargs[j] or key not in plot_kwargs[j]:
plot_obj = ax1.errorbar(masked[:, 0], masked[:, 1]/scaling,
yerr=masked[:, 2]/scaling, ms=2, marker='s',
zorder=2.5, ls='none')

plot_kwargs[j][key] = {'marker': 's', 'ms': 2., 'ls': 'none',
'color': plot_obj[0].get_color()}

else:
ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling,
zorder=2.5, **plot_kwargs[j][key])

if boxitem.flux is not None:
filter_list = []
wavel_list = []

for item in boxitem.flux:
read_filt = read_filter.ReadFilter(item)
filter_list.append(item)
wavel_list.append(read_filt.mean_wavelength())

sort_index = np.argsort(wavel_list)
filter_sort = []

for i in range(sort_index.size):
filter_sort.append(filter_list[sort_index[i]])

for item in filter_sort:
transmission = read_filter.ReadFilter(item)
wavelength = transmission.mean_wavelength()
fwhm = transmission.filter_fwhm()
Expand Down Expand Up @@ -430,23 +488,6 @@ def plot_spectrum(boxes,
ax1.errorbar(wavelength, boxitem.flux[item][0]/scaling, xerr=fwhm/2.,
yerr=boxitem.flux[item][1]/scaling, zorder=3, **plot_kwargs[j][item])

if boxitem.spectrum is not None:
for key, value in boxitem.spectrum.items():
masked = np.ma.array(boxitem.spectrum[key][0],
mask=np.isnan(boxitem.spectrum[key][0]))

if not plot_kwargs[j] or key not in plot_kwargs[j]:
plot_obj = ax1.errorbar(masked[:, 0], masked[:, 1]/scaling,
yerr=masked[:, 2]/scaling, ms=2, marker='s',
zorder=2.5, ls='none')

plot_kwargs[j][key] = {'marker': 's', 'ms': 2., 'ls': 'none',
'color': plot_obj[0].get_color()}

else:
ax1.errorbar(masked[:, 0], masked[:, 1]/scaling, yerr=masked[:, 2]/scaling,
zorder=2.5, **plot_kwargs[j][key])

elif isinstance(boxitem, box.SynphotBox):
for i, find_item in enumerate(boxes):
if isinstance(find_item, box.ObjectBox):
Expand Down Expand Up @@ -558,11 +599,49 @@ def plot_spectrum(boxes,
else:
ax1.set_title(title, y=1.02, fontsize=15)

handles, _ = ax1.get_legend_handles_labels()
handles, labels = ax1.get_legend_handles_labels()

if handles and legend is not None:
if isinstance(legend, (str, tuple)):
if isinstance(legend, list):
model_handles = []
data_handles = []

model_labels = []
data_labels = []

for i, item in enumerate(handles):
if isinstance(item, mpl.lines.Line2D):
model_handles.append(item)
model_labels.append(labels[i])

elif isinstance(item, mpl.container.ErrorbarContainer):
data_handles.append(item)
data_labels.append(labels[i])

else:
warnings.warn(f'The object type {item} is not implemented for the legend.')

if legend[0] is not None:
if isinstance(legend[0], (str, tuple)):
leg_1 = ax1.legend(model_handles, model_labels, loc=legend[0], fontsize=10., frameon=False)
else:
leg_1 = ax1.legend(model_handles, model_labels, **legend[0])

else:
leg_1 = None

if legend[1] is not None:
if isinstance(legend[1], (str, tuple)):
leg_2 = ax1.legend(data_handles, data_labels, loc=legend[1], fontsize=8, frameon=False)
else:
leg_2 = ax1.legend(data_handles, data_labels, **legend[1])

if leg_1 is not None:
ax1.add_artist(leg_1)

elif isinstance(legend, (str, tuple)):
ax1.legend(loc=legend, fontsize=8, frameon=False)

else:
ax1.legend(**legend)

Expand Down

0 comments on commit 32a65b8

Please sign in to comment.