Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "Modifies QuickPlot.plot ..." #4622

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
## Features
- Added two more submodels (options) for the SEI: Lars von Kolzenberg (2020) model and Tunneling Limit model ([#4394](https://github.com/pybamm-team/PyBaMM/pull/4394))

## Bug fixes

- Reverted modifications to quickplot from
[#4529](https://github.com/pybamm-team/PyBaMM/pull/4529) which caused
issues with the plots displaying correct variable names. ([]())
kratman marked this conversation as resolved.
Show resolved Hide resolved

## Breaking changes

- Double-layer SEI models have been removed (with the corresponding parameters). All models assume now a single SEI layer. ([#4470](https://github.com/pybamm-team/PyBaMM/pull/4470))
Expand Down
210 changes: 115 additions & 95 deletions src/pybamm/plotting/quick_plot.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#
# Class for quick plotting of variables from models
#
from __future__ import annotations
import os
import numpy as np
import pybamm
Expand Down Expand Up @@ -480,24 +479,24 @@ def reset_axis(self):
): # pragma: no cover
raise ValueError(f"Axis limits cannot be NaN for variables '{key}'")

def plot(self, t: float | list[float], dynamic: bool = False):
def plot(self, t, dynamic=False):
"""Produces a quick plot with the internal states at time t.

Parameters
----------
t : float or list of float
Dimensional time (in 'time_units') at which to plot. Can be a single time or a list of times.
t : float
Dimensional time (in 'time_units') at which to plot.
dynamic : bool, optional
Determine whether to allocate space for a slider at the bottom of the plot when generating a dynamic plot.
If True, creates a dynamic plot with a slider.
"""

plt = import_optional_dependency("matplotlib.pyplot")
gridspec = import_optional_dependency("matplotlib.gridspec")
cm = import_optional_dependency("matplotlib", "cm")
colors = import_optional_dependency("matplotlib", "colors")

if not isinstance(t, list):
t = [t]

t_in_seconds = t * self.time_scaling_factor
self.fig = plt.figure(figsize=self.figsize)

self.gridspec = gridspec.GridSpec(self.n_rows, self.n_cols)
Expand All @@ -509,11 +508,6 @@ def plot(self, t: float | list[float], dynamic: bool = False):
# initialize empty handles, to be created only if the appropriate plots are made
solution_handles = []

# Generate distinct colors for each time point
time_colors = plt.cm.coolwarm(
np.linspace(0, 1, len(t))
) # Use a colormap for distinct colors

for k, (key, variable_lists) in enumerate(self.variables.items()):
ax = self.fig.add_subplot(self.gridspec[k])
self.axes.add(key, ax)
Expand All @@ -524,17 +518,19 @@ def plot(self, t: float | list[float], dynamic: bool = False):
ax.xaxis.set_major_locator(plt.MaxNLocator(3))
self.plots[key] = defaultdict(dict)
variable_handles = []

# Set labels for the first subplot only (avoid repetition)
if variable_lists[0][0].dimensions == 0:
# 0D plot: plot as a function of time, indicating multiple times with lines
# 0D plot: plot as a function of time, indicating time t with a line
ax.set_xlabel(f"Time [{self.time_unit}]")
for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
linestyle = (
self.linestyles[i]
if len(variable_list) == 1
else self.linestyles[j]
)
if len(variable_list) == 1:
# single variable -> use linestyle to differentiate model
linestyle = self.linestyles[i]
else:
# multiple variables -> use linestyle to differentiate
# variables (color differentiates models)
linestyle = self.linestyles[j]
full_t = self.ts_seconds[i]
(self.plots[key][i][j],) = ax.plot(
full_t / self.time_scaling_factor,
Expand All @@ -546,104 +542,128 @@ def plot(self, t: float | list[float], dynamic: bool = False):
solution_handles.append(self.plots[key][i][0])
y_min, y_max = ax.get_ylim()
ax.set_ylim(y_min, y_max)

# Add vertical lines for each time in the list, using different colors for each time
for idx, t_single in enumerate(t):
t_in_seconds = t_single * self.time_scaling_factor
(self.time_lines[key],) = ax.plot(
[
t_in_seconds / self.time_scaling_factor,
t_in_seconds / self.time_scaling_factor,
],
[y_min, y_max],
"--", # Dashed lines
lw=1.5,
color=time_colors[idx], # Different color for each time
label=f"t = {t_single:.2f} {self.time_unit}",
)
ax.legend()

(self.time_lines[key],) = ax.plot(
[
t_in_seconds / self.time_scaling_factor,
t_in_seconds / self.time_scaling_factor,
],
[y_min, y_max],
"k--",
lw=1.5,
)
elif variable_lists[0][0].dimensions == 1:
# 1D plot: plot as a function of x at different times
# 1D plot: plot as a function of x at time t
# Read dictionary of spatial variables
spatial_vars = self.spatial_variable_dict[key]
spatial_var_name = next(iter(spatial_vars.keys()))
ax.set_xlabel(f"{spatial_var_name} [{self.spatial_unit}]")

for idx, t_single in enumerate(t):
t_in_seconds = t_single * self.time_scaling_factor

for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
linestyle = (
self.linestyles[i]
if len(variable_list) == 1
else self.linestyles[j]
)
(self.plots[key][i][j],) = ax.plot(
self.first_spatial_variable[key],
variable(t_in_seconds, **spatial_vars),
color=time_colors[idx], # Different color for each time
linestyle=linestyle,
label=f"t = {t_single:.2f} {self.time_unit}", # Add time label
zorder=10,
)
variable_handles.append(self.plots[key][0][j])
solution_handles.append(self.plots[key][i][0])

# Add a legend to indicate which plot corresponds to which time
ax.legend()

ax.set_xlabel(
f"{spatial_var_name} [{self.spatial_unit}]",
)
for i, variable_list in enumerate(variable_lists):
for j, variable in enumerate(variable_list):
if len(variable_list) == 1:
# single variable -> use linestyle to differentiate model
linestyle = self.linestyles[i]
else:
# multiple variables -> use linestyle to differentiate
# variables (color differentiates models)
linestyle = self.linestyles[j]
(self.plots[key][i][j],) = ax.plot(
self.first_spatial_variable[key],
variable(t_in_seconds, **spatial_vars),
color=self.colors[i],
linestyle=linestyle,
zorder=10,
)
variable_handles.append(self.plots[key][0][j])
solution_handles.append(self.plots[key][i][0])
# add lines for boundaries between subdomains
for boundary in variable_lists[0][0].internal_boundaries:
boundary_scaled = boundary * self.spatial_factor
ax.axvline(boundary_scaled, color="0.5", lw=1, zorder=0)
elif variable_lists[0][0].dimensions == 2:
# 2D plot: superimpose plots at different times
# Read dictionary of spatial variables
spatial_vars = self.spatial_variable_dict[key]
# there can only be one entry in the variable list
variable = variable_lists[0][0]

for t_single in t:
t_in_seconds = t_single * self.time_scaling_factor
# different order based on whether the domains are x-r, x-z or y-z, etc
if self.x_first_and_y_second[key] is False:
x_name = list(spatial_vars.keys())[1][0]
y_name = next(iter(spatial_vars.keys()))[0]
x = self.second_spatial_variable[key]
y = self.first_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars)
else:
x_name = next(iter(spatial_vars.keys()))[0]
y_name = list(spatial_vars.keys())[1][0]
x = self.first_spatial_variable[key]
y = self.second_spatial_variable[key]
var = variable(t_in_seconds, **spatial_vars).T

ax.set_xlabel(
f"{next(iter(spatial_vars.keys()))[0]} [{self.spatial_unit}]"
)
ax.set_ylabel(
f"{list(spatial_vars.keys())[1][0]} [{self.spatial_unit}]"
ax.set_xlabel(f"{x_name} [{self.spatial_unit}]")
ax.set_ylabel(f"{y_name} [{self.spatial_unit}]")
vmin, vmax = self.variable_limits[key]
# store the plot and the var data (for testing) as cant access
# z data from QuadMesh or QuadContourSet object
if self.is_y_z[key] is True:
self.plots[key][0][0] = ax.pcolormesh(
x,
y,
var,
vmin=vmin,
vmax=vmax,
shading=self.shading,
)
vmin, vmax = self.variable_limits[key]

# Use contourf and colorbars to represent the values
contour_plot = ax.contourf(
x, y, var, levels=100, vmin=vmin, vmax=vmax, cmap="coolwarm"
else:
self.plots[key][0][0] = ax.contourf(
x, y, var, levels=100, vmin=vmin, vmax=vmax
)
self.plots[key][0][0] = contour_plot
self.colorbars[key] = self.fig.colorbar(contour_plot, ax=ax)

self.plots[key][0][1] = var

ax.set_title(f"t = {t_single:.2f} {self.time_unit}")
self.plots[key][0][1] = var
if vmin is None and vmax is None:
vmin = ax_min(var)
vmax = ax_max(var)
self.colorbars[key] = self.fig.colorbar(
cm.ScalarMappable(colors.Normalize(vmin=vmin, vmax=vmax)),
ax=ax,
)
# Set either y label or legend entries
if len(key) == 1:
title = split_long_string(key[0])
ax.set_title(title, fontsize="medium")
else:
ax.legend(
variable_handles,
[split_long_string(s, 6) for s in key],
bbox_to_anchor=(0.5, 1),
loc="lower center",
)

# Set global legend if there are multiple models
# Set global legend
if len(self.labels) > 1:
fig_legend = self.fig.legend(
solution_handles, self.labels, loc="lower right"
)
# Get the position of the top of the legend in relative figure units
# There may be a better way ...
try:
legend_top_inches = fig_legend.get_window_extent(
renderer=self.fig.canvas.get_renderer()
).get_points()[1, 1]
fig_height_inches = (self.fig.get_size_inches() * self.fig.dpi)[1]
legend_top = legend_top_inches / fig_height_inches
except AttributeError: # pragma: no cover
# When testing the examples we set the matplotlib backend to "Template"
# which means that the above code doesn't work. Since this is just for
# that particular test we can just skip it
legend_top = 0
else:
fig_legend = None
legend_top = 0

# Fix layout for sliders if dynamic
# Fix layout
if dynamic:
slider_top = 0.05
else:
slider_top = 0
bottom = max(
fig_legend.get_window_extent(
renderer=self.fig.canvas.get_renderer()
).get_points()[1, 1]
if fig_legend
else 0,
slider_top,
)
bottom = max(legend_top, slider_top)
self.gridspec.tight_layout(self.fig, rect=[0, bottom, 1, 1])

def dynamic_plot(self, show_plot=True, step=None):
Expand Down
35 changes: 0 additions & 35 deletions tests/unit/test_plotting/test_quick_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,41 +252,6 @@ def test_simple_ode_model(self, solver):
solution, ["a", "b broadcasted"], variable_limits="bad variable limits"
)

# Test with a list of times
# Test for a 0D variable
quick_plot = pybamm.QuickPlot(solution, ["a"])
quick_plot.plot(t=[0, 1, 2])
assert len(quick_plot.plots[("a",)]) == 1

# Test for a 1D variable
quick_plot = pybamm.QuickPlot(solution, ["c broadcasted"])

time_list = [0.5, 1.5, 2.5]
quick_plot.plot(time_list)

ax = quick_plot.fig.axes[0]
lines = ax.get_lines()

variable_key = ("c broadcasted",)
if variable_key in quick_plot.variables:
num_variables = len(quick_plot.variables[variable_key][0])
else:
raise KeyError(
f"'{variable_key}' is not in quick_plot.variables. Available keys: "
+ str(quick_plot.variables.keys())
)

expected_lines = len(time_list) * num_variables

assert (
len(lines) == expected_lines
), f"Expected {expected_lines} superimposed lines, but got {len(lines)}"

# Test for a 2D variable
quick_plot = pybamm.QuickPlot(solution, ["2D variable"])
quick_plot.plot(t=[0, 1, 2])
assert len(quick_plot.plots[("2D variable",)]) == 1

# Test errors
with pytest.raises(ValueError, match="Mismatching variable domains"):
pybamm.QuickPlot(solution, [["a", "b broadcasted"]])
Expand Down
Loading