Skip to content

Commit

Permalink
Use the global config object in ssp_plot_traces.py
Browse files Browse the repository at this point in the history
  • Loading branch information
claudiodsf committed Jul 9, 2024
1 parent e392ce1 commit 8371828
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 20 deletions.
2 changes: 1 addition & 1 deletion sourcespec2/source_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def main():

from .ssp_plot_spectra import plot_spectra
from .ssp_plot_traces import plot_traces
plot_traces(config, proc_st, ncols=2, block=False)
plot_traces(proc_st, ncols=2, block=False)
plot_spectra(config, spec_st, ncols=1, stack_plots=True)


Expand Down
2 changes: 1 addition & 1 deletion sourcespec2/source_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main():
spec_st, specnoise_st, weight_st = build_spectra(proc_st)

from .ssp_plot_traces import plot_traces
plot_traces(config, proc_st)
plot_traces(proc_st)

# Spectral inversion
from .ssp_inversion import spectral_inversion
Expand Down
149 changes: 131 additions & 18 deletions sourcespec2/ssp_plot_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from matplotlib import patches
import matplotlib.patheffects as PathEffects
from matplotlib.ticker import ScalarFormatter as sf
from .config import config
from .savefig import savefig
from ._version import get_versions
logger = logging.getLogger(__name__.rsplit('.', maxsplit=1)[-1])
Expand All @@ -40,8 +41,20 @@ def _set_format(self, vmin=None, vmax=None):
phase_label_color = {'P': 'black', 'S': 'black'}


def _nplots(config, st, maxlines, ncols):
"""Determine the number of lines and columns of the plot."""
def _nplots(st, maxlines, ncols):
"""
Determine the number of lines and columns of the plot.
:param st: Stream of traces.
:type st: :class:`obspy.core.stream.Stream`
:param maxlines: Maximum number of lines.
:type maxlines: int
:param ncols: Number of columns.
:type ncols: int
:return: Number of lines and columns.
:rtype: tuple of int
"""
# Remove the channel letter to determine the number of plots
if config.plot_traces_ignored:
nplots = len({tr.id[:-1] for tr in st})
Expand All @@ -55,7 +68,19 @@ def _nplots(config, st, maxlines, ncols):
return nlines, ncols


def _make_fig(config, nlines, ncols):
def _make_fig(nlines, ncols):
"""
Create a figure with a number of subplots.
:param nlines: Number of lines.
:type nlines: int
:param ncols: Number of columns.
:type ncols: int
:return: Figure and axes.
:rtype: tuple of :class:`matplotlib.figure.Figure` and list of
:class:`matplotlib.axes.Axes`
"""
figsize = (16, 9) if nlines <= 3 else (16, 18)
# high dpi needed to rasterize png
# vector formats (pdf, svg) do not have rasters
Expand Down Expand Up @@ -127,7 +152,15 @@ def _make_fig(config, nlines, ncols):
BBOX = None


def _savefig(config, figures, force_numbering=False):
def _savefig(figures, force_numbering=False):
"""
Save figures to file.
:param figures: Figures to save.
:type figures: list of :class:`matplotlib.figure.Figure`
:param force_numbering: Force figure numbering.
:type force_numbering: bool
"""
global BBOX # pylint: disable=global-statement
evid = config.event.event_id
figfile_base = os.path.join(config.options.outdir, f'{evid}.traces.')
Expand Down Expand Up @@ -165,7 +198,24 @@ def _savefig(config, figures, force_numbering=False):


def _plot_min_max(ax, x_vals, y_vals, linewidth, color, alpha, zorder):
"""Quick and dirty plot using less points. Useful for vector plotting."""
"""
Quick and dirty plot using less points. Useful for vector plotting.
:param ax: Axes object.
:type ax: :class:`matplotlib.axes.Axes`
:param x_vals: X values.
:type x_vals: :class:`numpy.ndarray`
:param y_vals: Y values.
:type y_vals: :class:`numpy.ndarray`
:param linewidth: Line width.
:type linewidth: float
:param color: Line color.
:type color: str
:param alpha: Line alpha.
:type alpha: float
:param zorder: Z-order.
:type zorder: int
"""
ax_width_in_pixels = int(np.ceil(ax.bbox.width))
nsamples = len(x_vals)
samples_per_pixel = int(np.ceil(nsamples / ax_width_in_pixels))
Expand All @@ -190,7 +240,15 @@ def _plot_min_max(ax, x_vals, y_vals, linewidth, color, alpha, zorder):


def _freq_string(freq):
"""Return a string representing the rounded frequency."""
"""
Return a string representing the rounded frequency.
:param freq: Frequency.
:type freq: float
:return: Frequency string.
:rtype: str
"""
# int or float notation for frequencies between 0.01 and 100
if 1e-2 <= freq <= 1e2:
int_freq = int(round(freq))
Expand All @@ -209,7 +267,25 @@ def _freq_string(freq):
)


def _plot_trace(config, trace, ntraces, tmax, ax, trans, trans3, path_effects):
def _plot_trace(trace, ntraces, tmax, ax, trans, trans3, path_effects):
"""
Plot a trace.
:param trace: Trace to plot.
:type trace: :class:`obspy.core.trace.Trace`
:param ntraces: Number of traces.
:type ntraces: int
:param tmax: Maximum value of the trace.
:type tmax: float
:param ax: Axes object.
:type ax: :class:`matplotlib.axes.Axes`
:param trans: Transformation for plotting phase labels.
:type trans: :class:`matplotlib.transforms.BboxTransformTo`
:param trans3: Transformation for plotting station info.
:type trans3: :class:`matplotlib.transforms.BboxTransformTo`
:param path_effects: Path effects for text.
:type path_effects: :class:`matplotlib
"""
# Origin and height to draw vertical patches for noise and signal windows
rectangle_patch_origin = 0
rectangle_patch_height = 1
Expand Down Expand Up @@ -294,6 +370,16 @@ def _plot_trace(config, trace, ntraces, tmax, ax, trans, trans3, path_effects):


def _add_station_info_text(trace, ax, path_effects):
"""
Add station information text to the plot.
:param trace: Trace.
:type trace: :class:`obspy.core.trace.Trace`
:param ax: Axes object.
:type ax: :class:`matplotlib.axes.Axes`
:param path_effects: Path effects for text.
:type path_effects: :class:`matplotlib.patheffects`
"""
with contextlib.suppress(AttributeError):
if ax.has_station_info_text:
return
Expand All @@ -315,7 +401,16 @@ def _add_station_info_text(trace, ax, path_effects):


def _add_labels(axes, plotn, ncols):
"""Add xlabels to the last row of plots."""
"""
Add xlabels to the last row of plots.
:param axes: Axes objects.
:type axes: list of :class:`matplotlib.axes.Axes`
:param plotn: Number of plots.
:type plotn: int
:param ncols: Number of columns.
:type ncols: int
"""
# A row has "ncols" plots: the last row is from `plotn-ncols` to `plotn`
n0 = max(plotn - ncols, 0)
for ax in axes[n0:plotn]:
Expand All @@ -324,14 +419,25 @@ def _add_labels(axes, plotn, ncols):


def _set_ylim(axes):
"""Set symmetric ylim."""
"""
Set symmetric ylim.
:param axes: Axes objects.
:type axes: list of :class:`matplotlib
"""
for ax in axes:
ylim = ax.get_ylim()
ymax = np.max(np.abs(ylim))
ax.set_ylim(-ymax, ymax)


def _trim_traces(config, st):
def _trim_traces(st):
"""
Trim traces to the time window of interest.
:param st: Stream of traces.
:type st: :class:`obspy.core.stream.Stream`
"""
for trace in st:
t1 = trace.stats.arrivals['N1'][1]
t2 = trace.stats.arrivals['S2'][1] + 2 * config.win_length
Expand All @@ -342,11 +448,18 @@ def _trim_traces(config, st):
trace.stats.time_offset = trace.stats.starttime - min_starttime


def plot_traces(config, st, ncols=None, block=True):
def plot_traces(st, ncols=None, block=True):
"""
Plot traces in the original instrument unit (velocity or acceleration).
Display to screen and/or save to file.
:param st: Stream of traces.
:type st: :class:`obspy.core.stream.Stream`
:param ncols: Number of columns in the plot (autoset if None).
:type ncols: int
:param block: If True, block execution until the plot window is closed.
:type block: bool
"""
# Check config, if we need to plot at all
if not config.plot_show and not config.plot_save:
Expand All @@ -358,8 +471,8 @@ def plot_traces(config, st, ncols=None, block=True):
ntr = len({t.id[:-1] for t in st})
ncols = 4 if ntr > 6 else 3

nlines, ncols = _nplots(config, st, config.plot_traces_maxrows, ncols)
fig, axes = _make_fig(config, nlines, ncols)
nlines, ncols = _nplots(st, config.plot_traces_maxrows, ncols)
fig, axes = _make_fig(nlines, ncols)
figures = [fig]
# Path effect to contour text in white
path_effects = [PathEffects.withStroke(linewidth=3, foreground='white')]
Expand Down Expand Up @@ -394,8 +507,8 @@ def plot_traces(config, st, ncols=None, block=True):
config.plot_save_format != 'pdf_multipage'
):
# save figure here to free up memory
_savefig(config, figures, force_numbering=True)
fig, axes = _make_fig(config, nlines, ncols)
_savefig(figures, force_numbering=True)
fig, axes = _make_fig(nlines, ncols)
figures.append(fig)
plotn = 1
ax = axes[plotn - 1]
Expand All @@ -419,13 +532,13 @@ def plot_traces(config, st, ncols=None, block=True):
transforms.blended_transform_factory(ax.transAxes, ax.transData)
trans3 = transforms.offset_copy(trans2, fig=fig, x=0, y=0.1)

_trim_traces(config, st_sel)
_trim_traces(st_sel)
max_values = [abs(tr.max()) for tr in st_sel]
ntraces = len(max_values)
tmax = max(max_values)
for trace in st_sel:
_plot_trace(
config, trace, ntraces, tmax, ax, trans, trans3, path_effects)
trace, ntraces, tmax, ax, trans, trans3, path_effects)

_set_ylim(axes)
# Add labels for the last figure
Expand All @@ -437,4 +550,4 @@ def plot_traces(config, st, ncols=None, block=True):
if config.plot_show:
plt.show(block=block)
if config.plot_save:
_savefig(config, figures)
_savefig(figures)

0 comments on commit 8371828

Please sign in to comment.