From 4c237c12fb973827eb167f1390e285f3b3229e0a Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Thu, 26 Oct 2023 12:12:50 +0530 Subject: [PATCH 01/50] prevent `pybtex` default installation --- pybamm/__init__.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pybamm/__init__.py b/pybamm/__init__.py index 9aa1ca79a0..a8ffbcf83b 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -52,7 +52,13 @@ ) from .logger import logger, set_logging_level, get_new_logger from .settings import settings -from .citations import Citations, citations, print_citations +try: + import pybtex + + if pybtex is not None: + from .citations import Citations, citations, print_citations +except ImportError: + pass # # Classes for the Expression Tree From 6d30b3adea028d33ab3a377fe1fc870cf3f53abc Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Thu, 26 Oct 2023 20:42:17 +0530 Subject: [PATCH 02/50] resolve `anytree` default installation --- pybamm/expression_tree/symbol.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 5d28884ed5..88c4d02ab8 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -3,10 +3,14 @@ # import numbers -import anytree + +try: + import anytree + from anytree.exporter import DotExporter +except ImportError: + pass import numpy as np import sympy -from anytree.exporter import DotExporter from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property From e3b3b35aa48fe60d9c08454574e2df8aa150b590 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Thu, 26 Oct 2023 21:02:00 +0530 Subject: [PATCH 03/50] resolve `autograd` default imports --- pybamm/expression_tree/functions.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 80c2848ad9..788af40d50 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -3,7 +3,10 @@ # import numbers -import autograd +try: + import autograd +except ImportError: + pass import numpy as np import sympy from scipy import special From 4dd231799f25f45238b58458a386c882fa64fc5e Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Thu, 26 Oct 2023 21:06:51 +0530 Subject: [PATCH 04/50] resolve `skfem` default imports --- pybamm/meshes/scikit_fem_submeshes.py | 5 ++++- pybamm/spatial_methods/scikit_finite_element.py | 5 ++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pybamm/meshes/scikit_fem_submeshes.py b/pybamm/meshes/scikit_fem_submeshes.py index f25dce80b1..c067c43a8a 100644 --- a/pybamm/meshes/scikit_fem_submeshes.py +++ b/pybamm/meshes/scikit_fem_submeshes.py @@ -4,7 +4,10 @@ import pybamm from .meshes import SubMesh -import skfem +try: + import skfem +except ImportError: + pass import numpy as np diff --git a/pybamm/spatial_methods/scikit_finite_element.py b/pybamm/spatial_methods/scikit_finite_element.py index 0f0a42bbcb..7556645028 100644 --- a/pybamm/spatial_methods/scikit_finite_element.py +++ b/pybamm/spatial_methods/scikit_finite_element.py @@ -6,7 +6,10 @@ from scipy.sparse import csr_matrix, csc_matrix from scipy.sparse.linalg import inv import numpy as np -import skfem +try: + import skfem +except ImportError: + pass class ScikitFiniteElement(pybamm.SpatialMethod): From 9e24562b7bf00a7e1149b7910f3f2e2f6b3a0107 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Thu, 26 Oct 2023 21:20:40 +0530 Subject: [PATCH 05/50] resolve `tqdm` default imports --- pybamm/simulation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 380105d215..dfca7e0583 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -8,7 +8,10 @@ import sys from functools import lru_cache from datetime import timedelta -import tqdm +try: + import tqdm +except ImportError: + pass def is_notebook(): From 50315e7f838ecbf2b3cf73e2d02239124aecb286 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Sat, 28 Oct 2023 21:17:12 +0530 Subject: [PATCH 06/50] Raise import error for `anytree` requiring functions --- pybamm/expression_tree/symbol.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 88c4d02ab8..6904854050 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -8,7 +8,9 @@ import anytree from anytree.exporter import DotExporter except ImportError: - pass + _has_anytree = False +else: + _has_anytree = True import numpy as np import sympy from scipy.sparse import csr_matrix, issparse @@ -446,6 +448,8 @@ def render(self): # pragma: no cover """ Print out a visual representation of the tree (this node and its children) """ + if not _has_anytree: + raise ImportError("Module 'anytree' is required to do this") for pre, _, node in anytree.RenderTree(self): if isinstance(node, pybamm.Scalar) and node.name != str(node.value): print("{}{} = {}".format(pre, node.name, node.value)) @@ -463,6 +467,8 @@ def visualise(self, filename): filename : str filename to output, must end in ".png" """ + if not _has_anytree: + raise ImportError("Module 'anytree' is required to do this") # check that filename ends in .png. if filename[-4:] != ".png": @@ -483,6 +489,8 @@ def relabel_tree(self, symbol, counter): Finds all children of a symbol and assigns them a new id so that they can be visualised properly using the graphviz output """ + if not _has_anytree: + raise ImportError("Module 'anytree' is required to do this") name = symbol.name if name == "div": name = "∇⋅" @@ -526,6 +534,8 @@ def pre_order(self): a b """ + if not _has_anytree: + raise ImportError("Module 'anytree' is required to do this") return anytree.PreOrderIter(self) def __str__(self): From e09fcea3888ce5571812f1a5a54b44d2176554db Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 30 Oct 2023 05:04:02 +0530 Subject: [PATCH 07/50] Make simple function to check optional dependency --- pybamm/__init__.py | 1 + pybamm/util.py | 9 +++++++++ 2 files changed, 10 insertions(+) diff --git a/pybamm/__init__.py b/pybamm/__init__.py index a8ffbcf83b..8f92c71e18 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -47,6 +47,7 @@ get_parameters_filepath, have_jax, install_jax, + have_optional_dependency, is_jax_compatible, get_git_commit_info, ) diff --git a/pybamm/util.py b/pybamm/util.py index 562352bfac..c98ee6beda 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -345,3 +345,12 @@ def install_jax(arguments=None): # pragma: no cover f"jaxlib>={JAXLIB_VERSION}", ] ) + + +def have_optional_dependency(module): + try: + importlib.import_module(module) + _has_module = True + except ImportError: + _has_module = False + return _has_module From a07b34251586319c4d98cd6a9d0c9ac4b49bab44 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 30 Oct 2023 07:57:22 +0530 Subject: [PATCH 08/50] Make decorater function --- pybamm/__init__.py | 9 +-------- pybamm/citations.py | 1 + pybamm/util.py | 29 ++++++++++++++++++++++------- 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/pybamm/__init__.py b/pybamm/__init__.py index 8f92c71e18..07d8a1c0ea 100644 --- a/pybamm/__init__.py +++ b/pybamm/__init__.py @@ -53,14 +53,7 @@ ) from .logger import logger, set_logging_level, get_new_logger from .settings import settings -try: - import pybtex - - if pybtex is not None: - from .citations import Citations, citations, print_citations -except ImportError: - pass - +from .citations import Citations, citations, print_citations # # Classes for the Expression Tree # diff --git a/pybamm/citations.py b/pybamm/citations.py index da619062e0..87f8271dde 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -177,6 +177,7 @@ def _tag_citations(self): for key, entry in self._citation_tags.items(): print(f"{key} was cited due to the use of {entry}") + @pybamm.util.have_optional_dependency("pybtex") def print(self, filename=None, output_format="text", verbose=False): """Print all citations that were used for running simulations. The verbose option is provided to print tags for citations in the output such that it can diff --git a/pybamm/util.py b/pybamm/util.py index c98ee6beda..0b68173504 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -347,10 +347,25 @@ def install_jax(arguments=None): # pragma: no cover ) -def have_optional_dependency(module): - try: - importlib.import_module(module) - _has_module = True - except ImportError: - _has_module = False - return _has_module +def have_optional_dependency(module_name, attribute=None): + def decorator(func): + def wrapper(*args, **kwargs): + try: + module = importlib.import_module(module_name) + if attribute: + if hasattr(module, attribute): + imported_attribute = getattr(module, attribute) + print(f"The {module_name}.{attribute} is available.") + kwargs[attribute] = imported_attribute + else: + print(f"The {module_name}.{attribute} is not available.") + else: + print(f"The {module_name} module is available.") + return func(*args, **kwargs) + except ImportError: + if attribute: + print(f"The {module_name}.{attribute} is not available.") + else: + print(f"The {module_name} module is not available.") + return wrapper + return decorator From 9d9db2bf2dc089f17adabf7014ea1d63109c6883 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 30 Oct 2023 16:31:25 +0530 Subject: [PATCH 09/50] Make normal reusable function for optional deps --- pybamm/util.py | 40 +++++++++++++++++++--------------------- 1 file changed, 19 insertions(+), 21 deletions(-) diff --git a/pybamm/util.py b/pybamm/util.py index 0b68173504..f481480635 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -348,24 +348,22 @@ def install_jax(arguments=None): # pragma: no cover def have_optional_dependency(module_name, attribute=None): - def decorator(func): - def wrapper(*args, **kwargs): - try: - module = importlib.import_module(module_name) - if attribute: - if hasattr(module, attribute): - imported_attribute = getattr(module, attribute) - print(f"The {module_name}.{attribute} is available.") - kwargs[attribute] = imported_attribute - else: - print(f"The {module_name}.{attribute} is not available.") - else: - print(f"The {module_name} module is available.") - return func(*args, **kwargs) - except ImportError: - if attribute: - print(f"The {module_name}.{attribute} is not available.") - else: - print(f"The {module_name} module is not available.") - return wrapper - return decorator + try: + module = importlib.import_module(module_name) + if attribute: + if hasattr(module, attribute): + imported_attribute = getattr(module, attribute) + print(f"The {module_name}.{attribute} is available.") + return imported_attribute + else: + print(f"The {module_name}.{attribute} is not available.") + return None + else: + print(f"The {module_name} module is available.") + return module + except ImportError: + if attribute: + print(f"The {module_name}.{attribute} is not available.") + else: + print(f"The {module_name} module is not available.") + return None From 34311ee63326d936bffa473acebdfc1462e6eb14 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 30 Oct 2023 16:32:25 +0530 Subject: [PATCH 10/50] Update `citations.py` for `pybtex` as optional dependency --- pybamm/citations.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pybamm/citations.py b/pybamm/citations.py index 87f8271dde..fa3c4651b7 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -6,10 +6,10 @@ import pybamm import os import warnings -import pybtex +# import pybtex from sys import _getframe -from pybtex.database import parse_file, parse_string, Entry -from pybtex.scanner import PybtexError +# from pybtex.database import parse_file, parse_string, Entry +# from pybtex.scanner import PybtexError class Citations: @@ -76,6 +76,7 @@ def read_citations(self): """Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited by passing a BibTeX citation to :meth:`register`. """ + parse_file = pybamm.util.have_optional_dependency("pybtex.database","parse_file") citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib") bib_data = parse_file(citations_file, bib_format="bibtex") for key, entry in bib_data.entries.items(): @@ -86,6 +87,7 @@ def _add_citation(self, key, entry): previous entry is overwritten """ + Entry = pybamm.util.have_optional_dependency("pybtex.database","Entry") # Check input types are correct if not isinstance(key, str) or not isinstance(entry, Entry): raise TypeError() @@ -151,6 +153,8 @@ def _parse_citation(self, key): key: str A BibTeX formatted citation """ + PybtexError = pybamm.util.have_optional_dependency("pybtex.scanner","PybtexError") + parse_string = pybamm.util.have_optional_dependency("pybtex.database","parse_string") try: # Parse string as a bibtex citation, and check that a citation was found bib_data = parse_string(key, bib_format="bibtex") @@ -177,7 +181,6 @@ def _tag_citations(self): for key, entry in self._citation_tags.items(): print(f"{key} was cited due to the use of {entry}") - @pybamm.util.have_optional_dependency("pybtex") def print(self, filename=None, output_format="text", verbose=False): """Print all citations that were used for running simulations. The verbose option is provided to print tags for citations in the output such that it can @@ -218,6 +221,7 @@ def print(self, filename=None, output_format="text", verbose=False): """ # Parse citations that were not known keys at registration, but do not # fail if they cannot be parsed + pybtex = pybamm.util.have_optional_dependency("pybtex") try: for key in self._unknown_citations: self._parse_citation(key) From 12180658beaf910c513756269b0f3c6df9a16941 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 30 Oct 2023 16:42:31 +0530 Subject: [PATCH 11/50] Execute silently, raise ImportError & import function correctly --- pybamm/citations.py | 14 ++++++-------- pybamm/util.py | 10 +++------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/pybamm/citations.py b/pybamm/citations.py index fa3c4651b7..7d0959d89c 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -6,10 +6,8 @@ import pybamm import os import warnings -# import pybtex from sys import _getframe -# from pybtex.database import parse_file, parse_string, Entry -# from pybtex.scanner import PybtexError +from pybamm.util import have_optional_dependency class Citations: @@ -76,7 +74,7 @@ def read_citations(self): """Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited by passing a BibTeX citation to :meth:`register`. """ - parse_file = pybamm.util.have_optional_dependency("pybtex.database","parse_file") + parse_file = have_optional_dependency("pybtex.database","parse_file") citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib") bib_data = parse_file(citations_file, bib_format="bibtex") for key, entry in bib_data.entries.items(): @@ -87,7 +85,7 @@ def _add_citation(self, key, entry): previous entry is overwritten """ - Entry = pybamm.util.have_optional_dependency("pybtex.database","Entry") + Entry = have_optional_dependency("pybtex.database","Entry") # Check input types are correct if not isinstance(key, str) or not isinstance(entry, Entry): raise TypeError() @@ -153,8 +151,8 @@ def _parse_citation(self, key): key: str A BibTeX formatted citation """ - PybtexError = pybamm.util.have_optional_dependency("pybtex.scanner","PybtexError") - parse_string = pybamm.util.have_optional_dependency("pybtex.database","parse_string") + PybtexError = have_optional_dependency("pybtex.scanner","PybtexError") + parse_string = have_optional_dependency("pybtex.database","parse_string") try: # Parse string as a bibtex citation, and check that a citation was found bib_data = parse_string(key, bib_format="bibtex") @@ -221,7 +219,7 @@ def print(self, filename=None, output_format="text", verbose=False): """ # Parse citations that were not known keys at registration, but do not # fail if they cannot be parsed - pybtex = pybamm.util.have_optional_dependency("pybtex") + pybtex = have_optional_dependency("pybtex") try: for key in self._unknown_citations: self._parse_citation(key) diff --git a/pybamm/util.py b/pybamm/util.py index f481480635..a2625e5405 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -353,17 +353,13 @@ def have_optional_dependency(module_name, attribute=None): if attribute: if hasattr(module, attribute): imported_attribute = getattr(module, attribute) - print(f"The {module_name}.{attribute} is available.") return imported_attribute else: - print(f"The {module_name}.{attribute} is not available.") - return None + raise ImportError(f"{module_name}.{attribute} is not available.") else: - print(f"The {module_name} module is available.") return module except ImportError: if attribute: - print(f"The {module_name}.{attribute} is not available.") + raise ImportError(f"{module_name}.{attribute} is not available.") else: - print(f"The {module_name} module is not available.") - return None + raise ImportError(f"{module_name} module is not available.") From 90ac2ee2f812c34dfaec49434c128c7f789937c6 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 30 Oct 2023 16:56:59 +0530 Subject: [PATCH 12/50] Update `Symbol` for `anytree` as optional dependency --- pybamm/expression_tree/symbol.py | 21 +++++---------------- 1 file changed, 5 insertions(+), 16 deletions(-) diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 6904854050..8ad717f7ff 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -3,20 +3,13 @@ # import numbers - -try: - import anytree - from anytree.exporter import DotExporter -except ImportError: - _has_anytree = False -else: - _has_anytree = True import numpy as np import sympy from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property import pybamm +from pybamm.util import have_optional_dependency from pybamm.expression_tree.printing.print_name import prettify_print_name DOMAIN_LEVELS = ["primary", "secondary", "tertiary", "quaternary"] @@ -448,8 +441,7 @@ def render(self): # pragma: no cover """ Print out a visual representation of the tree (this node and its children) """ - if not _has_anytree: - raise ImportError("Module 'anytree' is required to do this") + anytree = have_optional_dependency("anytree") for pre, _, node in anytree.RenderTree(self): if isinstance(node, pybamm.Scalar) and node.name != str(node.value): print("{}{} = {}".format(pre, node.name, node.value)) @@ -467,9 +459,8 @@ def visualise(self, filename): filename : str filename to output, must end in ".png" """ - if not _has_anytree: - raise ImportError("Module 'anytree' is required to do this") + DotExporter = have_optional_dependency("anytree.exporter","DotExporter") # check that filename ends in .png. if filename[-4:] != ".png": raise ValueError("filename should end in .png") @@ -489,8 +480,7 @@ def relabel_tree(self, symbol, counter): Finds all children of a symbol and assigns them a new id so that they can be visualised properly using the graphviz output """ - if not _has_anytree: - raise ImportError("Module 'anytree' is required to do this") + anytree = have_optional_dependency("anytree") name = symbol.name if name == "div": name = "∇⋅" @@ -534,8 +524,7 @@ def pre_order(self): a b """ - if not _has_anytree: - raise ImportError("Module 'anytree' is required to do this") + anytree = have_optional_dependency("anytree") return anytree.PreOrderIter(self) def __str__(self): From 3e686173bc57e434209a08346214c382a2519b41 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 1 Nov 2023 13:38:54 +0530 Subject: [PATCH 13/50] Update `simulation` for `tqdm` as optional dependency --- pybamm/simulation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index dfca7e0583..0b1a6b2525 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -8,10 +8,7 @@ import sys from functools import lru_cache from datetime import timedelta -try: - import tqdm -except ImportError: - pass +from pybamm.util import have_optional_dependency def is_notebook(): @@ -535,6 +532,7 @@ def solve( Additional key-word arguments passed to `solver.solve`. See :meth:`pybamm.BaseSolver.solve`. """ + tqdm = have_optional_dependency("tqdm") # Setup if solver is None: solver = self._solver From 5551dac392adfb0111b5b208a46d22c6bec747f2 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 1 Nov 2023 13:47:08 +0530 Subject: [PATCH 14/50] Update `Function` class for `autograd` as optional dependency --- pybamm/expression_tree/functions.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index 788af40d50..ebfb313199 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -3,16 +3,12 @@ # import numbers -try: - import autograd -except ImportError: - pass import numpy as np import sympy from scipy import special import pybamm - +from pybamm.util import have_optional_dependency class Function(pybamm.Symbol): """ @@ -99,6 +95,7 @@ def _function_diff(self, children, idx): Derivative with respect to child number 'idx'. See :meth:`pybamm.Symbol._diff()`. """ + autograd = have_optional_dependency("autograd") # Store differentiated function, needed in case we want to convert to CasADi if self.derivative == "autograd": return Function( From 64d9037299a1b48c0e2801919699a3465df935d4 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 1 Nov 2023 14:03:03 +0530 Subject: [PATCH 15/50] Resolve `scikit-fem` based methods --- pybamm/meshes/scikit_fem_submeshes.py | 8 +++----- pybamm/spatial_methods/scikit_finite_element.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/pybamm/meshes/scikit_fem_submeshes.py b/pybamm/meshes/scikit_fem_submeshes.py index c067c43a8a..23c024dbbb 100644 --- a/pybamm/meshes/scikit_fem_submeshes.py +++ b/pybamm/meshes/scikit_fem_submeshes.py @@ -3,13 +3,10 @@ # import pybamm from .meshes import SubMesh - -try: - import skfem -except ImportError: - pass import numpy as np +from pybamm.util import have_optional_dependency + class ScikitSubMesh2D(SubMesh): """ @@ -30,6 +27,7 @@ class ScikitSubMesh2D(SubMesh): """ def __init__(self, edges, coord_sys, tabs): + skfem = have_optional_dependency("skfem") self.edges = edges self.nodes = dict.fromkeys(["y", "z"]) for var in self.nodes.keys(): diff --git a/pybamm/spatial_methods/scikit_finite_element.py b/pybamm/spatial_methods/scikit_finite_element.py index 7556645028..2d51e16c32 100644 --- a/pybamm/spatial_methods/scikit_finite_element.py +++ b/pybamm/spatial_methods/scikit_finite_element.py @@ -6,10 +6,8 @@ from scipy.sparse import csr_matrix, csc_matrix from scipy.sparse.linalg import inv import numpy as np -try: - import skfem -except ImportError: - pass + +from pybamm.util import have_optional_dependency class ScikitFiniteElement(pybamm.SpatialMethod): @@ -90,6 +88,7 @@ def gradient(self, symbol, discretised_symbol, boundary_conditions): to the y-component of the gradient and the second column corresponds to the z component of the gradient. """ + skfem = have_optional_dependency("skfem") domain = symbol.domain[0] mesh = self.mesh[domain] @@ -145,6 +144,7 @@ def gradient_matrix(self, symbol, boundary_conditions): :class:`pybamm.Matrix` The (sparse) finite element gradient matrix for the domain """ + skfem = have_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] @@ -190,6 +190,7 @@ def laplacian(self, symbol, discretised_symbol, boundary_conditions): Contains the result of acting the discretised gradient on the child discretised_symbol """ + skfem = have_optional_dependency("skfem") domain = symbol.domain[0] mesh = self.mesh[domain] @@ -261,6 +262,7 @@ def stiffness_matrix(self, symbol, boundary_conditions): :class:`pybamm.Matrix` The (sparse) finite element stiffness matrix for the domain """ + skfem = have_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] @@ -323,6 +325,7 @@ def definite_integral_matrix(self, child, vector_type="row"): :class:`pybamm.Matrix` The finite element integral vector for the domain """ + skfem = have_optional_dependency("skfem") # get primary domain mesh domain = child.domain[0] mesh = self.mesh[domain] @@ -384,6 +387,7 @@ def boundary_integral_vector(self, domain, region): :class:`pybamm.Matrix` The finite element integral vector for the domain """ + skfem = have_optional_dependency("skfem") # get primary domain mesh mesh = self.mesh[domain[0]] @@ -501,6 +505,7 @@ def assemble_mass_form(self, symbol, boundary_conditions, region="interior"): :class:`pybamm.Matrix` The (sparse) mass matrix for the spatial method. """ + skfem = have_optional_dependency("skfem") # get primary domain mesh domain = symbol.domain[0] mesh = self.mesh[domain] From 9ee911bd3cb3a16b869e17a62c0dcc26c16aca11 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 1 Nov 2023 15:26:41 +0530 Subject: [PATCH 16/50] Resolve `sympy` based methods --- pybamm/expression_tree/array.py | 3 ++- pybamm/expression_tree/binary_operators.py | 6 +++++- pybamm/expression_tree/concatenations.py | 3 ++- pybamm/expression_tree/functions.py | 5 ++++- pybamm/expression_tree/independent_variable.py | 5 +++-- pybamm/expression_tree/operations/latexify.py | 6 ++++-- pybamm/expression_tree/parameter.py | 4 +++- pybamm/expression_tree/printing/sympy_overrides.py | 4 +++- pybamm/expression_tree/scalar.py | 4 ++-- pybamm/expression_tree/symbol.py | 2 +- pybamm/expression_tree/unary_operators.py | 9 ++++++--- pybamm/expression_tree/variable.py | 3 ++- tests/unit/test_expression_tree/test_binary_operators.py | 3 ++- tests/unit/test_expression_tree/test_concatenations.py | 3 ++- tests/unit/test_expression_tree/test_functions.py | 3 ++- .../test_expression_tree/test_independent_variable.py | 3 ++- tests/unit/test_expression_tree/test_parameter.py | 5 +++-- .../test_printing/test_sympy_overrides.py | 4 ++-- tests/unit/test_expression_tree/test_symbol.py | 3 ++- tests/unit/test_expression_tree/test_unary_operators.py | 9 ++++++--- tests/unit/test_expression_tree/test_variable.py | 3 ++- 21 files changed, 60 insertions(+), 30 deletions(-) diff --git a/pybamm/expression_tree/array.py b/pybamm/expression_tree/array.py index a9141041b3..2736886d95 100644 --- a/pybamm/expression_tree/array.py +++ b/pybamm/expression_tree/array.py @@ -2,10 +2,10 @@ # NumpyArray class # import numpy as np -import sympy from scipy.sparse import csr_matrix, issparse import pybamm +from pybamm.util import have_optional_dependency class Array(pybamm.Symbol): @@ -125,6 +125,7 @@ def is_constant(self): def to_equation(self): """Returns the value returned by the node when evaluated.""" + sympy = have_optional_dependency("sympy") entries_list = self.entries.tolist() return sympy.Array(entries_list) diff --git a/pybamm/expression_tree/binary_operators.py b/pybamm/expression_tree/binary_operators.py index 749384e9bc..9fc6d2642e 100644 --- a/pybamm/expression_tree/binary_operators.py +++ b/pybamm/expression_tree/binary_operators.py @@ -4,11 +4,11 @@ import numbers import numpy as np -import sympy from scipy.sparse import csr_matrix, issparse import functools import pybamm +from pybamm.util import have_optional_dependency def _preprocess_binary(left, right): @@ -147,6 +147,7 @@ def _sympy_operator(self, left, right): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -323,6 +324,7 @@ def _binary_evaluate(self, left, right): def _sympy_operator(self, left, right): """Override :meth:`pybamm.BinaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") left = sympy.Matrix(left) right = sympy.Matrix(right) return left * right @@ -626,6 +628,7 @@ def _binary_new_copy(self, left, right): def _sympy_operator(self, left, right): """Override :meth:`pybamm.BinaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.Min(left, right) @@ -662,6 +665,7 @@ def _binary_new_copy(self, left, right): def _sympy_operator(self, left, right): """Override :meth:`pybamm.BinaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.Max(left, right) diff --git a/pybamm/expression_tree/concatenations.py b/pybamm/expression_tree/concatenations.py index 2185a0fad6..1c82aff122 100644 --- a/pybamm/expression_tree/concatenations.py +++ b/pybamm/expression_tree/concatenations.py @@ -5,10 +5,10 @@ from collections import defaultdict import numpy as np -import sympy from scipy.sparse import issparse, vstack import pybamm +from pybamm.util import have_optional_dependency class Concatenation(pybamm.Symbol): @@ -135,6 +135,7 @@ def is_constant(self): def _sympy_operator(self, *children): """Apply appropriate SymPy operators.""" + sympy = have_optional_dependency("sympy") self.concat_latex = tuple(map(sympy.latex, children)) if self.print_name is not None: diff --git a/pybamm/expression_tree/functions.py b/pybamm/expression_tree/functions.py index ebfb313199..0c7e98b508 100644 --- a/pybamm/expression_tree/functions.py +++ b/pybamm/expression_tree/functions.py @@ -4,7 +4,6 @@ import numbers import numpy as np -import sympy from scipy import special import pybamm @@ -202,6 +201,7 @@ def _sympy_operator(self, child): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -250,6 +250,7 @@ def _function_new_copy(self, children): def _sympy_operator(self, child): """Apply appropriate SymPy operators.""" + sympy = have_optional_dependency("sympy") class_name = self.__class__.__name__.lower() sympy_function = getattr(sympy, class_name) return sympy_function(child) @@ -267,6 +268,7 @@ def _function_diff(self, children, idx): def _sympy_operator(self, child): """Override :meth:`pybamm.Function._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.asinh(child) @@ -287,6 +289,7 @@ def _function_diff(self, children, idx): def _sympy_operator(self, child): """Override :meth:`pybamm.Function._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.atan(child) diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index efeb73f8bc..4c139c30a8 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -1,9 +1,8 @@ # # IndependentVariable class # -import sympy - import pybamm +from pybamm.utili import have_optional_dependency KNOWN_COORD_SYS = ["cartesian", "cylindrical polar", "spherical polar"] @@ -44,6 +43,7 @@ def _jac(self, variable): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -77,6 +77,7 @@ def _evaluate_for_shape(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") return sympy.Symbol("t") diff --git a/pybamm/expression_tree/operations/latexify.py b/pybamm/expression_tree/operations/latexify.py index 67e0199656..9f2949069e 100644 --- a/pybamm/expression_tree/operations/latexify.py +++ b/pybamm/expression_tree/operations/latexify.py @@ -5,10 +5,9 @@ import re import warnings -import sympy - import pybamm from pybamm.expression_tree.printing.sympy_overrides import custom_print_func +from pybamm.util import have_optional_dependency def get_rng_min_max_name(rng, min_or_max): @@ -88,6 +87,7 @@ def _get_bcs_displays(self, var): Returns a list of boundary condition equations with ranges in front of the equations. """ + sympy = have_optional_dependency("sympy") bcs_eqn_list = [] bcs = self.model.boundary_conditions.get(var, None) @@ -118,6 +118,7 @@ def _get_bcs_displays(self, var): def _get_param_var(self, node): """Returns a list of parameters and a list of variables.""" + sympy = have_optional_dependency("sympy") param_list = [] var_list = [] dfs_nodes = [node] @@ -160,6 +161,7 @@ def _get_param_var(self, node): return param_list, var_list def latexify(self, output_variables=None): + sympy = have_optional_dependency("sympy") # Voltage is the default output variable if it exists if output_variables is None: if "Voltage [V]" in self.model.variables: diff --git a/pybamm/expression_tree/parameter.py b/pybamm/expression_tree/parameter.py index 10addae464..eebe77ad2f 100644 --- a/pybamm/expression_tree/parameter.py +++ b/pybamm/expression_tree/parameter.py @@ -5,9 +5,9 @@ import sys import numpy as np -import sympy import pybamm +from pybamm.util import have_optional_dependency class Parameter(pybamm.Symbol): @@ -44,6 +44,7 @@ def is_constant(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -217,6 +218,7 @@ def _evaluate_for_shape(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: diff --git a/pybamm/expression_tree/printing/sympy_overrides.py b/pybamm/expression_tree/printing/sympy_overrides.py index a96aa19729..59f9567c5d 100644 --- a/pybamm/expression_tree/printing/sympy_overrides.py +++ b/pybamm/expression_tree/printing/sympy_overrides.py @@ -3,9 +3,10 @@ # import re -from sympy.printing.latex import LatexPrinter +from pybamm.util import have_optional_dependency +LatexPrinter = have_optional_dependency("sympy.printing.latex","LatexPrinter") class CustomPrint(LatexPrinter): """Override SymPy methods to match PyBaMM's requirements""" @@ -21,4 +22,5 @@ def _print_Derivative(self, expr): def custom_print_func(expr, **settings): + have_optional_dependency("sympy.printing.latex","LatexPrinter") return CustomPrint().doprint(expr) diff --git a/pybamm/expression_tree/scalar.py b/pybamm/expression_tree/scalar.py index 3149bf7bee..0209c02a8e 100644 --- a/pybamm/expression_tree/scalar.py +++ b/pybamm/expression_tree/scalar.py @@ -2,10 +2,9 @@ # Scalar class # import numpy as np -import sympy import pybamm - +from pybamm.util import have_optional_dependency class Scalar(pybamm.Symbol): """ @@ -70,6 +69,7 @@ def is_constant(self): def to_equation(self): """Returns the value returned by the node when evaluated.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 8ad717f7ff..85c392e590 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -4,7 +4,6 @@ import numbers import numpy as np -import sympy from scipy.sparse import csr_matrix, issparse from functools import lru_cache, cached_property @@ -987,4 +986,5 @@ def print_name(self, name): self._print_name = prettify_print_name(name) def to_equation(self): + sympy = have_optional_dependency("sympy") return sympy.Symbol(str(self.name)) diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index 7f9c45775c..e555f48455 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -4,11 +4,9 @@ import numbers import numpy as np -import sympy from scipy.sparse import csr_matrix, issparse -from sympy.vector.operators import Divergence as sympy_Divergence -from sympy.vector.operators import Gradient as sympy_Gradient import pybamm +from pybamm.util import have_optional_dependency class UnaryOperator(pybamm.Symbol): @@ -83,6 +81,7 @@ def _sympy_operator(self, child): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: @@ -368,6 +367,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" + sympy_Gradient = have_optional_dependency("sympy.vector.operators","Gradient") return sympy_Gradient(child) @@ -403,6 +403,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" + sympy_Divergence = have_optional_dependency("sympy.vector.operators","Divergence") return sympy_Divergence(child) @@ -579,6 +580,7 @@ def _evaluates_on_edges(self, dimension): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") return sympy.Integral(child, sympy.Symbol("xn")) @@ -889,6 +891,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" + sympy = have_optional_dependency("sympy") if ( self.child.domain[0] in ["negative particle", "positive particle"] and self.side == "right" diff --git a/pybamm/expression_tree/variable.py b/pybamm/expression_tree/variable.py index f9f7d94efc..0d1e1fd424 100644 --- a/pybamm/expression_tree/variable.py +++ b/pybamm/expression_tree/variable.py @@ -3,9 +3,9 @@ # import numpy as np -import sympy import numbers import pybamm +from pybamm.util import have_optional_dependency class VariableBase(pybamm.Symbol): @@ -124,6 +124,7 @@ def _evaluate_for_shape(self): def to_equation(self): """Convert the node and its subtree into a SymPy equation.""" + sympy = have_optional_dependency("sympy") if self.print_name is not None: return sympy.Symbol(self.print_name) else: diff --git a/tests/unit/test_expression_tree/test_binary_operators.py b/tests/unit/test_expression_tree/test_binary_operators.py index 6acd7c41b0..225f8e93c9 100644 --- a/tests/unit/test_expression_tree/test_binary_operators.py +++ b/tests/unit/test_expression_tree/test_binary_operators.py @@ -5,10 +5,10 @@ import unittest import numpy as np -import sympy from scipy.sparse import coo_matrix import pybamm +from pybamm.util import have_optional_dependency class TestBinaryOperators(TestCase): @@ -746,6 +746,7 @@ def test_inner_simplifications(self): self.assertEqual(pybamm.inner(a3, a3).evaluate(), 9) def test_to_equation(self): + sympy = have_optional_dependency("sympy") # Test print_name pybamm.Addition.print_name = "test" self.assertEqual(pybamm.Addition(1, 2).to_equation(), sympy.Symbol("test")) diff --git a/tests/unit/test_expression_tree/test_concatenations.py b/tests/unit/test_expression_tree/test_concatenations.py index df5add0f98..4b07b09fea 100644 --- a/tests/unit/test_expression_tree/test_concatenations.py +++ b/tests/unit/test_expression_tree/test_concatenations.py @@ -5,9 +5,9 @@ from tests import TestCase import numpy as np -import sympy import pybamm +from pybamm.util import have_optional_dependency from tests import get_discretisation_for_testing, get_mesh_for_testing @@ -370,6 +370,7 @@ def test_numpy_concatenation(self): ) def test_to_equation(self): + sympy = have_optional_dependency("sympy") a = pybamm.Symbol("a", domain="test a") b = pybamm.Symbol("b", domain="test b") func_symbol = sympy.Symbol(r"\begin{cases}a\\b\end{cases}") diff --git a/tests/unit/test_expression_tree/test_functions.py b/tests/unit/test_expression_tree/test_functions.py index ac5410d9e1..6d22571a01 100644 --- a/tests/unit/test_expression_tree/test_functions.py +++ b/tests/unit/test_expression_tree/test_functions.py @@ -5,10 +5,10 @@ import unittest import numpy as np -import sympy from scipy import special import pybamm +from pybamm.util import have_optional_dependency def test_function(arg): @@ -120,6 +120,7 @@ def test_function_unnamed(self): self.assertEqual(fun.name, "function (cos)") def test_to_equation(self): + sympy = have_optional_dependency("sympy") a = pybamm.Symbol("a", domain="test") # Test print_name diff --git a/tests/unit/test_expression_tree/test_independent_variable.py b/tests/unit/test_expression_tree/test_independent_variable.py index 95141f0f03..b748a6fbe9 100644 --- a/tests/unit/test_expression_tree/test_independent_variable.py +++ b/tests/unit/test_expression_tree/test_independent_variable.py @@ -4,9 +4,9 @@ from tests import TestCase import unittest -import sympy import pybamm +from pybamm.util import have_optional_dependency class TestIndependentVariable(TestCase): @@ -64,6 +64,7 @@ def test_spatial_variable_edge(self): self.assertTrue(x.evaluates_on_edges("primary")) def test_to_equation(self): + sympy = have_optional_dependency("sympy") # Test print_name func = pybamm.IndependentVariable("a") func.print_name = "test" diff --git a/tests/unit/test_expression_tree/test_parameter.py b/tests/unit/test_expression_tree/test_parameter.py index f67ee2dd62..d9a756b45d 100644 --- a/tests/unit/test_expression_tree/test_parameter.py +++ b/tests/unit/test_expression_tree/test_parameter.py @@ -5,9 +5,8 @@ import numbers import unittest -import sympy - import pybamm +from pybamm.util import have_optional_dependency class TestParameter(TestCase): @@ -21,6 +20,7 @@ def test_evaluate_for_shape(self): self.assertIsInstance(a.evaluate_for_shape(), numbers.Number) def test_to_equation(self): + sympy = have_optional_dependency("sympy") func = pybamm.Parameter("test_string") func1 = pybamm.Parameter("test_name") @@ -98,6 +98,7 @@ def _myfun(x): self.assertEqual(_myfun(x).print_name, None) def test_function_parameter_to_equation(self): + sympy = have_optional_dependency("sympy") func = pybamm.FunctionParameter("test", {"x": pybamm.Scalar(1)}) func1 = pybamm.FunctionParameter("func", {"var": pybamm.Variable("var")}) diff --git a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py index b5ae229ae5..de3ff08c43 100644 --- a/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py +++ b/tests/unit/test_expression_tree/test_printing/test_sympy_overrides.py @@ -4,14 +4,14 @@ from tests import TestCase import unittest -import sympy - import pybamm from pybamm.expression_tree.printing.sympy_overrides import custom_print_func +from pybamm.util import have_optional_dependency class TestCustomPrint(TestCase): def test_print_Derivative(self): + sympy = have_optional_dependency("sympy") # Test force_partial der1 = sympy.Derivative("y", "x") der1.force_partial = True diff --git a/tests/unit/test_expression_tree/test_symbol.py b/tests/unit/test_expression_tree/test_symbol.py index 3f91633fbe..3eb7adae47 100644 --- a/tests/unit/test_expression_tree/test_symbol.py +++ b/tests/unit/test_expression_tree/test_symbol.py @@ -8,10 +8,10 @@ import numpy as np from scipy.sparse import csr_matrix, coo_matrix -import sympy import pybamm from pybamm.expression_tree.binary_operators import _Heaviside +from pybamm.util import have_optional_dependency class TestSymbol(TestCase): @@ -484,6 +484,7 @@ def test_test_shape(self): (y1 + y2).test_shape() def test_to_equation(self): + sympy = have_optional_dependency("sympy") self.assertEqual(pybamm.Symbol("test").to_equation(), sympy.Symbol("test")) def test_numpy_array_ufunc(self): diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index b0513c974b..d8bf30d79f 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -5,12 +5,10 @@ from tests import TestCase import numpy as np -import sympy from scipy.sparse import diags -from sympy.vector.operators import Divergence as sympy_Divergence -from sympy.vector.operators import Gradient as sympy_Gradient import pybamm +from pybamm.util import have_optional_dependency class TestUnaryOperators(TestCase): @@ -613,6 +611,11 @@ def test_not_constant(self): self.assertFalse((2 * a).is_constant()) def test_to_equation(self): + + sympy = have_optional_dependency("sympy") + sympy_Divergence = have_optional_dependency("sympy.vector.operators","Divergence") + sympy_Gradient = have_optional_dependency("sympy.vector.operators","Gradient") + a = pybamm.Symbol("a", domain="negative particle") b = pybamm.Symbol("b", domain="current collector") c = pybamm.Symbol("c", domain="test") diff --git a/tests/unit/test_expression_tree/test_variable.py b/tests/unit/test_expression_tree/test_variable.py index be791903e2..583008f882 100644 --- a/tests/unit/test_expression_tree/test_variable.py +++ b/tests/unit/test_expression_tree/test_variable.py @@ -5,9 +5,9 @@ import unittest import numpy as np -import sympy import pybamm +from pybamm.util import have_optional_dependency class TestVariable(TestCase): @@ -55,6 +55,7 @@ def test_variable_bounds(self): pybamm.Variable("var", bounds=(1, 1)) def test_to_equation(self): + sympy = have_optional_dependency("sympy") # Test print_name func = pybamm.Variable("test_string") func.print_name = "test" From efe887747422cf379b706af19e2dfd1116c5a5b7 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 1 Nov 2023 15:43:29 +0530 Subject: [PATCH 17/50] Fix Typo --- pybamm/expression_tree/independent_variable.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pybamm/expression_tree/independent_variable.py b/pybamm/expression_tree/independent_variable.py index 4c139c30a8..2f30da9a5e 100644 --- a/pybamm/expression_tree/independent_variable.py +++ b/pybamm/expression_tree/independent_variable.py @@ -2,7 +2,7 @@ # IndependentVariable class # import pybamm -from pybamm.utili import have_optional_dependency +from pybamm.util import have_optional_dependency KNOWN_COORD_SYS = ["cartesian", "cylindrical polar", "spherical polar"] From 911105521377a8fa8e014a31f5c3e5a2a4b1fe7e Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 1 Nov 2023 15:55:42 +0530 Subject: [PATCH 18/50] Return more helpful message --- pybamm/util.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pybamm/util.py b/pybamm/util.py index a2625e5405..9af22a8ab3 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -355,11 +355,11 @@ def have_optional_dependency(module_name, attribute=None): imported_attribute = getattr(module, attribute) return imported_attribute else: - raise ImportError(f"{module_name}.{attribute} is not available.") + raise ImportError(f"Optional dependency {module_name}.{attribute} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") else: return module except ImportError: if attribute: - raise ImportError(f"{module_name}.{attribute} is not available.") + raise ImportError(f"Optional dependency {module_name}.{attribute} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") else: - raise ImportError(f"{module_name} module is not available.") + raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") From c65a2a29e785f71bfeb4905f45761a7c394000a9 Mon Sep 17 00:00:00 2001 From: Arjun Date: Fri, 3 Nov 2023 14:12:08 +0530 Subject: [PATCH 19/50] Abstraction to only show module name if not available Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> --- pybamm/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pybamm/util.py b/pybamm/util.py index 9af22a8ab3..8656f00701 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -355,11 +355,11 @@ def have_optional_dependency(module_name, attribute=None): imported_attribute = getattr(module, attribute) return imported_attribute else: - raise ImportError(f"Optional dependency {module_name}.{attribute} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") + raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") else: return module except ImportError: if attribute: - raise ImportError(f"Optional dependency {module_name}.{attribute} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") + raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") else: raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") From 4d32e32c7ac0a195d532f266de9dfaab26e11df7 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Fri, 3 Nov 2023 23:48:03 +0530 Subject: [PATCH 20/50] Update docs for have_optional_deps --- CONTRIBUTING.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index bec0fee02a..8eceda7972 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -100,9 +100,9 @@ On the other hand... We _do_ want to compare several tools, to generate document Only 'core pybamm' is installed by default. The others have to be specified explicitly when running the installation command. -### Matplotlib +### Managing Optional Dependencies and Their Imports -We use Matplotlib in PyBaMM, but with two caveats: +PyBaMM utilizes optional dependencies to allow users to choose which additional libraries they want to use. Managing these optional dependencies and their imports is essential to provide flexibility to PyBaMM users. First, Matplotlib should only be used in plotting methods, and these should _never_ be called by other PyBaMM methods. So users who don't like Matplotlib will not be forced to use it in any way. Use in notebooks is OK and encouraged. From fd0916322fd390fee7502d9022d8e5536c59124a Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 6 Nov 2023 15:50:06 +0530 Subject: [PATCH 21/50] Update for `have_optional_dependency` --- CONTRIBUTING.md | 25 ++++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8eceda7972..648996a024 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -54,7 +54,7 @@ You now have everything you need to start making changes! 10. [Test your code!](#testing) 11. PyBaMM has online documentation at http://docs.pybamm.org/. To make sure any new methods or classes you added show up there, please read the [documentation](#documentation) section. 12. If you added a major new feature, perhaps it should be showcased in an [example notebook](#example-notebooks). -13. When you feel your code is finished, or at least warrants serious discussion, run the [pre-commit checks](#pre-commit-checks) and then create a [pull request](https://help.github.com/articles/about-pull-requests/) (PR) on [PyBaMM's GitHub page](https://github.com/pybamm-team/PyBaMM). +13. When you feel your code is finished, or at least warrants serious discussion, run the [pre-commit checks](#pre-commit-checks) and then create a [pull request](https://help.github.com/articles/about-pull-requests/) (PR) on [PyBaMM's GitHub page](https://github.com/pybamm-team/PyBaMM). 14. Once a PR has been created, it will be reviewed by any member of the community. Changes might be suggested which you can make by simply adding new commits to the branch. When everything's finished, someone with the right GitHub permissions will merge your changes into PyBaMM main repository. Finally, if you really, really, _really_ love developing PyBaMM, have a look at the current [project infrastructure](#infrastructure). @@ -104,17 +104,25 @@ Only 'core pybamm' is installed by default. The others have to be specified expl PyBaMM utilizes optional dependencies to allow users to choose which additional libraries they want to use. Managing these optional dependencies and their imports is essential to provide flexibility to PyBaMM users. -First, Matplotlib should only be used in plotting methods, and these should _never_ be called by other PyBaMM methods. So users who don't like Matplotlib will not be forced to use it in any way. Use in notebooks is OK and encouraged. +PyBaMM provides a utility function `have_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it: -Second, Matplotlib should never be imported at the module level, but always inside methods. For example: +Optional Dependencies should never be imported at the module level, but always inside methods. For example: ``` -def plot_great_things(self, x, y, z): - import matplotlib.pyplot as pl +def use_pybtex(x,y,z): + pybtex = have_optional_dependency("pybtex") ... ``` -This allows people to (1) use PyBaMM without ever importing Matplotlib and (2) configure Matplotlib's back-end in their scripts, which _must_ be done before e.g. `pyplot` is first imported. +While importing a specific attribute instead of whole module: + +``` +def use_parse_file(x,y,z): + parse_file = have_optional_dependency("pybtex.database","parse_file") + ... +``` + +This allows people to (1) use PyBaMM without importing Optional dependency by default and (2) configure module dependent functionality in their scripts, which _must_ be done before e.g. `print_citations` method is first imported. ## Testing @@ -266,7 +274,6 @@ This also means that, if you can't fix the bug yourself, it will be much easier ``` This will start the debugger at the point where the `ValueError` was raised, and allow you to investigate further. Sometimes, it is more informative to put the try-except block further up the call stack than exactly where the error is raised. - 2. Warnings. If functions are raising warnings instead of errors, it can be hard to pinpoint where this is coming from. Here, you can use the `warnings` module to convert warnings to errors: ```python @@ -276,7 +283,6 @@ This also means that, if you can't fix the bug yourself, it will be much easier ``` Then you can use a try-except block, as in a., but with, for example, `RuntimeWarning` instead of `ValueError`. - 3. Stepping through the expression tree. Most calls in PyBaMM are operations on [expression trees](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/expression_tree/expression-tree.ipynb). To view an expression tree in ipython, you can use the `render` command: ```python @@ -284,11 +290,8 @@ This also means that, if you can't fix the bug yourself, it will be much easier ``` You can then step through the expression tree, using the `children` attribute, to pinpoint exactly where a bug is coming from. For example, if `expression_tree.jac(y)` is failing, you can check `expression_tree.children[0].jac(y)`, then `expression_tree.children[0].children[0].jac(y)`, etc. - 3. To isolate whether a bug is in a model, its Jacobian or its simplified version, you can set the `use_jacobian` and/or `use_simplify` attributes of the model to `False` (they are both `True` by default for most models). - 4. If a model isn't giving the answer you expect, you can try comparing it to other models. For example, you can investigate parameter limits in which two models should give the same answer by setting some parameters to be small or zero. The `StandardOutputComparison` class can be used to compare some standard outputs from battery models. - 5. To get more information about what is going on under the hood, and hence understand what is causing the bug, you can set the [logging](https://realpython.com/python-logging/) level to `DEBUG` by adding the following line to your test or script: ```python3 From 926f8d74b6a40c98a3c9a40a40af5ffb4fad154b Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 6 Nov 2023 16:00:02 +0530 Subject: [PATCH 22/50] Add comments to `have_optional_dependency` --- pybamm/util.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pybamm/util.py b/pybamm/util.py index 8656f00701..78a5cff27d 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -346,19 +346,26 @@ def install_jax(arguments=None): # pragma: no cover ] ) - +# https://docs.pybamm.org/en/latest/source/user_guide/contributing.html#managing-optional-dependencies-and-their-imports def have_optional_dependency(module_name, attribute=None): try: + # Attempt to import the specified module module = importlib.import_module(module_name) + if attribute: + # If an attribute is specified, check if it's available if hasattr(module, attribute): imported_attribute = getattr(module, attribute) - return imported_attribute + return imported_attribute # Return the imported attribute else: + # Raise an ImportError if the attribute is not available raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") else: + # Return the entire module if no attribute is specified return module + except ImportError: + # Raise an ImportError if the module or attribute is not available if attribute: raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") else: From ec963d1ae3dbaf673e6f512156f6cddc36f4e52b Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Wed, 8 Nov 2023 23:19:19 +0530 Subject: [PATCH 23/50] Add `test_have_optional_dependency` --- tests/unit/test_util.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index c5060e65a6..8f706d8149 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -88,6 +88,10 @@ def test_git_commit_info(self): self.assertIsInstance(git_commit_info, str) self.assertEqual(git_commit_info[:2], "v2") + def test_have_optional_dependency(self): + with self.assertRaisesRegex(ImportError,"Optional dependency pybtex is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + pybamm.print_citations() + class TestSearch(TestCase): def test_url_gets_to_stdout(self): From dd8a6f21d573f397767b97b662cd26c591d06e0d Mon Sep 17 00:00:00 2001 From: Arjun Date: Thu, 9 Nov 2023 18:52:12 +0530 Subject: [PATCH 24/50] Apply suggestions from code review Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> --- CONTRIBUTING.md | 8 ++++---- pybamm/expression_tree/unary_operators.py | 4 ++-- pybamm/util.py | 4 ++-- tests/unit/test_expression_tree/test_unary_operators.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 648996a024..78fbb0fdec 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -106,7 +106,7 @@ PyBaMM utilizes optional dependencies to allow users to choose which additional PyBaMM provides a utility function `have_optional_dependency`, to check for the availability of optional dependencies within methods. This function can be used to conditionally import optional dependencies only if they are available. Here's how to use it: -Optional Dependencies should never be imported at the module level, but always inside methods. For example: +Optional dependencies should never be imported at the module level, but always inside methods. For example: ``` def use_pybtex(x,y,z): @@ -114,15 +114,15 @@ def use_pybtex(x,y,z): ... ``` -While importing a specific attribute instead of whole module: +While importing a specific module instead of an entire package/library: -``` +```python def use_parse_file(x,y,z): parse_file = have_optional_dependency("pybtex.database","parse_file") ... ``` -This allows people to (1) use PyBaMM without importing Optional dependency by default and (2) configure module dependent functionality in their scripts, which _must_ be done before e.g. `print_citations` method is first imported. +This allows people to (1) use PyBaMM without importing optional dependencies by default and (2) configure module-dependent functionalities in their scripts, which _must_ be done before e.g. `print_citations` method is first imported. ## Testing diff --git a/pybamm/expression_tree/unary_operators.py b/pybamm/expression_tree/unary_operators.py index e555f48455..81c3dc28c2 100644 --- a/pybamm/expression_tree/unary_operators.py +++ b/pybamm/expression_tree/unary_operators.py @@ -367,7 +367,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" - sympy_Gradient = have_optional_dependency("sympy.vector.operators","Gradient") + sympy_Gradient = have_optional_dependency("sympy.vector.operators", "Gradient") return sympy_Gradient(child) @@ -403,7 +403,7 @@ def _unary_new_copy(self, child): def _sympy_operator(self, child): """Override :meth:`pybamm.UnaryOperator._sympy_operator`""" - sympy_Divergence = have_optional_dependency("sympy.vector.operators","Divergence") + sympy_Divergence = have_optional_dependency("sympy.vector.operators", "Divergence") return sympy_Divergence(child) diff --git a/pybamm/util.py b/pybamm/util.py index 78a5cff27d..b6825f7eda 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -359,12 +359,12 @@ def have_optional_dependency(module_name, attribute=None): return imported_attribute # Return the imported attribute else: # Raise an ImportError if the attribute is not available - raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") + raise ModuleNotFoundError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") else: # Return the entire module if no attribute is specified return module - except ImportError: + except ModuleNotFoundError: # Raise an ImportError if the module or attribute is not available if attribute: raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") diff --git a/tests/unit/test_expression_tree/test_unary_operators.py b/tests/unit/test_expression_tree/test_unary_operators.py index d8bf30d79f..fc845cb574 100644 --- a/tests/unit/test_expression_tree/test_unary_operators.py +++ b/tests/unit/test_expression_tree/test_unary_operators.py @@ -613,8 +613,8 @@ def test_not_constant(self): def test_to_equation(self): sympy = have_optional_dependency("sympy") - sympy_Divergence = have_optional_dependency("sympy.vector.operators","Divergence") - sympy_Gradient = have_optional_dependency("sympy.vector.operators","Gradient") + sympy_Divergence = have_optional_dependency("sympy.vector.operators", "Divergence") + sympy_Gradient = have_optional_dependency("sympy.vector.operators", "Gradient") a = pybamm.Symbol("a", domain="negative particle") b = pybamm.Symbol("b", domain="current collector") From aa2327edd7dfd25298e7b4076902bca74814b880 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Nov 2023 13:22:22 +0000 Subject: [PATCH 25/50] style: pre-commit fixes --- CONTRIBUTING.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 78fbb0fdec..de0d626940 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -117,8 +117,8 @@ def use_pybtex(x,y,z): While importing a specific module instead of an entire package/library: ```python -def use_parse_file(x,y,z): - parse_file = have_optional_dependency("pybtex.database","parse_file") +def use_parse_file(x, y, z): + parse_file = have_optional_dependency("pybtex.database", "parse_file") ... ``` From c28c7fbfe0f8a3405b2251cc84aa6b159a2891bb Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Thu, 9 Nov 2023 21:24:34 +0530 Subject: [PATCH 26/50] Raise simple ModuleNotFoundError even if attribute not found --- pybamm/util.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/pybamm/util.py b/pybamm/util.py index b6825f7eda..dee77b8841 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -358,15 +358,12 @@ def have_optional_dependency(module_name, attribute=None): imported_attribute = getattr(module, attribute) return imported_attribute # Return the imported attribute else: - # Raise an ImportError if the attribute is not available + # Raise an ModuleNotFoundError if the attribute is not available raise ModuleNotFoundError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") else: # Return the entire module if no attribute is specified return module except ModuleNotFoundError: - # Raise an ImportError if the module or attribute is not available - if attribute: - raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") - else: - raise ImportError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") + # Raise an ModuleNotFoundError if the module or attribute is not available + raise ModuleNotFoundError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") From fd9ae61636bc92cc31e3efab7765f601218281d7 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Thu, 9 Nov 2023 21:44:01 +0530 Subject: [PATCH 27/50] Set pybtex to None to avoid import --- tests/unit/test_util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 8f706d8149..bfcac5fa5f 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -89,7 +89,8 @@ def test_git_commit_info(self): self.assertEqual(git_commit_info[:2], "v2") def test_have_optional_dependency(self): - with self.assertRaisesRegex(ImportError,"Optional dependency pybtex is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency pybtex.database is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + sys.modules['pybtex'] = None pybamm.print_citations() From 7cb2ef69250dba93955ca46468c17ee35ad581ee Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Thu, 9 Nov 2023 22:20:28 +0530 Subject: [PATCH 28/50] Add more testcases for optional dependencies --- tests/unit/test_util.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index bfcac5fa5f..8edf4ad6ec 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -11,7 +11,8 @@ from unittest.mock import patch from io import StringIO - +def test_function(arg): + return arg + arg class TestUtil(TestCase): """ Test the functionality in util.py @@ -89,9 +90,21 @@ def test_git_commit_info(self): self.assertEqual(git_commit_info[:2], "v2") def test_have_optional_dependency(self): - with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency pybtex.database is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency pybtex is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): sys.modules['pybtex'] = None pybamm.print_citations() + with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency tqdm is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + sys.modules['tqdm'] = None + model = pybamm.BaseModel() + v = pybamm.Variable("v") + model.rhs = {v: -v} + model.initial_conditions = {v: 1} + sim = pybamm.Simulation(model) + sim.solve([0, 1]) + with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency autograd is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + sys.modules['autograd'] = None + a = pybamm.StateVector(slice(0, 1)) + pybamm.Function(test_function, a) class TestSearch(TestCase): From b681bbc69e15171429ee5ea48a1282ba83fdc30c Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Thu, 9 Nov 2023 23:54:54 +0530 Subject: [PATCH 29/50] Add test for case if dependency is available --- tests/unit/test_util.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 8edf4ad6ec..7b6864f443 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -101,10 +101,9 @@ def test_have_optional_dependency(self): model.initial_conditions = {v: 1} sim = pybamm.Simulation(model) sim.solve([0, 1]) - with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency autograd is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): - sys.modules['autograd'] = None - a = pybamm.StateVector(slice(0, 1)) - pybamm.Function(test_function, a) + + sys.modules['pybtex'] = pybamm.util.have_optional_dependency("pybtex") + pybamm.print_citations() class TestSearch(TestCase): From f2e37cf0439ec06a994b13c957894c2f194d1dc9 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Fri, 10 Nov 2023 00:33:43 +0530 Subject: [PATCH 30/50] Reset pybtex to run dependent function --- tests/unit/test_util.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 7b6864f443..3ac78986bb 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -91,6 +91,7 @@ def test_git_commit_info(self): def test_have_optional_dependency(self): with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency pybtex is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + pybtex = sys.modules['pybtex'] sys.modules['pybtex'] = None pybamm.print_citations() with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency tqdm is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): @@ -102,7 +103,8 @@ def test_have_optional_dependency(self): sim = pybamm.Simulation(model) sim.solve([0, 1]) - sys.modules['pybtex'] = pybamm.util.have_optional_dependency("pybtex") + sys.modules['pybtex'] = pybtex + pybamm.util.have_optional_dependency("pybtex") pybamm.print_citations() From 8d6db99511e3bdbe6f20bacac6714c85fc4a0b11 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Fri, 10 Nov 2023 01:46:00 +0530 Subject: [PATCH 31/50] Add test for full coverage --- tests/unit/test_util.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 3ac78986bb..a9f70bbcc7 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -10,6 +10,7 @@ import unittest from unittest.mock import patch from io import StringIO +from tempfile import TemporaryDirectory def test_function(arg): return arg + arg @@ -102,6 +103,15 @@ def test_have_optional_dependency(self): model.initial_conditions = {v: 1} sim = pybamm.Simulation(model) sim.solve([0, 1]) + with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency anytree is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + with TemporaryDirectory() as dir_name: + sys.modules['anytree'] = None + test_stub = os.path.join(dir_name, "test_visualize") + test_name = f"{test_stub}.png" + c = pybamm.Variable("c", "negative electrode") + d = pybamm.Variable("d", "negative electrode") + sym = pybamm.div(c * pybamm.grad(c)) + (c / d + c - d) ** 5 + sym.visualise(test_name) sys.modules['pybtex'] = pybtex pybamm.util.have_optional_dependency("pybtex") From 700ab5af6b0805dc503ded489bbacf23a38cd2d7 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Fri, 10 Nov 2023 02:07:17 +0530 Subject: [PATCH 32/50] Declare `anytree` onn top to pass `test_is_constant_and_can_evaluate` --- tests/unit/test_util.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index a9f70bbcc7..b9dd428a4b 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -14,6 +14,8 @@ def test_function(arg): return arg + arg + +anytree = sys.modules['anytree'] class TestUtil(TestCase): """ Test the functionality in util.py @@ -31,6 +33,7 @@ def test_rmse(self): pybamm.rmse(np.ones(5), np.zeros(3)) def test_is_constant_and_can_evaluate(self): + sys.modules['anytree'] = anytree symbol = pybamm.PrimaryBroadcast(0, "negative electrode") self.assertEqual(False, pybamm.is_constant_and_can_evaluate(symbol)) symbol = pybamm.StateVector(slice(0, 1)) From bfbe41e36d32b9afc7d8643fd078f9fe70333bff Mon Sep 17 00:00:00 2001 From: Arjun Date: Fri, 10 Nov 2023 18:20:35 +0530 Subject: [PATCH 33/50] Apply suggestions from code review Co-authored-by: Saransh Chopra --- pybamm/citations.py | 8 ++++---- pybamm/expression_tree/printing/sympy_overrides.py | 4 ++-- pybamm/expression_tree/symbol.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/pybamm/citations.py b/pybamm/citations.py index 7d0959d89c..b72262989b 100644 --- a/pybamm/citations.py +++ b/pybamm/citations.py @@ -74,7 +74,7 @@ def read_citations(self): """Reads the citations in `pybamm.CITATIONS.bib`. Other works can be cited by passing a BibTeX citation to :meth:`register`. """ - parse_file = have_optional_dependency("pybtex.database","parse_file") + parse_file = have_optional_dependency("pybtex.database", "parse_file") citations_file = os.path.join(pybamm.root_dir(), "pybamm", "CITATIONS.bib") bib_data = parse_file(citations_file, bib_format="bibtex") for key, entry in bib_data.entries.items(): @@ -85,7 +85,7 @@ def _add_citation(self, key, entry): previous entry is overwritten """ - Entry = have_optional_dependency("pybtex.database","Entry") + Entry = have_optional_dependency("pybtex.database", "Entry") # Check input types are correct if not isinstance(key, str) or not isinstance(entry, Entry): raise TypeError() @@ -151,8 +151,8 @@ def _parse_citation(self, key): key: str A BibTeX formatted citation """ - PybtexError = have_optional_dependency("pybtex.scanner","PybtexError") - parse_string = have_optional_dependency("pybtex.database","parse_string") + PybtexError = have_optional_dependency("pybtex.scanner", "PybtexError") + parse_string = have_optional_dependency("pybtex.database", "parse_string") try: # Parse string as a bibtex citation, and check that a citation was found bib_data = parse_string(key, bib_format="bibtex") diff --git a/pybamm/expression_tree/printing/sympy_overrides.py b/pybamm/expression_tree/printing/sympy_overrides.py index 59f9567c5d..64743f557d 100644 --- a/pybamm/expression_tree/printing/sympy_overrides.py +++ b/pybamm/expression_tree/printing/sympy_overrides.py @@ -6,7 +6,7 @@ from pybamm.util import have_optional_dependency -LatexPrinter = have_optional_dependency("sympy.printing.latex","LatexPrinter") +LatexPrinter = have_optional_dependency("sympy.printing.latex", "LatexPrinter") class CustomPrint(LatexPrinter): """Override SymPy methods to match PyBaMM's requirements""" @@ -22,5 +22,5 @@ def _print_Derivative(self, expr): def custom_print_func(expr, **settings): - have_optional_dependency("sympy.printing.latex","LatexPrinter") + have_optional_dependency("sympy.printing.latex", "LatexPrinter") return CustomPrint().doprint(expr) diff --git a/pybamm/expression_tree/symbol.py b/pybamm/expression_tree/symbol.py index 85c392e590..8f1608e7ba 100644 --- a/pybamm/expression_tree/symbol.py +++ b/pybamm/expression_tree/symbol.py @@ -459,7 +459,7 @@ def visualise(self, filename): filename to output, must end in ".png" """ - DotExporter = have_optional_dependency("anytree.exporter","DotExporter") + DotExporter = have_optional_dependency("anytree.exporter", "DotExporter") # check that filename ends in .png. if filename[-4:] != ".png": raise ValueError("filename should end in .png") From 2f1d3ceea469a4de03c9143706cd9f2e23bdd14d Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Fri, 10 Nov 2023 23:40:05 +0530 Subject: [PATCH 34/50] Shorten assert string --- tests/unit/test_util.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index b9dd428a4b..ea087ad4c4 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -12,10 +12,8 @@ from io import StringIO from tempfile import TemporaryDirectory -def test_function(arg): - return arg + arg - anytree = sys.modules['anytree'] + class TestUtil(TestCase): """ Test the functionality in util.py @@ -94,11 +92,11 @@ def test_git_commit_info(self): self.assertEqual(git_commit_info[:2], "v2") def test_have_optional_dependency(self): - with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency pybtex is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency pybtex is not available."): pybtex = sys.modules['pybtex'] sys.modules['pybtex'] = None pybamm.print_citations() - with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency tqdm is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency tqdm is not available."): sys.modules['tqdm'] = None model = pybamm.BaseModel() v = pybamm.Variable("v") @@ -106,7 +104,7 @@ def test_have_optional_dependency(self): model.initial_conditions = {v: 1} sim = pybamm.Simulation(model) sim.solve([0, 1]) - with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency anytree is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details."): + with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency anytree is not available."): with TemporaryDirectory() as dir_name: sys.modules['anytree'] = None test_stub = os.path.join(dir_name, "test_visualize") From fe6b9105f0ed0abad3e7284cb94edaea46056ed9 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Sat, 11 Nov 2023 00:19:27 +0530 Subject: [PATCH 35/50] Improve readibility & add case to fix coverage --- pybamm/util.py | 5 +++-- tests/unit/test_util.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pybamm/util.py b/pybamm/util.py index dee77b8841..6c91948394 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -348,6 +348,7 @@ def install_jax(arguments=None): # pragma: no cover # https://docs.pybamm.org/en/latest/source/user_guide/contributing.html#managing-optional-dependencies-and-their-imports def have_optional_dependency(module_name, attribute=None): + err_msg = f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details." try: # Attempt to import the specified module module = importlib.import_module(module_name) @@ -359,11 +360,11 @@ def have_optional_dependency(module_name, attribute=None): return imported_attribute # Return the imported attribute else: # Raise an ModuleNotFoundError if the attribute is not available - raise ModuleNotFoundError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") + raise ModuleNotFoundError(err_msg) else: # Return the entire module if no attribute is specified return module except ModuleNotFoundError: # Raise an ModuleNotFoundError if the module or attribute is not available - raise ModuleNotFoundError(f"Optional dependency {module_name} is not available. See https://docs.pybamm.org/en/latest/source/user_guide/installation/index.html#optional-dependencies for more details.") + raise ModuleNotFoundError(err_msg) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index ea087ad4c4..b2ef72fcbc 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -118,6 +118,9 @@ def test_have_optional_dependency(self): pybamm.util.have_optional_dependency("pybtex") pybamm.print_citations() + with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency flask is not available."): + pybamm.util.have_optional_dependency("flask","Flask") + class TestSearch(TestCase): def test_url_gets_to_stdout(self): From 6239653a57fff5a4f33c35b8919ab8f4814de6a7 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Sat, 11 Nov 2023 00:58:15 +0530 Subject: [PATCH 36/50] Modify CONTRIBUTING.md for optional dependency tests --- CONTRIBUTING.md | 23 +++++++++++++++++++++++ pybamm/util.py | 2 +- 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index de0d626940..9a7e3d779d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -124,6 +124,29 @@ def use_parse_file(x, y, z): This allows people to (1) use PyBaMM without importing optional dependencies by default and (2) configure module-dependent functionalities in their scripts, which _must_ be done before e.g. `print_citations` method is first imported. +**Writing Tests for Optional Dependencies** + +Whenever a new optional dependency is added for optional functionality, it is recommended to write a corresponding unit test in _test_util.py_. This ensures that an error is raised upon the absence of said dependency. Here's an example: + +```python +from tests import TestCase +import pybamm + + +class TestUtil(TestCase): + def test_optional_dependency(self): + # Test that an error is raised when pybtex is not available + with self.assertRaisesRegex( + ModuleNotFoundError, "Optional dependency pybtex is not available" + ): + sys.modules["pybtex"] = None + pybamm.function_using_pybtex(x, y, z) + + # Test that the function works when pybtex is available + sys.modules["pybtex"] = pybamm.util.have_optional_dependency("pybtex") + pybamm.function_using_pybtex(x, y, z) +``` + ## Testing All code requires testing. We use the [unittest](https://docs.python.org/3.3/library/unittest.html) package for our tests. (These tests typically just check that the code runs without error, and so, are more _debugging_ than _testing_ in a strict sense. Nevertheless, they are very useful to have!) diff --git a/pybamm/util.py b/pybamm/util.py index 6c91948394..90cb290c6e 100644 --- a/pybamm/util.py +++ b/pybamm/util.py @@ -360,7 +360,7 @@ def have_optional_dependency(module_name, attribute=None): return imported_attribute # Return the imported attribute else: # Raise an ModuleNotFoundError if the attribute is not available - raise ModuleNotFoundError(err_msg) + raise ModuleNotFoundError(err_msg) # pragma: no cover else: # Return the entire module if no attribute is specified return module From 6f5823fed4825ab44da90a79f65a94246ffec0bb Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Sat, 11 Nov 2023 18:55:49 +0530 Subject: [PATCH 37/50] Prevent inheriting LatexPrinter instead use a function --- .../printing/sympy_overrides.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/pybamm/expression_tree/printing/sympy_overrides.py b/pybamm/expression_tree/printing/sympy_overrides.py index 64743f557d..3e89542d10 100644 --- a/pybamm/expression_tree/printing/sympy_overrides.py +++ b/pybamm/expression_tree/printing/sympy_overrides.py @@ -6,21 +6,24 @@ from pybamm.util import have_optional_dependency -LatexPrinter = have_optional_dependency("sympy.printing.latex", "LatexPrinter") -class CustomPrint(LatexPrinter): +def custom_latex_printer(expr, **settings): + latex = have_optional_dependency("sympy","latex") + Derivative = have_optional_dependency("sympy","Derivative") + if isinstance(expr, Derivative) and getattr(expr, "force_partial", False): + latex_str = latex(expr, **settings) + var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", latex_str)[0] + latex_str = latex_str.replace(var1, "\partial").replace(var2, "\partial") + return latex_str + else: + return latex(expr, **settings) + +class CustomPrint: """Override SymPy methods to match PyBaMM's requirements""" def _print_Derivative(self, expr): """Override :meth:`sympy.printing.latex.LatexPrinter._print_Derivative`""" - eqn = super()._print_Derivative(expr) - - if getattr(expr, "force_partial", False) and "partial" not in eqn: - var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", eqn)[0] - eqn = eqn.replace(var1, "\partial").replace(var2, "\partial") - - return eqn - + return custom_latex_printer(expr) def custom_print_func(expr, **settings): have_optional_dependency("sympy.printing.latex", "LatexPrinter") - return CustomPrint().doprint(expr) + return CustomPrint()._print_Derivative(expr) From c093d44614e2dfdb3f24f96357169a9cfb3d6ca3 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Sat, 11 Nov 2023 18:59:32 +0530 Subject: [PATCH 38/50] Remove redundant testcase --- tests/unit/test_util.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index b2ef72fcbc..ea087ad4c4 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -118,9 +118,6 @@ def test_have_optional_dependency(self): pybamm.util.have_optional_dependency("pybtex") pybamm.print_citations() - with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency flask is not available."): - pybamm.util.have_optional_dependency("flask","Flask") - class TestSearch(TestCase): def test_url_gets_to_stdout(self): From a5d25736d79c0571eccada2ddee7a330fbb7b2dc Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 13 Nov 2023 17:22:05 +0530 Subject: [PATCH 39/50] Add `anytree` to required & install `[plot,cite]` in `examples` session --- noxfile.py | 2 +- setup.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/noxfile.py b/noxfile.py index 430ad59659..83f4c3d717 100644 --- a/noxfile.py +++ b/noxfile.py @@ -101,7 +101,7 @@ def run_unit(session): def run_examples(session): """Run the examples tests for Jupyter notebooks.""" set_environment_variables(PYBAMM_ENV, session=session) - session.install("-e", ".[all,dev]", silent=False) + session.install("-e", ".[plot,cite,dev]", silent=False) notebooks_to_test = session.posargs if session.posargs else [] session.run("pytest", "--nbmake", *notebooks_to_test, external=True) diff --git a/setup.py b/setup.py index f6fd37f75c..fca5b83de8 100644 --- a/setup.py +++ b/setup.py @@ -207,6 +207,7 @@ def compile_KLU(): "scipy>=1.3", "casadi>=3.6.0", "xarray", + "anytree>=2.4.3", ], extras_require={ "docs": [ From a3952ddfef3361d4411a045228abbda98ef8a840 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 13 Nov 2023 20:32:26 +0530 Subject: [PATCH 40/50] Set iterator based upon `tqdm` --- noxfile.py | 2 +- pybamm/simulation.py | 18 +++++++++++++----- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/noxfile.py b/noxfile.py index 83f4c3d717..0d5d6e0d20 100644 --- a/noxfile.py +++ b/noxfile.py @@ -101,7 +101,7 @@ def run_unit(session): def run_examples(session): """Run the examples tests for Jupyter notebooks.""" set_environment_variables(PYBAMM_ENV, session=session) - session.install("-e", ".[plot,cite,dev]", silent=False) + session.install("-e", ".[plot,cite,examples,dev]", silent=False) notebooks_to_test = session.posargs if session.posargs else [] session.run("pytest", "--nbmake", *notebooks_to_test, external=True) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 0b1a6b2525..49b46f1dac 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -532,7 +532,10 @@ def solve( Additional key-word arguments passed to `solver.solve`. See :meth:`pybamm.BaseSolver.solve`. """ - tqdm = have_optional_dependency("tqdm") + try: + tqdm = have_optional_dependency("tqdm") + except ModuleNotFoundError: + tqdm = False # Setup if solver is None: solver = self._solver @@ -727,13 +730,18 @@ def solve( # Update _solution self._solution = current_solution - for cycle_num, cycle_length in enumerate( - # tqdm is the progress bar. - tqdm.tqdm( + if tqdm: + iterator = tqdm.tqdm( self.experiment.cycle_lengths, disable=(not showprogress), desc="Cycling", - ), + ) + else: + iterator = self.experiment.cycle_lengths + + for cycle_num, cycle_length in enumerate( + # tqdm is the progress bar. + iterator, start=1, ): logs["cycle number"] = ( From ae22805107b98aae6867c0cf91c4d8fcd6ba8ba6 Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Mon, 13 Nov 2023 20:47:10 +0530 Subject: [PATCH 41/50] Clean up tqdm mess --- noxfile.py | 2 +- pybamm/simulation.py | 16 ++++++---------- tests/unit/test_util.py | 6 +++--- 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/noxfile.py b/noxfile.py index 0d5d6e0d20..430ad59659 100644 --- a/noxfile.py +++ b/noxfile.py @@ -101,7 +101,7 @@ def run_unit(session): def run_examples(session): """Run the examples tests for Jupyter notebooks.""" set_environment_variables(PYBAMM_ENV, session=session) - session.install("-e", ".[plot,cite,examples,dev]", silent=False) + session.install("-e", ".[all,dev]", silent=False) notebooks_to_test = session.posargs if session.posargs else [] session.run("pytest", "--nbmake", *notebooks_to_test, external=True) diff --git a/pybamm/simulation.py b/pybamm/simulation.py index 49b46f1dac..42bda08e31 100644 --- a/pybamm/simulation.py +++ b/pybamm/simulation.py @@ -532,10 +532,6 @@ def solve( Additional key-word arguments passed to `solver.solve`. See :meth:`pybamm.BaseSolver.solve`. """ - try: - tqdm = have_optional_dependency("tqdm") - except ModuleNotFoundError: - tqdm = False # Setup if solver is None: solver = self._solver @@ -730,18 +726,18 @@ def solve( # Update _solution self._solution = current_solution - if tqdm: - iterator = tqdm.tqdm( + # check if a user has tqdm installed + if showprogress: + tqdm = have_optional_dependency("tqdm") + cycle_lengths = tqdm.tqdm( self.experiment.cycle_lengths, - disable=(not showprogress), desc="Cycling", ) else: - iterator = self.experiment.cycle_lengths + cycle_lengths = self.experiment.cycle_lengths for cycle_num, cycle_length in enumerate( - # tqdm is the progress bar. - iterator, + cycle_lengths, start=1, ): logs["cycle number"] = ( diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index ea087ad4c4..5079842003 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -92,11 +92,11 @@ def test_git_commit_info(self): self.assertEqual(git_commit_info[:2], "v2") def test_have_optional_dependency(self): - with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency pybtex is not available."): + with self.assertRaisesRegex(ModuleNotFoundError, "Optional dependency pybtex is not available."): pybtex = sys.modules['pybtex'] sys.modules['pybtex'] = None pybamm.print_citations() - with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency tqdm is not available."): + with self.assertRaisesRegex(ModuleNotFoundError, "Optional dependency tqdm is not available."): sys.modules['tqdm'] = None model = pybamm.BaseModel() v = pybamm.Variable("v") @@ -104,7 +104,7 @@ def test_have_optional_dependency(self): model.initial_conditions = {v: 1} sim = pybamm.Simulation(model) sim.solve([0, 1]) - with self.assertRaisesRegex(ModuleNotFoundError,"Optional dependency anytree is not available."): + with self.assertRaisesRegex(ModuleNotFoundError, "Optional dependency anytree is not available."): with TemporaryDirectory() as dir_name: sys.modules['anytree'] = None test_stub = os.path.join(dir_name, "test_visualize") From b5f74ad7b76fef4adb6c6496f0456dba8f182d24 Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Mon, 13 Nov 2023 21:03:17 +0530 Subject: [PATCH 42/50] Fix matplotlib errors --- pybamm/plotting/plot.py | 3 ++- pybamm/plotting/plot2D.py | 3 ++- pybamm/plotting/plot_summary_variables.py | 3 ++- pybamm/plotting/plot_voltage_components.py | 4 +++- pybamm/plotting/quick_plot.py | 18 ++++++++++-------- 5 files changed, 19 insertions(+), 12 deletions(-) diff --git a/pybamm/plotting/plot.py b/pybamm/plotting/plot.py index 19aa9dc5e0..88c8dfe442 100644 --- a/pybamm/plotting/plot.py +++ b/pybamm/plotting/plot.py @@ -3,6 +3,7 @@ # import pybamm from .quick_plot import ax_min, ax_max +from pybamm.util import have_optional_dependency def plot(x, y, ax=None, testing=False, **kwargs): @@ -25,7 +26,7 @@ def plot(x, y, ax=None, testing=False, **kwargs): Keyword arguments, passed to plt.plot """ - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib.pyplot") if not isinstance(x, pybamm.Array): raise TypeError("x must be 'pybamm.Array'") diff --git a/pybamm/plotting/plot2D.py b/pybamm/plotting/plot2D.py index 80bb5d0ee2..d4f6d31e3a 100644 --- a/pybamm/plotting/plot2D.py +++ b/pybamm/plotting/plot2D.py @@ -3,6 +3,7 @@ # import pybamm from .quick_plot import ax_min, ax_max +from pybamm.util import have_optional_dependency def plot2D(x, y, z, ax=None, testing=False, **kwargs): @@ -25,7 +26,7 @@ def plot2D(x, y, z, ax=None, testing=False, **kwargs): Whether to actually make the plot (turned off for unit tests) """ - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib.pyplot") if not isinstance(x, pybamm.Array): raise TypeError("x must be 'pybamm.Array'") diff --git a/pybamm/plotting/plot_summary_variables.py b/pybamm/plotting/plot_summary_variables.py index 6fe71518db..e50f38fddf 100644 --- a/pybamm/plotting/plot_summary_variables.py +++ b/pybamm/plotting/plot_summary_variables.py @@ -3,6 +3,7 @@ # import numpy as np import pybamm +from pybamm.util import have_optional_dependency def plot_summary_variables( @@ -25,7 +26,7 @@ def plot_summary_variables( Keyword arguments, passed to plt.subplots. """ - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib.pyplot") if isinstance(solutions, pybamm.Solution): solutions = [solutions] diff --git a/pybamm/plotting/plot_voltage_components.py b/pybamm/plotting/plot_voltage_components.py index ad0e9a8b71..a681094bea 100644 --- a/pybamm/plotting/plot_voltage_components.py +++ b/pybamm/plotting/plot_voltage_components.py @@ -3,6 +3,8 @@ # import numpy as np +from pybamm.util import have_optional_dependency + def plot_voltage_components( solution, @@ -32,7 +34,7 @@ def plot_voltage_components( Keyword arguments, passed to ax.fill_between """ - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib.pyplot") # Set a default value for alpha, the opacity kwargs_fill = {"alpha": 0.6, **kwargs_fill} diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index 5e9c9ef941..00a07d16a1 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -5,6 +5,7 @@ import numpy as np import pybamm from collections import defaultdict +from pybamm.util import have_optional_dependency class LoopList(list): @@ -46,7 +47,7 @@ def split_long_string(title, max_words=None): def close_plots(): """Close all open figures""" - import matplotlib.pyplot as plt + plt = have_optional_dependency("matplotlib", "pyplot") plt.close("all") @@ -469,9 +470,10 @@ def plot(self, t, dynamic=False): Dimensional time (in 'time_units') at which to plot. """ - import matplotlib.pyplot as plt - import matplotlib.gridspec as gridspec - from matplotlib import cm, colors + plt = have_optional_dependency("matplotlib.pyplot") + gridspec = have_optional_dependency("matplotlib.gridspec") + cm = have_optional_dependency("matplotlib", "cm") + colors = have_optional_dependency("matplotlib", "colors") t_in_seconds = t * self.time_scaling_factor self.fig = plt.figure(figsize=self.figsize) @@ -668,8 +670,8 @@ def dynamic_plot(self, testing=False, step=None): continuous_update=False, ) else: - import matplotlib.pyplot as plt - from matplotlib.widgets import Slider + plt = have_optional_dependency("matplotlib.pyplot") + Slider = have_optional_dependency("matplotlib.widgets", "Slider") # create an initial plot at time self.min_t self.plot(self.min_t, dynamic=True) @@ -773,8 +775,8 @@ def create_gif(self, number_of_images=80, duration=0.1, output_filename="plot.gi Name of the generated GIF file. """ - import imageio.v2 as imageio - import matplotlib.pyplot as plt + imageio = have_optional_dependency("imageio.v2") + plt = have_optional_dependency("matplotlib.pyplot") # time stamps at which the images/plots will be created time_array = np.linspace(self.min_t, self.max_t, num=number_of_images) From 78792bcb3aa73840e2db9378792a67a5ae91739f Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Mon, 13 Nov 2023 21:13:14 +0530 Subject: [PATCH 43/50] Apply suggestions from code review --- CONTRIBUTING.md | 4 ++-- pybamm/expression_tree/printing/sympy_overrides.py | 4 ++-- pybamm/plotting/quick_plot.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9a7e3d779d..0a5b17bcb0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -54,7 +54,7 @@ You now have everything you need to start making changes! 10. [Test your code!](#testing) 11. PyBaMM has online documentation at http://docs.pybamm.org/. To make sure any new methods or classes you added show up there, please read the [documentation](#documentation) section. 12. If you added a major new feature, perhaps it should be showcased in an [example notebook](#example-notebooks). -13. When you feel your code is finished, or at least warrants serious discussion, run the [pre-commit checks](#pre-commit-checks) and then create a [pull request](https://help.github.com/articles/about-pull-requests/) (PR) on [PyBaMM's GitHub page](https://github.com/pybamm-team/PyBaMM). +13. When you feel your code is finished, or at least warrants serious discussion, run the [pre-commit checks](#pre-commit-checks) and then create a [pull request](https://help.github.com/articles/about-pull-requests/) (PR) on [PyBaMM's GitHub page](https://github.com/pybamm-team/PyBaMM). 14. Once a PR has been created, it will be reviewed by any member of the community. Changes might be suggested which you can make by simply adding new commits to the branch. When everything's finished, someone with the right GitHub permissions will merge your changes into PyBaMM main repository. Finally, if you really, really, _really_ love developing PyBaMM, have a look at the current [project infrastructure](#infrastructure). @@ -126,7 +126,7 @@ This allows people to (1) use PyBaMM without importing optional dependencies by **Writing Tests for Optional Dependencies** -Whenever a new optional dependency is added for optional functionality, it is recommended to write a corresponding unit test in _test_util.py_. This ensures that an error is raised upon the absence of said dependency. Here's an example: +Whenever a new optional dependency is added for optional functionality, it is recommended to write a corresponding unit test in `test_util.py`. This ensures that an error is raised upon the absence of said dependency. Here's an example: ```python from tests import TestCase diff --git a/pybamm/expression_tree/printing/sympy_overrides.py b/pybamm/expression_tree/printing/sympy_overrides.py index 3e89542d10..ec70de22b2 100644 --- a/pybamm/expression_tree/printing/sympy_overrides.py +++ b/pybamm/expression_tree/printing/sympy_overrides.py @@ -7,8 +7,8 @@ def custom_latex_printer(expr, **settings): - latex = have_optional_dependency("sympy","latex") - Derivative = have_optional_dependency("sympy","Derivative") + latex = have_optional_dependency("sympy", "latex") + Derivative = have_optional_dependency("sympy", "Derivative") if isinstance(expr, Derivative) and getattr(expr, "force_partial", False): latex_str = latex(expr, **settings) var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", latex_str)[0] diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py index 00a07d16a1..ff657ee375 100644 --- a/pybamm/plotting/quick_plot.py +++ b/pybamm/plotting/quick_plot.py @@ -47,7 +47,7 @@ def split_long_string(title, max_words=None): def close_plots(): """Close all open figures""" - plt = have_optional_dependency("matplotlib", "pyplot") + plt = have_optional_dependency("matplotlib.pyplot") plt.close("all") From a8ac4c784d4b9299d1c1f8beb7e656b1be7b1cc1 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Mon, 13 Nov 2023 21:34:56 +0530 Subject: [PATCH 44/50] Remove test for tqdm as ModuleNotFoundError no longer being raised for `Simulation.solve()` --- tests/unit/test_util.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/unit/test_util.py b/tests/unit/test_util.py index 5079842003..730e4cc08d 100644 --- a/tests/unit/test_util.py +++ b/tests/unit/test_util.py @@ -96,14 +96,6 @@ def test_have_optional_dependency(self): pybtex = sys.modules['pybtex'] sys.modules['pybtex'] = None pybamm.print_citations() - with self.assertRaisesRegex(ModuleNotFoundError, "Optional dependency tqdm is not available."): - sys.modules['tqdm'] = None - model = pybamm.BaseModel() - v = pybamm.Variable("v") - model.rhs = {v: -v} - model.initial_conditions = {v: 1} - sim = pybamm.Simulation(model) - sim.solve([0, 1]) with self.assertRaisesRegex(ModuleNotFoundError, "Optional dependency anytree is not available."): with TemporaryDirectory() as dir_name: sys.modules['anytree'] = None From bd2d009c74d29a326857277129f2c713f8a1020e Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Mon, 13 Nov 2023 22:55:28 +0530 Subject: [PATCH 45/50] Fix sympy overrides --- .../printing/sympy_overrides.py | 28 +++++-------- pybamm/models/base_model.py | 39 ++++++++++++++++--- 2 files changed, 44 insertions(+), 23 deletions(-) diff --git a/pybamm/expression_tree/printing/sympy_overrides.py b/pybamm/expression_tree/printing/sympy_overrides.py index 3e89542d10..d127534a0e 100644 --- a/pybamm/expression_tree/printing/sympy_overrides.py +++ b/pybamm/expression_tree/printing/sympy_overrides.py @@ -3,27 +3,19 @@ # import re -from pybamm.util import have_optional_dependency +from sympy.printing.latex import LatexPrinter -def custom_latex_printer(expr, **settings): - latex = have_optional_dependency("sympy","latex") - Derivative = have_optional_dependency("sympy","Derivative") - if isinstance(expr, Derivative) and getattr(expr, "force_partial", False): - latex_str = latex(expr, **settings) - var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", latex_str)[0] - latex_str = latex_str.replace(var1, "\partial").replace(var2, "\partial") - return latex_str - else: - return latex(expr, **settings) +class CustomPrint(LatexPrinter): + """Override SymPy methods to match PyBaMM's requirements""" + def _print_Derivative(self, expr): + """Override :meth:`sympy.printing.latex.LatexPrinter._print_Derivative`""" + eqn = super()._print_Derivative(expr) + if getattr(expr, "force_partial", False) and "partial" not in eqn: + var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", eqn)[0] + eqn = eqn.replace(var1, "\partial").replace(var2, "\partial") -class CustomPrint: - """Override SymPy methods to match PyBaMM's requirements""" - - def _print_Derivative(self, expr): - """Override :meth:`sympy.printing.latex.LatexPrinter._print_Derivative`""" - return custom_latex_printer(expr) + return eqn def custom_print_func(expr, **settings): - have_optional_dependency("sympy.printing.latex", "LatexPrinter") return CustomPrint()._print_Derivative(expr) diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 41192dbe1f..08890757b7 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -9,7 +9,7 @@ import numpy as np import pybamm -from pybamm.expression_tree.operations.latexify import Latexify +from pybamm.util import have_optional_dependency class BaseModel: @@ -1055,14 +1055,43 @@ def generate( C.generate() def latexify(self, filename=None, newline=True, output_variables=None): - # For docstring, see pybamm.expression_tree.operations.latexify.Latexify + """ + Converts all model equations in latex. + + Parameters + ---------- + filename: str (optional) + Accepted file formats - any image format, pdf and tex + Default is None, When None returns all model equations in latex + If not None, returns all model equations in given file format. + + newline: bool (optional) + Default is True, If True, returns every equation in a new line. + If False, returns the list of all the equations. + + Load model + >>> model = pybamm.lithium_ion.SPM() + + This will returns all model equations in png + >>> model.latexify("equations.png") + + This will return all the model equations in latex + >>> model.latexify() + + This will return the list of all the model equations + >>> model.latexify(newline=False) + + This will return first five model equations + >>> model.latexify(newline=False)[1:5] + """ + sympy = have_optional_dependency("sympy") + if sympy: + from pybamm.expression_tree.operations.latexify import Latexify + return Latexify(self, filename, newline).latexify( output_variables=output_variables ) - # Set :meth:`latexify` docstring from :class:`Latexify` - latexify.__doc__ = Latexify.__doc__ - def process_parameters_and_discretise(self, symbol, parameter_values, disc): """ Process parameters and discretise a symbol using supplied parameter values From 9d342c0f34ef4b7a579a011526d5441c0069bfc9 Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Mon, 13 Nov 2023 23:00:40 +0530 Subject: [PATCH 46/50] fix tabs --- .../expression_tree/printing/sympy_overrides.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pybamm/expression_tree/printing/sympy_overrides.py b/pybamm/expression_tree/printing/sympy_overrides.py index d127534a0e..e189e536d7 100644 --- a/pybamm/expression_tree/printing/sympy_overrides.py +++ b/pybamm/expression_tree/printing/sympy_overrides.py @@ -7,15 +7,15 @@ class CustomPrint(LatexPrinter): - """Override SymPy methods to match PyBaMM's requirements""" - def _print_Derivative(self, expr): - """Override :meth:`sympy.printing.latex.LatexPrinter._print_Derivative`""" - eqn = super()._print_Derivative(expr) - if getattr(expr, "force_partial", False) and "partial" not in eqn: - var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", eqn)[0] - eqn = eqn.replace(var1, "\partial").replace(var2, "\partial") + """Override SymPy methods to match PyBaMM's requirements""" + def _print_Derivative(self, expr): + """Override :meth:`sympy.printing.latex.LatexPrinter._print_Derivative`""" + eqn = super()._print_Derivative(expr) + if getattr(expr, "force_partial", False) and "partial" not in eqn: + var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", eqn)[0] + eqn = eqn.replace(var1, "\partial").replace(var2, "\partial") - return eqn + return eqn def custom_print_func(expr, **settings): return CustomPrint()._print_Derivative(expr) From 2e30131d7ef35c19f538065fc894dc4bf32236f8 Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Mon, 13 Nov 2023 23:05:22 +0530 Subject: [PATCH 47/50] Fix CustomPrinter --- pybamm/expression_tree/printing/sympy_overrides.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pybamm/expression_tree/printing/sympy_overrides.py b/pybamm/expression_tree/printing/sympy_overrides.py index e189e536d7..678d4f5a37 100644 --- a/pybamm/expression_tree/printing/sympy_overrides.py +++ b/pybamm/expression_tree/printing/sympy_overrides.py @@ -12,10 +12,10 @@ def _print_Derivative(self, expr): """Override :meth:`sympy.printing.latex.LatexPrinter._print_Derivative`""" eqn = super()._print_Derivative(expr) if getattr(expr, "force_partial", False) and "partial" not in eqn: - var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", eqn)[0] - eqn = eqn.replace(var1, "\partial").replace(var2, "\partial") + var1, var2 = re.findall(r"^\\frac{(\w+)}{(\w+) .+", eqn)[0] + eqn = eqn.replace(var1, "\partial").replace(var2, "\partial") return eqn def custom_print_func(expr, **settings): - return CustomPrint()._print_Derivative(expr) + return CustomPrint().doprint(expr) From 47113627907f9f2baf7936b82a202ab1fd1bbaf3 Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Mon, 13 Nov 2023 23:30:31 +0530 Subject: [PATCH 48/50] Fix test --- pybamm/expression_tree/printing/sympy_overrides.py | 1 + .../test_expression_tree/test_operations/test_latexify.py | 4 ---- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/pybamm/expression_tree/printing/sympy_overrides.py b/pybamm/expression_tree/printing/sympy_overrides.py index 678d4f5a37..1898822ea8 100644 --- a/pybamm/expression_tree/printing/sympy_overrides.py +++ b/pybamm/expression_tree/printing/sympy_overrides.py @@ -17,5 +17,6 @@ def _print_Derivative(self, expr): return eqn + def custom_print_func(expr, **settings): return CustomPrint().doprint(expr) diff --git a/tests/unit/test_expression_tree/test_operations/test_latexify.py b/tests/unit/test_expression_tree/test_operations/test_latexify.py index be7cc21115..7e0703534e 100644 --- a/tests/unit/test_expression_tree/test_operations/test_latexify.py +++ b/tests/unit/test_expression_tree/test_operations/test_latexify.py @@ -8,7 +8,6 @@ import uuid import pybamm -from pybamm.expression_tree.operations.latexify import Latexify class TestLatexify(TestCase): @@ -19,9 +18,6 @@ def test_latexify(self): model_spme = pybamm.lithium_ion.SPMe() func_spme = str(model_spme.latexify()) - # Test docstring - self.assertEqual(pybamm.BaseModel.latexify.__doc__, Latexify.__doc__) - # Test model name self.assertIn("Single Particle Model with electrolyte Equations", func_spme) From 05c74e41e54a88bea8d0cfaaf24ee93b3f3457b4 Mon Sep 17 00:00:00 2001 From: "arjxn.py" Date: Tue, 14 Nov 2023 01:07:30 +0530 Subject: [PATCH 49/50] Add `anytree` to required deps in docs --- docs/source/user_guide/installation/index.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index 6338323e79..2cf61093be 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -66,6 +66,7 @@ Package Minimum support `SciPy `__ 2.8.2 `CasADi `__ 3.6.0 `Xarray `__ 2023.04.0 +`Anytree `__ 2.4.3 ================================================================ ========================== .. _install.optional_dependencies: From 4d118abc78115b9f27ffd970d64d52ca697294a9 Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Tue, 14 Nov 2023 14:42:02 +0530 Subject: [PATCH 50/50] Fix CHANGELOG --- CHANGELOG.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index b02df8ed4c..afbc5073b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,7 @@ ## Bug fixes -- Fixed a bug where the JaxSolver would fails when using GPU support with no input parameters ([#3423](https://github.com/pybamm-team/PyBaMM/pull/3423)) +- Fixed bug in calculation of theoretical energy that made it very slow ([#3506](https://github.com/pybamm-team/PyBaMM/pull/3506)) # [v23.9rc0](https://github.com/pybamm-team/PyBaMM/tree/v23.9rc0) - 2023-10-31 @@ -23,6 +23,7 @@ ## Bug fixes +- Fixed a bug where the JaxSolver would fail when using GPU support with no input parameters ([#3423](https://github.com/pybamm-team/PyBaMM/pull/3423)) - Fixed a bug where empty lists passed to QuickPlot resulted in an IndexError and did not return a meaningful error message ([#3359](https://github.com/pybamm-team/PyBaMM/pull/3359)) - Fixed a bug where there was a missing thermal conductivity in the thermal pouch cell models ([#3330](https://github.com/pybamm-team/PyBaMM/pull/3330)) - Fixed a bug that caused incorrect results of “{Domain} electrode thickness change [m]” due to the absence of dimension for the variable `electrode_thickness_change`([#3329](https://github.com/pybamm-team/PyBaMM/pull/3329)). @@ -61,7 +62,7 @@ - Added option to use an empirical hysteresis model for the diffusivity and exchange-current density ([#3194](https://github.com/pybamm-team/PyBaMM/pull/3194)) - Double-layer capacity can now be provided as a function of temperature ([#3174](https://github.com/pybamm-team/PyBaMM/pull/3174)) - `pybamm_install_jax` is deprecated. It is now replaced with `pip install pybamm[jax]` ([#3163](https://github.com/pybamm-team/PyBaMM/pull/3163)) -- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044)) +- PyBaMM now has optional dependencies that can be installed with the pattern `pip install pybamm[option]` e.g. `pybamm[plot]` ([#3044](https://github.com/pybamm-team/PyBaMM/pull/3044), [#3475](https://github.com/pybamm-team/PyBaMM/pull/3475)) # [v23.5](https://github.com/pybamm-team/PyBaMM/tree/v23.5) - 2023-06-18