diff --git a/CHANGELOG.md b/CHANGELOG.md index c1cf4b91ed..afbc5073b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,7 +3,6 @@ ## Bug fixes - Fixed bug in calculation of theoretical energy that made it very slow ([#3506](https://github.com/pybamm-team/PyBaMM/pull/3506)) -- Fixed a bug where the JaxSolver would fails when using GPU support with no input parameters ([#3423](https://github.com/pybamm-team/PyBaMM/pull/3423)) # [v23.9rc0](https://github.com/pybamm-team/PyBaMM/tree/v23.9rc0) - 2023-10-31 @@ -24,6 +23,7 @@ ## Bug fixes +- Fixed a bug where the JaxSolver would fail when using GPU support with no input parameters ([#3423](https://github.com/pybamm-team/PyBaMM/pull/3423)) - Fixed a bug where empty lists passed to QuickPlot resulted in an IndexError and did not return a meaningful error message ([#3359](https://github.com/pybamm-team/PyBaMM/pull/3359)) - Fixed a bug where there was a missing thermal conductivity in the thermal pouch cell models ([#3330](https://github.com/pybamm-team/PyBaMM/pull/3330)) - Fixed a bug that caused incorrect results of “{Domain} electrode thickness change [m]” due to the absence of dimension for the variable `electrode_thickness_change`([#3329](https://github.com/pybamm-team/PyBaMM/pull/3329)). @@ -62,7 +62,7 @@ - Added option to use an empirical hysteresis model for the diffusivity and exchange-current density ([#3194](https://github.com/pybamm-team/PyBaMM/pull/3194)) - Double-layer capacity can now be provided as a function of temperature ([#3174](https://github.com/pybamm-team/PyBaMM/pull/3174)) - `pybamm_install_jax` is deprecated. It is now replaced with `pip install pybamm[jax]` ([#3163](https://github.com/pybamm-team/PyBaMM/pull/3163)) -- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044)) +- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044), [#3475](https://github.com/pybamm-team/PyBaMM/pull/3475)) # [v23.5](https://github.com/pybamm-team/PyBaMM/tree/v23.5) - 2023-06-18 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bec0fee02a..0a5b17bcb0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -100,21 +100,52 @@ On the other hand... We _do_ want to compare several tools, to generate document Only 'core pybamm' is installed by default. The others have to be specified explicitly when running the installation command. -### Matplotlib +### Managing Optional Dependencies and Their Imports -We use Matplotlib in PyBaMM, but with two caveats: +PyBaMM utilizes optional dependencies to allow users to choose which additional libraries they want to use. Managing these optional dependencies and their imports is essential to provide flexibility to PyBaMM users. -First, Matplotlib should only be used in plotting methods, and these should _never_ be called by other PyBaMM methods. So users who don't like Matplotlib will not be forced to use it in any way. Use in notebooks is OK and encouraged. +PyBaMM provides a utility function `have_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it: -Second, Matplotlib should never be imported at the module level, but always inside methods. For example: +Optional dependencies should never be imported at the module level, but always inside methods. For example: ``` -def plot_great_things(self, x, y, z): - import matplotlib.pyplot as pl +def use_pybtex(x,y,z): + pybtex = have_optional_dependency("pybtex") ... ``` -This allows people to (1) use PyBaMM without ever importing Matplotlib and (2) configure Matplotlib's back-end in their scripts, which _must_ be done before e.g. `pyplot` is first imported. +While importing a specific module instead of an entire package/library: + +```python +def use_parse_file(x, y, z): + parse_file = have_optional_dependency("pybtex.database", "parse_file") + ... +``` + +This allows people to (1) use PyBaMM without importing optional dependencies by default and (2) configure module-dependent functionalities in their scripts, which _must_ be done before e.g. `print_citations` method is first imported. + +**Writing Tests for Optional Dependencies** + +Whenever a new optional dependency is added for optional functionality, it is recommended to write a corresponding unit test in `test_util.py`. This ensures that an error is raised upon the absence of said dependency. Here's an example: + +```python +from tests import TestCase +import pybamm + + +class TestUtil(TestCase): + def test_optional_dependency(self): + # Test that an error is raised when pybtex is not available + with self.assertRaisesRegex( + ModuleNotFoundError, "Optional dependency pybtex is not available" + ): + sys.modules["pybtex"] = None + pybamm.function_using_pybtex(x, y, z) + + # Test that the function works when pybtex is available + sys.modules["pybtex"] = pybamm.util.have_optional_dependency("pybtex") + pybamm.function_using_pybtex(x, y, z) +``` ## Testing @@ -266,7 +297,6 @@ This also means that, if you can't fix the bug yourself, it will be much easier ``` This will start the debugger at the point where the `ValueError` was raised, and allow you to investigate further. Sometimes, it is more informative to put the try-except block further up the call stack than exactly where the error is raised. - 2. Warnings. If functions are raising warnings instead of errors, it can be hard to pinpoint where this is coming from. Here, you can use the `warnings` module to convert warnings to errors: ```python @@ -276,7 +306,6 @@ This also means that, if you can't fix the bug yourself, it will be much easier ``` Then you can use a try-except block, as in a., but with, for example, `RuntimeWarning` instead of `ValueError`. - 3. Stepping through the expression tree. Most calls in PyBaMM are operations on [expression trees](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/expression_tree/expression-tree.ipynb). To view an expression tree in ipython, you can use the `render` command: ```python @@ -284,11 +313,8 @@ This also means that, if you can't fix the bug yourself, it will be much easier ``` You can then step through the expression tree, using the `children` attribute, to pinpoint exactly where a bug is coming from. For example, if `expression_tree.jac(y)` is failing, you can check `expression_tree.children[0].jac(y)`, then `expression_tree.children[0].children[0].jac(y)`, etc. - 3. To isolate whether a bug is in a model, its Jacobian or its simplified version, you can set the `use_jacobian` and/or `use_simplify` attributes of the model to `False` (they are both `True` by default for most models). - 4. If a model isn't giving the answer you expect, you can try comparing it to other models. For example, you can investigate parameter limits in which two models should give the same answer by setting some parameters to be small or zero. The `StandardOutputComparison` class can be used to compare some standard outputs from battery models. - 5. To get more information about what is going on under the hood, and hence understand what is causing the bug, you can set the [logging](https://realpython.com/python-logging/) level to `DEBUG` by adding the following line to your test or script: ```python3 diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index 6338323e79..2cf61093be 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -66,6 +66,7 @@ Package Minimum support `SciPy `__ 2.8.2 `CasADi `__ 3.6.0 `Xarray `__ 2023.04.0 +`Anytree `__ 2.4.3 ================================================================ ========================== .. _install.optional_dependencies: diff --git a/pybamm/__init__.py b/pybamm/__init__.py index 9aa1ca79a0..07d8a1c0ea 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -47,13 +47,13 @@ get_parameters_filepath, have_jax, install_jax, + have_optional_dependency, is_jax_compatible, get_git_commit_info, ) from .logger import logger, set_logging_level, get_new_logger from .settings import settings from .citations import Citations, citations, print_citations - # # Classes for the Expression Tree # diff --git a/pybamm/citations.py b/pybamm/citations.py index da619062e0..b72262989b 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -6,10 +6,8 @@ import pybamm import os import warnings -import pybtex from sys import _getframe -from pybtex.database import parse_file, parse_string, Entry -from pybtex.scanner import PybtexError +from pybamm.util import have_optional_dependency class Citations: @@ -76,6 +74,7 @@ def read_citations(self): """Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited by passing a BibTeX citation to :meth:`register`. """ + parse_file = have_optional_dependency("pybtex.database", "parse_file") citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib") bib_data = parse_file(citations_file, bib_format="bibtex") for key, entry in bib_data.entries.items(): @@ -86,6 +85,7 @@ def _add_citation(self, key, entry): previous entry is overwritten """ + Entry = have_optional_dependency("pybtex.database", "Entry") # Check input types are correct if not isinstance(key, str) or not isinstance(entry, Entry): raise TypeError() @@ -151,6 +151,8 @@ def _parse_citation(self, key): key: str A BibTeX formatted citation """ + PybtexError = have_optional_dependency("pybtex.scanner", "PybtexError") + parse_string = have_optional_dependency("pybtex.database", "parse_string") try: # Parse string as a bibtex citation, and check that a citation was found bib_data = parse_string(key, bib_format="bibtex") @@ -217,6 +219,7 @@ def print(self, filename=None, output_format="text", verbose=False): """ # Parse citations that were not known keys at registration, but do not # fail if they cannot be parsed + pybtex = have_optional_dependency("pybtex") try: for key in self._unknown_citations: self._parse_citation(key) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index a9141041b3..2736886d95 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -2,10 +2,10 @@ # NumpyArray class # import numpy as np -import sympy from scipy.sparse import csr_matrix, issparse import pybamm +from pybamm.util import have_optional_dependency class Array(pybamm.Symbol): @@ -125,6 +125,7 @@ def is_constant(self): def to_equation(self): """Returns the value returned by the node when evaluated.""" + sympy = have_optional_dependency("sympy") entries_list = self.entries.tolist() return sympy.Array(entries_list) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index bde9a17271..bfb31596e6 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -4,11 +4,11 @@ import numbers import numpy as np -import sympy from scipy.sparse import csr_matrix, issparse import functools import pybamm +from pybamm.util import have_optional_dependency def _preprocess_binary(left, right): @@ -147,6 +147,7 @@ def _sympy_operator(self, left, right): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -323,6 +324,7 @@ def _binary_evaluate(self, left, right): def _sympy_operator(self, left, right): """Override :meth:`pybamm.BinaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") left = sympy.Matrix(left) right = sympy.Matrix(right) return left * right @@ -626,6 +628,7 @@ def _binary_new_copy(self, left, right): def _sympy_operator(self, left, right): """Override :meth:`pybamm.BinaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.Min(left, right) @@ -662,6 +665,7 @@ def _binary_new_copy(self, left, right): def _sympy_operator(self, left, right): """Override :meth:`pybamm.BinaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.Max(left, right) diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 2185a0fad6..1c82aff122 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -5,10 +5,10 @@ from collections import defaultdict import numpy as np -import sympy from scipy.sparse import issparse, vstack import pybamm +from pybamm.util import have_optional_dependency class Concatenation(pybamm.Symbol): @@ -135,6 +135,7 @@ def is_constant(self): def _sympy_operator(self, *children): """Apply appropriate SymPy operators.""" + sympy = have_optional_dependency("sympy") self.concat_latex = tuple(map(sympy.latex, children)) if self.print_name is not None: diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 80c2848ad9..0c7e98b508 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -3,13 +3,11 @@ # import numbers -import autograd import numpy as np -import sympy from scipy import special import pybamm - +from pybamm.util import have_optional_dependency class Function(pybamm.Symbol): """ @@ -96,6 +94,7 @@ def _function_diff(self, children, idx): Derivative with respect to child number 'idx'. See :meth:`pybamm.Symbol._diff()`. """ + autograd = have_optional_dependency("autograd") # Store differentiated function, needed in case we want to convert to CasADi if self.derivative == "autograd": return Function( @@ -202,6 +201,7 @@ def _sympy_operator(self, child): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -250,6 +250,7 @@ def _function_new_copy(self, children): def _sympy_operator(self, child): """Apply appropriate SymPy operators.""" + sympy = have_optional_dependency("sympy") class_name = self.__class__.__name__.lower() sympy_function = getattr(sympy, class_name) return sympy_function(child) @@ -267,6 +268,7 @@ def _function_diff(self, children, idx): def _sympy_operator(self, child): """Override :meth:`pybamm.Function._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.asinh(child) @@ -287,6 +289,7 @@ def _function_diff(self, children, idx): def _sympy_operator(self, child): """Override :meth:`pybamm.Function._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.atan(child) diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index efeb73f8bc..2f30da9a5e 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -1,9 +1,8 @@ # # IndependentVariable class # -import sympy - import pybamm +from pybamm.util import have_optional_dependency KNOWN_COORD_SYS = ["cartesian", "cylindrical polar", "spherical polar"] @@ -44,6 +43,7 @@ def _jac(self, variable): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -77,6 +77,7 @@ def _evaluate_for_shape(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") return sympy.Symbol("t") diff --git a/pybamm/expression_tree/operations/latexify.py b/pybamm/expression_tree/operations/latexify.py index 67e0199656..9f2949069e 100644 --- a/pybamm/expression_tree/operations/latexify.py +++ b/pybamm/expression_tree/operations/latexify.py @@ -5,10 +5,9 @@ import re import warnings -import sympy - import pybamm from pybamm.expression_tree.printing.sympy_overrides import custom_print_func +from pybamm.util import have_optional_dependency def get_rng_min_max_name(rng, min_or_max): @@ -88,6 +87,7 @@ def _get_bcs_displays(self, var): Returns a list of boundary condition equations with ranges in front of the equations. """ + sympy = have_optional_dependency("sympy") bcs_eqn_list = [] bcs = self.model.boundary_conditions.get(var, None) @@ -118,6 +118,7 @@ def _get_bcs_displays(self, var): def _get_param_var(self, node): """Returns a list of parameters and a list of variables.""" + sympy = have_optional_dependency("sympy") param_list = [] var_list = [] dfs_nodes = [node] @@ -160,6 +161,7 @@ def _get_param_var(self, node): return param_list, var_list def latexify(self, output_variables=None): + sympy = have_optional_dependency("sympy") # Voltage is the default output variable if it exists if output_variables is None: if "Voltage [V]" in self.model.variables: diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index 10addae464..eebe77ad2f 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -5,9 +5,9 @@ import sys import numpy as np -import sympy import pybamm +from pybamm.util import have_optional_dependency class Parameter(pybamm.Symbol): @@ -44,6 +44,7 @@ def is_constant(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -217,6 +218,7 @@ def _evaluate_for_shape(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: diff --git a/pybamm/expression_tree/printing/sympy_overrides.py b/pybamm/expression_tree/printing/sympy_overrides.py index a96aa19729..1898822ea8 100644 --- a/pybamm/expression_tree/printing/sympy_overrides.py +++ b/pybamm/expression_tree/printing/sympy_overrides.py @@ -8,11 +8,9 @@ class CustomPrint(LatexPrinter): """Override SymPy methods to match PyBaMM's requirements""" - def _print_Derivative(self, expr): """Override :meth:`sympy.printing.latex.LatexPrinter._print_Derivative`""" eqn = super()._print_Derivative(expr) - if getattr(expr, "force_partial", False) and "partial" not in eqn: var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", eqn)[0] eqn = eqn.replace(var1, "\partial").replace(var2, "\partial") diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index 3149bf7bee..0209c02a8e 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -2,10 +2,9 @@ # Scalar class # import numpy as np -import sympy import pybamm - +from pybamm.util import have_optional_dependency class Scalar(pybamm.Symbol): """ @@ -70,6 +69,7 @@ def is_constant(self): def to_equation(self): """Returns the value returned by the node when evaluated.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 5d28884ed5..8f1608e7ba 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -3,14 +3,12 @@ # import numbers -import anytree import numpy as np -import sympy -from anytree.exporter import DotExporter from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property import pybamm +from pybamm.util import have_optional_dependency from pybamm.expression_tree.printing.print_name import prettify_print_name DOMAIN_LEVELS = ["primary", "secondary", "tertiary", "quaternary"] @@ -442,6 +440,7 @@ def render(self): # pragma: no cover """ Print out a visual representation of the tree (this node and its children) """ + anytree = have_optional_dependency("anytree") for pre, _, node in anytree.RenderTree(self): if isinstance(node, pybamm.Scalar) and node.name != str(node.value): print("{}{} = {}".format(pre, node.name, node.value)) @@ -460,6 +459,7 @@ def visualise(self, filename): filename to output, must end in ".png" """ + DotExporter = have_optional_dependency("anytree.exporter", "DotExporter") # check that filename ends in .png. if filename[-4:] != ".png": raise ValueError("filename should end in .png") @@ -479,6 +479,7 @@ def relabel_tree(self, symbol, counter): Finds all children of a symbol and assigns them a new id so that they can be visualised properly using the graphviz output """ + anytree = have_optional_dependency("anytree") name = symbol.name if name == "div": name = "∇⋅" @@ -522,6 +523,7 @@ def pre_order(self): a b """ + anytree = have_optional_dependency("anytree") return anytree.PreOrderIter(self) def __str__(self): @@ -984,4 +986,5 @@ def print_name(self, name): self._print_name = prettify_print_name(name) def to_equation(self): + sympy = have_optional_dependency("sympy") return sympy.Symbol(str(self.name)) diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 7f9c45775c..81c3dc28c2 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -4,11 +4,9 @@ import numbers import numpy as np -import sympy from scipy.sparse import csr_matrix, issparse -from sympy.vector.operators import Divergence as sympy_Divergence -from sympy.vector.operators import Gradient as sympy_Gradient import pybamm +from pybamm.util import have_optional_dependency class UnaryOperator(pybamm.Symbol): @@ -83,6 +81,7 @@ def _sympy_operator(self, child): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -368,6 +367,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" + sympy_Gradient = have_optional_dependency("sympy.vector.operators", "Gradient") return sympy_Gradient(child) @@ -403,6 +403,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" + sympy_Divergence = have_optional_dependency("sympy.vector.operators", "Divergence") return sympy_Divergence(child) @@ -579,6 +580,7 @@ def _evaluates_on_edges(self, dimension): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.Integral(child, sympy.Symbol("xn")) @@ -889,6 +891,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") if ( self.child.domain[0] in ["negative particle", "positive particle"] and self.side == "right" diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index f9f7d94efc..0d1e1fd424 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -3,9 +3,9 @@ # import numpy as np -import sympy import numbers import pybamm +from pybamm.util import have_optional_dependency class VariableBase(pybamm.Symbol): @@ -124,6 +124,7 @@ def _evaluate_for_shape(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: diff --git a/pybamm/meshes/scikit_fem_submeshes.py b/pybamm/meshes/scikit_fem_submeshes.py index f25dce80b1..23c024dbbb 100644 --- a/pybamm/meshes/scikit_fem_submeshes.py +++ b/pybamm/meshes/scikit_fem_submeshes.py @@ -3,10 +3,10 @@ # import pybamm from .meshes import SubMesh - -import skfem import numpy as np +from pybamm.util import have_optional_dependency + class ScikitSubMesh2D(SubMesh): """ @@ -27,6 +27,7 @@ class ScikitSubMesh2D(SubMesh): """ def __init__(self, edges, coord_sys, tabs): + skfem = have_optional_dependency("skfem") self.edges = edges self.nodes = dict.fromkeys(["y", "z"]) for var in self.nodes.keys(): diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 41192dbe1f..08890757b7 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -9,7 +9,7 @@ import numpy as np import pybamm -from pybamm.expression_tree.operations.latexify import Latexify +from pybamm.util import have_optional_dependency class BaseModel: @@ -1055,14 +1055,43 @@ def generate( C.generate() def latexify(self, filename=None, newline=True, output_variables=None): - # For docstring, see pybamm.expression_tree.operations.latexify.Latexify + """ + Converts all model equations in latex. + + Parameters + ---------- + filename: str (optional) + Accepted file formats - any image format, pdf and tex + Default is None, When None returns all model equations in latex + If not None, returns all model equations in given file format. + + newline: bool (optional) + Default is True, If True, returns every equation in a new line. + If False, returns the list of all the equations. + + Load model + >>> model = pybamm.lithium_ion.SPM() + + This will returns all model equations in png + >>> model.latexify("equations.png") + + This will return all the model equations in latex + >>> model.latexify() + + This will return the list of all the model equations + >>> model.latexify(newline=False) + + This will return first five model equations + >>> model.latexify(newline=False)[1:5] + """ + sympy = have_optional_dependency("sympy") + if sympy: + from pybamm.expression_tree.operations.latexify import Latexify + return Latexify(self, filename, newline).latexify( output_variables=output_variables ) - # Set :meth:`latexify` docstring from :class:`Latexify` - latexify.__doc__ = Latexify.__doc__ - def process_parameters_and_discretise(self, symbol, parameter_values, disc): """ Process parameters and discretise a symbol using supplied parameter values diff --git a/pybamm/plotting/plot.py b/pybamm/plotting/plot.py index 19aa9dc5e0..88c8dfe442 100644 --- a/pybamm/plotting/plot.py +++ b/pybamm/plotting/plot.py @@ -3,6 +3,7 @@ # import pybamm from .quick_plot import ax_min, ax_max +from pybamm.util import have_optional_dependency def plot(x, y, ax=None, testing=False, **kwargs): @@ -25,7 +26,7 @@ def plot(x, y, ax=None, testing=False, **kwargs): Keyword arguments, passed to plt.plot """ - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib.pyplot") if not isinstance(x, pybamm.Array): raise TypeError("x must be 'pybamm.Array'") diff --git a/pybamm/plotting/plot2D.py b/pybamm/plotting/plot2D.py index 80bb5d0ee2..d4f6d31e3a 100644 --- a/pybamm/plotting/plot2D.py +++ b/pybamm/plotting/plot2D.py @@ -3,6 +3,7 @@ # import pybamm from .quick_plot import ax_min, ax_max +from pybamm.util import have_optional_dependency def plot2D(x, y, z, ax=None, testing=False, **kwargs): @@ -25,7 +26,7 @@ def plot2D(x, y, z, ax=None, testing=False, **kwargs): Whether to actually make the plot (turned off for unit tests) """ - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib.pyplot") if not isinstance(x, pybamm.Array): raise TypeError("x must be 'pybamm.Array'") diff --git a/pybamm/plotting/plot_summary_variables.py b/pybamm/plotting/plot_summary_variables.py index 6fe71518db..e50f38fddf 100644 --- a/pybamm/plotting/plot_summary_variables.py +++ b/pybamm/plotting/plot_summary_variables.py @@ -3,6 +3,7 @@ # import numpy as np import pybamm +from pybamm.util import have_optional_dependency def plot_summary_variables( @@ -25,7 +26,7 @@ def plot_summary_variables( Keyword arguments, passed to plt.subplots. """ - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib.pyplot") if isinstance(solutions, pybamm.Solution): solutions = [solutions] diff --git a/pybamm/plotting/plot_voltage_components.py b/pybamm/plotting/plot_voltage_components.py index ad0e9a8b71..a681094bea 100644 --- a/pybamm/plotting/plot_voltage_components.py +++ b/pybamm/plotting/plot_voltage_components.py @@ -3,6 +3,8 @@ # import numpy as np +from pybamm.util import have_optional_dependency + def plot_voltage_components( solution, @@ -32,7 +34,7 @@ def plot_voltage_components( Keyword arguments, passed to ax.fill_between """ - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib.pyplot") # Set a default value for alpha, the opacity kwargs_fill = {"alpha": 0.6, **kwargs_fill} diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index 5e9c9ef941..ff657ee375 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -5,6 +5,7 @@ import numpy as np import pybamm from collections import defaultdict +from pybamm.util import have_optional_dependency class LoopList(list): @@ -46,7 +47,7 @@ def split_long_string(title, max_words=None): def close_plots(): """Close all open figures""" - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib.pyplot") plt.close("all") @@ -469,9 +470,10 @@ def plot(self, t, dynamic=False): Dimensional time (in 'time_units') at which to plot. """ - import matplotlib.pyplot as plt - import matplotlib.gridspec as gridspec - from matplotlib import cm, colors + plt = have_optional_dependency("matplotlib.pyplot") + gridspec = have_optional_dependency("matplotlib.gridspec") + cm = have_optional_dependency("matplotlib", "cm") + colors = have_optional_dependency("matplotlib", "colors") t_in_seconds = t * self.time_scaling_factor self.fig = plt.figure(figsize=self.figsize) @@ -668,8 +670,8 @@ def dynamic_plot(self, testing=False, step=None): continuous_update=False, ) else: - import matplotlib.pyplot as plt - from matplotlib.widgets import Slider + plt = have_optional_dependency("matplotlib.pyplot") + Slider = have_optional_dependency("matplotlib.widgets", "Slider") # create an initial plot at time self.min_t self.plot(self.min_t, dynamic=True) @@ -773,8 +775,8 @@ def create_gif(self, number_of_images=80, duration=0.1, output_filename="plot.gi Name of the generated GIF file. """ - import imageio.v2 as imageio - import matplotlib.pyplot as plt + imageio = have_optional_dependency("imageio.v2") + plt = have_optional_dependency("matplotlib.pyplot") # time stamps at which the images/plots will be created time_array = np.linspace(self.min_t, self.max_t, num=number_of_images) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 380105d215..42bda08e31 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -8,7 +8,7 @@ import sys from functools import lru_cache from datetime import timedelta -import tqdm +from pybamm.util import have_optional_dependency def is_notebook(): @@ -726,13 +726,18 @@ def solve( # Update _solution self._solution = current_solution - for cycle_num, cycle_length in enumerate( - # tqdm is the progress bar. - tqdm.tqdm( + # check if a user has tqdm installed + if showprogress: + tqdm = have_optional_dependency("tqdm") + cycle_lengths = tqdm.tqdm( self.experiment.cycle_lengths, - disable=(not showprogress), desc="Cycling", - ), + ) + else: + cycle_lengths = self.experiment.cycle_lengths + + for cycle_num, cycle_length in enumerate( + cycle_lengths, start=1, ): logs["cycle number"] = ( diff --git a/pybamm/spatial_methods/scikit_finite_element.py b/pybamm/spatial_methods/scikit_finite_element.py index 0f0a42bbcb..2d51e16c32 100644 --- a/pybamm/spatial_methods/scikit_finite_element.py +++ b/pybamm/spatial_methods/scikit_finite_element.py @@ -6,7 +6,8 @@ from scipy.sparse import csr_matrix, csc_matrix from scipy.sparse.linalg import inv import numpy as np -import skfem + +from pybamm.util import have_optional_dependency class ScikitFiniteElement(pybamm.SpatialMethod): @@ -87,6 +88,7 @@ def gradient(self, symbol, discretised_symbol, boundary_conditions): to the y-component of the gradient and the second column corresponds to the z component of the gradient. """ + skfem = have_optional_dependency("skfem") domain = symbol.domain[0] mesh = self.mesh[domain] @@ -142,6 +144,7 @@ def gradient_matrix(self, symbol, boundary_conditions): :class:`pybamm.Matrix` The (sparse) finite element gradient matrix for the domain """ + skfem = have_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] @@ -187,6 +190,7 @@ def laplacian(self, symbol, discretised_symbol, boundary_conditions): Contains the result of acting the discretised gradient on the child discretised_symbol """ + skfem = have_optional_dependency("skfem") domain = symbol.domain[0] mesh = self.mesh[domain] @@ -258,6 +262,7 @@ def stiffness_matrix(self, symbol, boundary_conditions): :class:`pybamm.Matrix` The (sparse) finite element stiffness matrix for the domain """ + skfem = have_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] @@ -320,6 +325,7 @@ def definite_integral_matrix(self, child, vector_type="row"): :class:`pybamm.Matrix` The finite element integral vector for the domain """ + skfem = have_optional_dependency("skfem") # get primary domain mesh domain = child.domain[0] mesh = self.mesh[domain] @@ -381,6 +387,7 @@ def boundary_integral_vector(self, domain, region): :class:`pybamm.Matrix` The finite element integral vector for the domain """ + skfem = have_optional_dependency("skfem") # get primary domain mesh mesh = self.mesh[domain[0]] @@ -498,6 +505,7 @@ def assemble_mass_form(self, symbol, boundary_conditions, region="interior"): :class:`pybamm.Matrix` The (sparse) mass matrix for the spatial method. """ + skfem = have_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] diff --git a/pybamm/util.py b/pybamm/util.py index 562352bfac..90cb290c6e 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -345,3 +345,26 @@ def install_jax(arguments=None): # pragma: no cover f"jaxlib>={JAXLIB_VERSION}", ] ) + +# https://docs.pybamm.org/en/latest/source/user_guide/contributing.html#managing-optional-dependencies-and-their-imports +def have_optional_dependency(module_name, attribute=None): + err_msg = f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details." + try: + # Attempt to import the specified module + module = importlib.import_module(module_name) + + if attribute: + # If an attribute is specified, check if it's available + if hasattr(module, attribute): + imported_attribute = getattr(module, attribute) + return imported_attribute # Return the imported attribute + else: + # Raise an ModuleNotFoundError if the attribute is not available + raise ModuleNotFoundError(err_msg) # pragma: no cover + else: + # Return the entire module if no attribute is specified + return module + + except ModuleNotFoundError: + # Raise an ModuleNotFoundError if the module or attribute is not available + raise ModuleNotFoundError(err_msg) diff --git a/setup.py b/setup.py index f6fd37f75c..fca5b83de8 100644 --- a/setup.py +++ b/setup.py @@ -207,6 +207,7 @@ def compile_KLU(): "scipy>=1.3", "casadi>=3.6.0", "xarray", + "anytree>=2.4.3", ], extras_require={ "docs": [ diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index 6acd7c41b0..225f8e93c9 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -5,10 +5,10 @@ import unittest import numpy as np -import sympy from scipy.sparse import coo_matrix import pybamm +from pybamm.util import have_optional_dependency class TestBinaryOperators(TestCase): @@ -746,6 +746,7 @@ def test_inner_simplifications(self): self.assertEqual(pybamm.inner(a3, a3).evaluate(), 9) def test_to_equation(self): + sympy = have_optional_dependency("sympy") # Test print_name pybamm.Addition.print_name = "test" self.assertEqual(pybamm.Addition(1, 2).to_equation(), sympy.Symbol("test")) diff --git a/tests/unit/test_expression_tree/test_concatenations.py b/tests/unit/test_expression_tree/test_concatenations.py index df5add0f98..4b07b09fea 100644 --- a/tests/unit/test_expression_tree/test_concatenations.py +++ b/tests/unit/test_expression_tree/test_concatenations.py @@ -5,9 +5,9 @@ from tests import TestCase import numpy as np -import sympy import pybamm +from pybamm.util import have_optional_dependency from tests import get_discretisation_for_testing, get_mesh_for_testing @@ -370,6 +370,7 @@ def test_numpy_concatenation(self): ) def test_to_equation(self): + sympy = have_optional_dependency("sympy") a = pybamm.Symbol("a", domain="test a") b = pybamm.Symbol("b", domain="test b") func_symbol = sympy.Symbol(r"\begin{cases}a\\b\end{cases}") diff --git a/tests/unit/test_expression_tree/test_functions.py b/tests/unit/test_expression_tree/test_functions.py index ac5410d9e1..6d22571a01 100644 --- a/tests/unit/test_expression_tree/test_functions.py +++ b/tests/unit/test_expression_tree/test_functions.py @@ -5,10 +5,10 @@ import unittest import numpy as np -import sympy from scipy import special import pybamm +from pybamm.util import have_optional_dependency def test_function(arg): @@ -120,6 +120,7 @@ def test_function_unnamed(self): self.assertEqual(fun.name, "function (cos)") def test_to_equation(self): + sympy = have_optional_dependency("sympy") a = pybamm.Symbol("a", domain="test") # Test print_name diff --git a/tests/unit/test_expression_tree/test_independent_variable.py b/tests/unit/test_expression_tree/test_independent_variable.py index 95141f0f03..b748a6fbe9 100644 --- a/tests/unit/test_expression_tree/test_independent_variable.py +++ b/tests/unit/test_expression_tree/test_independent_variable.py @@ -4,9 +4,9 @@ from tests import TestCase import unittest -import sympy import pybamm +from pybamm.util import have_optional_dependency class TestIndependentVariable(TestCase): @@ -64,6 +64,7 @@ def test_spatial_variable_edge(self): self.assertTrue(x.evaluates_on_edges("primary")) def test_to_equation(self): + sympy = have_optional_dependency("sympy") # Test print_name func = pybamm.IndependentVariable("a") func.print_name = "test" diff --git a/tests/unit/test_expression_tree/test_operations/test_latexify.py b/tests/unit/test_expression_tree/test_operations/test_latexify.py index be7cc21115..7e0703534e 100644 --- a/tests/unit/test_expression_tree/test_operations/test_latexify.py +++ b/tests/unit/test_expression_tree/test_operations/test_latexify.py @@ -8,7 +8,6 @@ import uuid import pybamm -from pybamm.expression_tree.operations.latexify import Latexify class TestLatexify(TestCase): @@ -19,9 +18,6 @@ def test_latexify(self): model_spme = pybamm.lithium_ion.SPMe() func_spme = str(model_spme.latexify()) - # Test docstring - self.assertEqual(pybamm.BaseModel.latexify.__doc__, Latexify.__doc__) - # Test model name self.assertIn("Single Particle Model with electrolyte Equations", func_spme) diff --git a/tests/unit/test_expression_tree/test_parameter.py b/tests/unit/test_expression_tree/test_parameter.py index f67ee2dd62..d9a756b45d 100644 --- a/tests/unit/test_expression_tree/test_parameter.py +++ b/tests/unit/test_expression_tree/test_parameter.py @@ -5,9 +5,8 @@ import numbers import unittest -import sympy - import pybamm +from pybamm.util import have_optional_dependency class TestParameter(TestCase): @@ -21,6 +20,7 @@ def test_evaluate_for_shape(self): self.assertIsInstance(a.evaluate_for_shape(), numbers.Number) def test_to_equation(self): + sympy = have_optional_dependency("sympy") func = pybamm.Parameter("test_string") func1 = pybamm.Parameter("test_name") @@ -98,6 +98,7 @@ def _myfun(x): self.assertEqual(_myfun(x).print_name, None) def test_function_parameter_to_equation(self): + sympy = have_optional_dependency("sympy") func = pybamm.FunctionParameter("test", {"x": pybamm.Scalar(1)}) func1 = pybamm.FunctionParameter("func", {"var": pybamm.Variable("var")}) diff --git a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py index b5ae229ae5..de3ff08c43 100644 --- a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py +++ b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py @@ -4,14 +4,14 @@ from tests import TestCase import unittest -import sympy - import pybamm from pybamm.expression_tree.printing.sympy_overrides import custom_print_func +from pybamm.util import have_optional_dependency class TestCustomPrint(TestCase): def test_print_Derivative(self): + sympy = have_optional_dependency("sympy") # Test force_partial der1 = sympy.Derivative("y", "x") der1.force_partial = True diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 3f91633fbe..3eb7adae47 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -8,10 +8,10 @@ import numpy as np from scipy.sparse import csr_matrix, coo_matrix -import sympy import pybamm from pybamm.expression_tree.binary_operators import _Heaviside +from pybamm.util import have_optional_dependency class TestSymbol(TestCase): @@ -484,6 +484,7 @@ def test_test_shape(self): (y1 + y2).test_shape() def test_to_equation(self): + sympy = have_optional_dependency("sympy") self.assertEqual(pybamm.Symbol("test").to_equation(), sympy.Symbol("test")) def test_numpy_array_ufunc(self): diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index b0513c974b..fc845cb574 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -5,12 +5,10 @@ from tests import TestCase import numpy as np -import sympy from scipy.sparse import diags -from sympy.vector.operators import Divergence as sympy_Divergence -from sympy.vector.operators import Gradient as sympy_Gradient import pybamm +from pybamm.util import have_optional_dependency class TestUnaryOperators(TestCase): @@ -613,6 +611,11 @@ def test_not_constant(self): self.assertFalse((2 * a).is_constant()) def test_to_equation(self): + + sympy = have_optional_dependency("sympy") + sympy_Divergence = have_optional_dependency("sympy.vector.operators", "Divergence") + sympy_Gradient = have_optional_dependency("sympy.vector.operators", "Gradient") + a = pybamm.Symbol("a", domain="negative particle") b = pybamm.Symbol("b", domain="current collector") c = pybamm.Symbol("c", domain="test") diff --git a/tests/unit/test_expression_tree/test_variable.py b/tests/unit/test_expression_tree/test_variable.py index be791903e2..583008f882 100644 --- a/tests/unit/test_expression_tree/test_variable.py +++ b/tests/unit/test_expression_tree/test_variable.py @@ -5,9 +5,9 @@ import unittest import numpy as np -import sympy import pybamm +from pybamm.util import have_optional_dependency class TestVariable(TestCase): @@ -55,6 +55,7 @@ def test_variable_bounds(self): pybamm.Variable("var", bounds=(1, 1)) def test_to_equation(self): + sympy = have_optional_dependency("sympy") # Test print_name func = pybamm.Variable("test_string") func.print_name = "test" diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index c5060e65a6..730e4cc08d 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -10,7 +10,9 @@ import unittest from unittest.mock import patch from io import StringIO +from tempfile import TemporaryDirectory +anytree = sys.modules['anytree'] class TestUtil(TestCase): """ @@ -29,6 +31,7 @@ def test_rmse(self): pybamm.rmse(np.ones(5), np.zeros(3)) def test_is_constant_and_can_evaluate(self): + sys.modules['anytree'] = anytree symbol = pybamm.PrimaryBroadcast(0, "negative electrode") self.assertEqual(False, pybamm.is_constant_and_can_evaluate(symbol)) symbol = pybamm.StateVector(slice(0, 1)) @@ -88,6 +91,25 @@ def test_git_commit_info(self): self.assertIsInstance(git_commit_info, str) self.assertEqual(git_commit_info[:2], "v2") + def test_have_optional_dependency(self): + with self.assertRaisesRegex(ModuleNotFoundError, "Optional dependency pybtex is not available."): + pybtex = sys.modules['pybtex'] + sys.modules['pybtex'] = None + pybamm.print_citations() + with self.assertRaisesRegex(ModuleNotFoundError, "Optional dependency anytree is not available."): + with TemporaryDirectory() as dir_name: + sys.modules['anytree'] = None + test_stub = os.path.join(dir_name, "test_visualize") + test_name = f"{test_stub}.png" + c = pybamm.Variable("c", "negative electrode") + d = pybamm.Variable("d", "negative electrode") + sym = pybamm.div(c * pybamm.grad(c)) + (c / d + c - d) ** 5 + sym.visualise(test_name) + + sys.modules['pybtex'] = pybtex + pybamm.util.have_optional_dependency("pybtex") + pybamm.print_citations() + class TestSearch(TestCase): def test_url_gets_to_stdout(self):