From 9b6652b025accd78cc9bb71d4192d19b3ca81e03 Mon Sep 17 00:00:00 2001 From: Santhosh Sundaram Date: Tue, 9 Jul 2024 23:00:44 +0530 Subject: [PATCH 1/6] resolved merge conflicts --- pyproject.toml | 3 + src/pybamm_cookiecutter/__init__.py | 4 + src/pybamm_cookiecutter/models/__init__.py | 1 + src/pybamm_cookiecutter/models/input/SPM.py | 182 ++++++++++++++++++++ src/pybamm_cookiecutter/models/model.py | 72 ++++++++ tests/test_model_entry_points.py | 21 +++ tests/test_parameter_entry_points.py | 27 +++ 7 files changed, 310 insertions(+) create mode 100644 src/pybamm_cookiecutter/models/__init__.py create mode 100644 src/pybamm_cookiecutter/models/input/SPM.py create mode 100644 src/pybamm_cookiecutter/models/model.py create mode 100644 tests/test_model_entry_points.py create mode 100644 tests/test_parameter_entry_points.py diff --git a/pyproject.toml b/pyproject.toml index dc0bda6..9228987 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,9 @@ Changelog = "https://github.com/pybamm-team/pybamm-cookiecutter/releases" [project.entry-points."cookie_parameter_sets"] Chen2020 = "pybamm_cookiecutter.parameters.input.Chen2020:get_parameter_values" +[project.entry-points."model_entry_points"] +SPM = "pybamm_cookiecutter.models.input.SPM:SPM" + [tool.hatch] version.source = "vcs" build.hooks.vcs.version-file = "src/pybamm_cookiecutter/_version.py" diff --git a/src/pybamm_cookiecutter/__init__.py b/src/pybamm_cookiecutter/__init__.py index 6dd71d2..3a99d69 100644 --- a/src/pybamm_cookiecutter/__init__.py +++ b/src/pybamm_cookiecutter/__init__.py @@ -9,9 +9,13 @@ from ._version import version as __version__ from .parameters.parameter_sets import parameter_sets +from .models.model import models, model_instance __all__ : list[str] = [ "__version__", "pybamm", "parameter_sets", + "models", + "model_instance", + ] diff --git a/src/pybamm_cookiecutter/models/__init__.py b/src/pybamm_cookiecutter/models/__init__.py new file mode 100644 index 0000000..545acc4 --- /dev/null +++ b/src/pybamm_cookiecutter/models/__init__.py @@ -0,0 +1 @@ +__all__ = ["model",] diff --git a/src/pybamm_cookiecutter/models/input/SPM.py b/src/pybamm_cookiecutter/models/input/SPM.py new file mode 100644 index 0000000..92bf2b2 --- /dev/null +++ b/src/pybamm_cookiecutter/models/input/SPM.py @@ -0,0 +1,182 @@ +# +# Basic Single Particle Model (SPM) +# +import pybamm + + +class SPM(pybamm.lithium_ion.BaseModel): + """Single Particle Model (SPM) model of a lithium-ion battery, from + :footcite:t:`Marquis2019`. + + This class differs from the :class:`pybamm.lithium_ion.SPM` model class in that it + shows the whole model in a single class. This comes at the cost of flexibility in + combining different physical effects, and in general the main SPM class should be + used instead. + + Parameters + ---------- + name : str, optional + The name of the model. + """ + + def __init__(self, name="Single Particle Model"): + super().__init__({}, name) + pybamm.citations.register("Marquis2019") + # `param` is a class containing all the relevant parameters and functions for + # this model. These are purely symbolic at this stage, and will be set by the + # `ParameterValues` class when the model is processed. + param = self.param + + ###################### + # Variables + ###################### + # Variables that depend on time only are created without a domain + Q = pybamm.Variable("Discharge capacity [A.h]") + # Variables that vary spatially are created with a domain + c_s_n = pybamm.Variable( + "X-averaged negative particle concentration [mol.m-3]", + domain="negative particle", + ) + c_s_p = pybamm.Variable( + "X-averaged positive particle concentration [mol.m-3]", + domain="positive particle", + ) + + # Constant temperature + T = param.T_init + + ###################### + # Other set-up + ###################### + + # Current density + i_cell = param.current_density_with_time + a_n = 3 * param.n.prim.epsilon_s_av / param.n.prim.R_typ + a_p = 3 * param.p.prim.epsilon_s_av / param.p.prim.R_typ + j_n = i_cell / (param.n.L * a_n) + j_p = -i_cell / (param.p.L * a_p) + + ###################### + # State of Charge + ###################### + I = param.current_with_time + # The `rhs` dictionary contains differential equations, with the key being the + # variable in the d/dt + self.rhs[Q] = I / 3600 + # Initial conditions must be provided for the ODEs + self.initial_conditions[Q] = pybamm.Scalar(0) + + ###################### + # Particles + ###################### + + # The div and grad operators will be converted to the appropriate matrix + # multiplication at the discretisation stage + N_s_n = -param.n.prim.D(c_s_n, T) * pybamm.grad(c_s_n) + N_s_p = -param.p.prim.D(c_s_p, T) * pybamm.grad(c_s_p) + self.rhs[c_s_n] = -pybamm.div(N_s_n) + self.rhs[c_s_p] = -pybamm.div(N_s_p) + # Surf takes the surface value of a variable, i.e. its boundary value on the + # right side. This is also accessible via `boundary_value(x, "right")`, with + # "left" providing the boundary value of the left side + c_s_surf_n = pybamm.surf(c_s_n) + c_s_surf_p = pybamm.surf(c_s_p) + # Boundary conditions must be provided for equations with spatial derivatives + self.boundary_conditions[c_s_n] = { + "left": (pybamm.Scalar(0), "Neumann"), + "right": ( + -j_n / (param.F * pybamm.surf(param.n.prim.D(c_s_n, T))), + "Neumann", + ), + } + self.boundary_conditions[c_s_p] = { + "left": (pybamm.Scalar(0), "Neumann"), + "right": ( + -j_p / (param.F * pybamm.surf(param.p.prim.D(c_s_p, T))), + "Neumann", + ), + } + # c_n_init and c_p_init are functions of r and x, but for the SPM we + # take the x-averaged value since there is no x-dependence in the particles + self.initial_conditions[c_s_n] = pybamm.x_average(param.n.prim.c_init) + self.initial_conditions[c_s_p] = pybamm.x_average(param.p.prim.c_init) + # Events specify points at which a solution should terminate + sto_surf_n = c_s_surf_n / param.n.prim.c_max + sto_surf_p = c_s_surf_p / param.p.prim.c_max + self.events += [ + pybamm.Event( + "Minimum negative particle surface stoichiometry", + pybamm.min(sto_surf_n) - 0.01, + ), + pybamm.Event( + "Maximum negative particle surface stoichiometry", + (1 - 0.01) - pybamm.max(sto_surf_n), + ), + pybamm.Event( + "Minimum positive particle surface stoichiometry", + pybamm.min(sto_surf_p) - 0.01, + ), + pybamm.Event( + "Maximum positive particle surface stoichiometry", + (1 - 0.01) - pybamm.max(sto_surf_p), + ), + ] + + # Note that the SPM does not have any algebraic equations, so the `algebraic` + # dictionary remains empty + + ###################### + # (Some) variables + ###################### + # Interfacial reactions + RT_F = param.R * T / param.F + j0_n = param.n.prim.j0(param.c_e_init_av, c_s_surf_n, T) + j0_p = param.p.prim.j0(param.c_e_init_av, c_s_surf_p, T) + eta_n = (2 / param.n.prim.ne) * RT_F * pybamm.arcsinh(j_n / (2 * j0_n)) + eta_p = (2 / param.p.prim.ne) * RT_F * pybamm.arcsinh(j_p / (2 * j0_p)) + phi_s_n = 0 + phi_e = -eta_n - param.n.prim.U(sto_surf_n, T) + phi_s_p = eta_p + phi_e + param.p.prim.U(sto_surf_p, T) + V = phi_s_p + num_cells = pybamm.Parameter( + "Number of cells connected in series to make a battery" + ) + + whole_cell = ["negative electrode", "separator", "positive electrode"] + # The `variables` dictionary contains all variables that might be useful for + # visualising the solution of the model + # Primary broadcasts are used to broadcast scalar quantities across a domain + # into a vector of the right shape, for multiplying with other vectors + self.variables = { + "Time [s]": pybamm.t, + "Discharge capacity [A.h]": Q, + "X-averaged negative particle concentration [mol.m-3]": c_s_n, + "Negative particle surface " + "concentration [mol.m-3]": pybamm.PrimaryBroadcast( + c_s_surf_n, "negative electrode" + ), + "Electrolyte concentration [mol.m-3]": pybamm.PrimaryBroadcast( + param.c_e_init_av, whole_cell + ), + "X-averaged positive particle concentration [mol.m-3]": c_s_p, + "Positive particle surface " + "concentration [mol.m-3]": pybamm.PrimaryBroadcast( + c_s_surf_p, "positive electrode" + ), + "Current [A]": I, + "Current variable [A]": I, # for compatibility with pybamm.Experiment + "Negative electrode potential [V]": pybamm.PrimaryBroadcast( + phi_s_n, "negative electrode" + ), + "Electrolyte potential [V]": pybamm.PrimaryBroadcast(phi_e, whole_cell), + "Positive electrode potential [V]": pybamm.PrimaryBroadcast( + phi_s_p, "positive electrode" + ), + "Voltage [V]": V, + "Battery voltage [V]": V * num_cells, + } + # Events specify points at which a solution should terminate + self.events += [ + pybamm.Event("Minimum voltage [V]", V - param.voltage_low_cut), + pybamm.Event("Maximum voltage [V]", param.voltage_high_cut - V), + ] diff --git a/src/pybamm_cookiecutter/models/model.py b/src/pybamm_cookiecutter/models/model.py new file mode 100644 index 0000000..31a01a2 --- /dev/null +++ b/src/pybamm_cookiecutter/models/model.py @@ -0,0 +1,72 @@ +import importlib.metadata +from collections.abc import Mapping +from typing import Callable +import textwrap + +class ModelEntryPoints(Mapping): + + def __init__(self): + self._all_models = dict() + for entry_point in self.get_entries("model_entry_points"): + self._all_models[entry_point.name] = entry_point + print(self._all_models) + + @staticmethod + def get_entries(group_name): + return importlib.metadata.entry_points(group=group_name) + + def __new__(cls): + """Ensure only one instance""" + if not hasattr(cls, "instance"): + cls.instance = super().__new__(cls) + return cls.instance + + def __getitem__(self, key) -> dict: + return self._load_entry_point(key)() + + def _load_entry_point(self, key) -> Callable: + """Check that ``key`` is a registered ``model_entry_points``, + and return the entry point for the model, loading it needed. + """ + if key not in self._all_models: + raise KeyError(f"Unknown Model: {key}") + model = self._all_models[key] + try: + model = self._all_models[key] = model.load() + except AttributeError: + pass + return model + + def __iter__(self): + return self._all_models.__iter__() + + def __len__(self) -> int: + return len(self._all_models) + + def get_docstring(self, key): + """Return the docstring for the ``key`` model""" + return textwrap.dedent(self._load_entry_point(key).__doc__) + + def __getattribute__(self, name): + try: + return super().__getattribute__(name) + except AttributeError as error: + raise error + +#: Singleton Instance of :class:ModelEntryPoints""" +model_instance = ModelEntryPoints() + +def models(model:str): + """ + Returns the loaded model object + + Parameters + ---------- + model : str + The model name or author name of the model mentioned at the model entry point. + Returns + ------- + pybamm.model + Model object of the initialised model. + """ + return model_instance[model] diff --git a/tests/test_model_entry_points.py b/tests/test_model_entry_points.py new file mode 100644 index 0000000..0feaa32 --- /dev/null +++ b/tests/test_model_entry_points.py @@ -0,0 +1,21 @@ +import pytest +import pybamm_cookiecutter +import importlib.util +import sys +from pathlib import Path + +def test_entry_points(): + """Test if the entry points are loaded correctly.""" + + entry_points = list(pybamm_cookiecutter.model_instance) + models = Path("src/pybamm_cookiecutter/models/input/").glob("*.py") + # Making a list Parameter sets in the parameters/input directory + models = [x.stem for x in models] + + assert models == entry_points, "Entry points missing either in pyproject.toml or in the input directory" + +def test_entry_point_load(): + """Testing if the values get loaded via parameter entry points and are equal when loaded through entry points""" + # Loading parameter_sets through entry points + model_instance = pybamm_cookiecutter.models("SPM") + assert model_instance != None diff --git a/tests/test_parameter_entry_points.py b/tests/test_parameter_entry_points.py new file mode 100644 index 0000000..58cee98 --- /dev/null +++ b/tests/test_parameter_entry_points.py @@ -0,0 +1,27 @@ +import pytest +import pybamm_cookiecutter +import importlib.util +import sys +from pathlib import Path + +def test_entry_points(): + """Test if the entry points are loaded correctly.""" + + entry_points = list(pybamm_cookiecutter.parameter_sets) + parameter_sets = Path("src/pybamm_cookiecutter/parameters/input/").glob("*.py") + # Making a list Parameter sets in the parameters/input directory + parameter_sets = [x.stem for x in parameter_sets] + + assert parameter_sets == entry_points, "Entry points missing either in pyproject.toml or in the input directory" + +def test_entry_point_load(): + """Testing if the values get loaded via parameter entry points and are equal when loaded through entry points""" + # Loading parameter_sets through entry points + parameters = pybamm_cookiecutter.parameter_sets['Chen2020'] + # Loading parameter sets through the source file by dynamically loading Chen2020.py as a module + spec = importlib.util.spec_from_file_location("Chen2020mod", "src/pybamm_cookiecutter/parameters/input/Chen2020.py") + chen_module = importlib.util.module_from_spec(spec) + sys.modules["Chen2020mod"] = chen_module + spec.loader.exec_module(chen_module) + parameters_from_file = chen_module.get_parameter_values() + assert parameters.keys() == parameters_from_file.keys(), "The keys in the module and local input file are not the same" From 74bad64827e99f3c122b3fca260e7689ec755a73 Mon Sep 17 00:00:00 2001 From: Santhosh Sundaram Date: Tue, 9 Jul 2024 23:05:41 +0530 Subject: [PATCH 2/6] Added model entry points --- src/pybamm_cookiecutter/models/input/SPM.py | 33 ++++++++++++++++++++ src/pybamm_cookiecutter/models/model.py | 34 ++++++++++++++++++++- 2 files changed, 66 insertions(+), 1 deletion(-) diff --git a/src/pybamm_cookiecutter/models/input/SPM.py b/src/pybamm_cookiecutter/models/input/SPM.py index 92bf2b2..e94741a 100644 --- a/src/pybamm_cookiecutter/models/input/SPM.py +++ b/src/pybamm_cookiecutter/models/input/SPM.py @@ -1,3 +1,36 @@ +""" +This code is adopted from the PyBaMM project under the BSD-3-Clause + +Copyright (c) 2018-2024, the PyBaMM team. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + # # Basic Single Particle Model (SPM) # diff --git a/src/pybamm_cookiecutter/models/model.py b/src/pybamm_cookiecutter/models/model.py index 31a01a2..c517593 100644 --- a/src/pybamm_cookiecutter/models/model.py +++ b/src/pybamm_cookiecutter/models/model.py @@ -1,3 +1,36 @@ +""" +This code is adopted from the PyBaMM project under the BSD-3-Clause + +Copyright (c) 2018-2024, the PyBaMM team. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +""" + + import importlib.metadata from collections.abc import Mapping from typing import Callable @@ -9,7 +42,6 @@ def __init__(self): self._all_models = dict() for entry_point in self.get_entries("model_entry_points"): self._all_models[entry_point.name] = entry_point - print(self._all_models) @staticmethod def get_entries(group_name): From 679b8b61513b769a16a60a63acf481ed082036d7 Mon Sep 17 00:00:00 2001 From: Santhosh Sundaram Date: Sat, 13 Jul 2024 20:06:14 +0530 Subject: [PATCH 3/6] unified entry points for parameter sets and models --- pyproject.toml | 2 +- src/pybamm_cookiecutter/__init__.py | 6 +- .../parameter_sets.py => entry_point.py} | 70 ++++++++---- src/pybamm_cookiecutter/models/__init__.py | 1 - src/pybamm_cookiecutter/models/input/SPM.py | 1 - src/pybamm_cookiecutter/models/model.py | 104 ------------------ .../parameters/__init__.py | 3 - tests/test_model_entry_points.py | 4 +- 8 files changed, 53 insertions(+), 138 deletions(-) rename src/pybamm_cookiecutter/{parameters/parameter_sets.py => entry_point.py} (60%) delete mode 100644 src/pybamm_cookiecutter/models/__init__.py delete mode 100644 src/pybamm_cookiecutter/models/model.py delete mode 100644 src/pybamm_cookiecutter/parameters/__init__.py diff --git a/pyproject.toml b/pyproject.toml index 9228987..797a805 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ Homepage = "https://github.com/pybamm-team/pybamm-cookiecutter" Discussions = "https://github.com/pybamm-team/pybamm-cookiecutter/discussions" Changelog = "https://github.com/pybamm-team/pybamm-cookiecutter/releases" -[project.entry-points."cookie_parameter_sets"] +[project.entry-points."parameter_sets"] Chen2020 = "pybamm_cookiecutter.parameters.input.Chen2020:get_parameter_values" [project.entry-points."model_entry_points"] diff --git a/src/pybamm_cookiecutter/__init__.py b/src/pybamm_cookiecutter/__init__.py index 3a99d69..6e69f35 100644 --- a/src/pybamm_cookiecutter/__init__.py +++ b/src/pybamm_cookiecutter/__init__.py @@ -8,14 +8,12 @@ import pybamm from ._version import version as __version__ -from .parameters.parameter_sets import parameter_sets -from .models.model import models, model_instance +from .entry_point import Model, parameter_sets, models __all__ : list[str] = [ "__version__", "pybamm", "parameter_sets", + "Model", "models", - "model_instance", - ] diff --git a/src/pybamm_cookiecutter/parameters/parameter_sets.py b/src/pybamm_cookiecutter/entry_point.py similarity index 60% rename from src/pybamm_cookiecutter/parameters/parameter_sets.py rename to src/pybamm_cookiecutter/entry_point.py index bcca96e..71323f0 100644 --- a/src/pybamm_cookiecutter/parameters/parameter_sets.py +++ b/src/pybamm_cookiecutter/entry_point.py @@ -37,10 +37,11 @@ from collections.abc import Mapping from typing import Callable -class ParameterSets(Mapping): +class EntryPoint(Mapping): """ - Dict-like interface for accessing parameter sets through entry points in cookiecutter template. - Access via :py:data:`pybamm_cookiecutter.parameter_sets` + Dict-like interface for accessing parameter sets and models through entry points in cookiecutter template. + Access via :py:data:`pybamm_cookiecutter.parameter_sets` for parameter_sets + Access via :py:data:`pybamm_cookiecutter.Model` for Models Examples -------- @@ -48,8 +49,10 @@ class ParameterSets(Mapping): >>> import pybamm_cookiecutter >>> list(pybamm_cookiecutter.parameter_sets) ['Chen2020', ...] + >>> list(pybamm_cookiecutter.models) + ['SPM', ...] - Get the docstring for a parameter set: + Get the docstring for a parameter set/model: >>> print(pybamm_cookiecutter.parameter_sets.get_docstring("Ai2020")) @@ -62,11 +65,16 @@ class ParameterSets(Mapping): """ - def __init__(self): - """Dict of entry points for parameter sets, lazily load entry points as""" - self.__all_parameter_sets = dict() - for entry_point in self.get_entries("cookie_parameter_sets"): - self.__all_parameter_sets[entry_point.name] = entry_point + _instances = 0 + def __init__(self, group): + """Dict of entry points for parameter sets or models, lazily load entry points as""" + if not hasattr(self, 'initialized'): # Ensure __init__ is called once per instance + self.initialized = True + EntryPoint._instances += 1 + self._all_entries = dict() + self.group = group + for entry_point in self.get_entries(self.group): + self._all_entries[entry_point.name] = entry_point @staticmethod def get_entries(group_name): @@ -76,9 +84,9 @@ def get_entries(group_name): else: return importlib.metadata.entry_points(group=group_name) - def __new__(cls): - """Ensure only one instance of ParameterSets exists""" - if not hasattr(cls, "instance"): + def __new__(cls, group): + """Ensure only two instances of entry points exist, one for parameter sets and the other for models""" + if EntryPoint._instances < 2: cls.instance = super().__new__(cls) return cls.instance @@ -86,25 +94,25 @@ def __getitem__(self, key) -> dict: return self._load_entry_point(key)() def _load_entry_point(self, key) -> Callable: - """Check that ``key`` is a registered ``cookie_parameter_sets``, - and return the entry point for the parameter set, loading it needed.""" - if key not in self.__all_parameter_sets: - raise KeyError(f"Unknown parameter set: {key}") - ps = self.__all_parameter_sets[key] + """Check that ``key`` is a registered ``parameter_sets`` or ``model_entry_points`` , + and return the entry point for the parameter set/model, loading it needed.""" + if key not in self._all_entries: + raise KeyError(f"Unknown parameter set or model: {key}") + ps = self._all_entries[key] try: - ps = self.__all_parameter_sets[key] = ps.load() + ps = self._all_entries[key] = ps.load() except AttributeError: pass return ps def __iter__(self): - return self.__all_parameter_sets.__iter__() + return self._all_entries.__iter__() def __len__(self) -> int: - return len(self.__all_parameter_sets) + return len(self._all_entries) def get_docstring(self, key): - """Return the docstring for the ``key`` parameter set""" + """Return the docstring for the ``key`` parameter set or model""" return textwrap.dedent(self._load_entry_point(key).__doc__) def __getattribute__(self, name): @@ -114,4 +122,22 @@ def __getattribute__(self, name): raise error #: Singleton Instance of :class:ParameterSets """ -parameter_sets = ParameterSets() +parameter_sets = EntryPoint(group="parameter_sets") + +#: Singleton Instance of :class:ModelEntryPoints""" +models = EntryPoint(group="model_entry_points") + +def Model(model:str): + """ + Returns the loaded model object + + Parameters + ---------- + model : str + The model name or author name of the model mentioned at the model entry point. + Returns + ------- + pybamm.model + Model object of the initialised model. + """ + return models[model] diff --git a/src/pybamm_cookiecutter/models/__init__.py b/src/pybamm_cookiecutter/models/__init__.py deleted file mode 100644 index 545acc4..0000000 --- a/src/pybamm_cookiecutter/models/__init__.py +++ /dev/null @@ -1 +0,0 @@ -__all__ = ["model",] diff --git a/src/pybamm_cookiecutter/models/input/SPM.py b/src/pybamm_cookiecutter/models/input/SPM.py index e94741a..b6c0526 100644 --- a/src/pybamm_cookiecutter/models/input/SPM.py +++ b/src/pybamm_cookiecutter/models/input/SPM.py @@ -36,7 +36,6 @@ # import pybamm - class SPM(pybamm.lithium_ion.BaseModel): """Single Particle Model (SPM) model of a lithium-ion battery, from :footcite:t:`Marquis2019`. diff --git a/src/pybamm_cookiecutter/models/model.py b/src/pybamm_cookiecutter/models/model.py deleted file mode 100644 index c517593..0000000 --- a/src/pybamm_cookiecutter/models/model.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -This code is adopted from the PyBaMM project under the BSD-3-Clause - -Copyright (c) 2018-2024, the PyBaMM team. -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this - list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -* Neither the name of the copyright holder nor the names of its - contributors may be used to endorse or promote products derived from - this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -""" - - -import importlib.metadata -from collections.abc import Mapping -from typing import Callable -import textwrap - -class ModelEntryPoints(Mapping): - - def __init__(self): - self._all_models = dict() - for entry_point in self.get_entries("model_entry_points"): - self._all_models[entry_point.name] = entry_point - - @staticmethod - def get_entries(group_name): - return importlib.metadata.entry_points(group=group_name) - - def __new__(cls): - """Ensure only one instance""" - if not hasattr(cls, "instance"): - cls.instance = super().__new__(cls) - return cls.instance - - def __getitem__(self, key) -> dict: - return self._load_entry_point(key)() - - def _load_entry_point(self, key) -> Callable: - """Check that ``key`` is a registered ``model_entry_points``, - and return the entry point for the model, loading it needed. - """ - if key not in self._all_models: - raise KeyError(f"Unknown Model: {key}") - model = self._all_models[key] - try: - model = self._all_models[key] = model.load() - except AttributeError: - pass - return model - - def __iter__(self): - return self._all_models.__iter__() - - def __len__(self) -> int: - return len(self._all_models) - - def get_docstring(self, key): - """Return the docstring for the ``key`` model""" - return textwrap.dedent(self._load_entry_point(key).__doc__) - - def __getattribute__(self, name): - try: - return super().__getattribute__(name) - except AttributeError as error: - raise error - -#: Singleton Instance of :class:ModelEntryPoints""" -model_instance = ModelEntryPoints() - -def models(model:str): - """ - Returns the loaded model object - - Parameters - ---------- - model : str - The model name or author name of the model mentioned at the model entry point. - Returns - ------- - pybamm.model - Model object of the initialised model. - """ - return model_instance[model] diff --git a/src/pybamm_cookiecutter/parameters/__init__.py b/src/pybamm_cookiecutter/parameters/__init__.py deleted file mode 100644 index 6586100..0000000 --- a/src/pybamm_cookiecutter/parameters/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from __future__ import annotations - -__all__ = ["parameter_sets",] diff --git a/tests/test_model_entry_points.py b/tests/test_model_entry_points.py index 0feaa32..abe284b 100644 --- a/tests/test_model_entry_points.py +++ b/tests/test_model_entry_points.py @@ -7,7 +7,7 @@ def test_entry_points(): """Test if the entry points are loaded correctly.""" - entry_points = list(pybamm_cookiecutter.model_instance) + entry_points = list(pybamm_cookiecutter.models) models = Path("src/pybamm_cookiecutter/models/input/").glob("*.py") # Making a list Parameter sets in the parameters/input directory models = [x.stem for x in models] @@ -17,5 +17,5 @@ def test_entry_points(): def test_entry_point_load(): """Testing if the values get loaded via parameter entry points and are equal when loaded through entry points""" # Loading parameter_sets through entry points - model_instance = pybamm_cookiecutter.models("SPM") + model_instance = pybamm_cookiecutter.Model("SPM") assert model_instance != None From db17f5132e6b4de5aa0c67303cc086eb623f5308 Mon Sep 17 00:00:00 2001 From: Santhosh Sundaram Date: Mon, 15 Jul 2024 23:18:20 +0530 Subject: [PATCH 4/6] added examples and code review suggestions --- pyproject.toml | 2 +- src/pybamm_cookiecutter/entry_point.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 797a805..3d0f514 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ Changelog = "https://github.com/pybamm-team/pybamm-cookiecutter/releases" [project.entry-points."parameter_sets"] Chen2020 = "pybamm_cookiecutter.parameters.input.Chen2020:get_parameter_values" -[project.entry-points."model_entry_points"] +[project.entry-points."models"] SPM = "pybamm_cookiecutter.models.input.SPM:SPM" [tool.hatch] diff --git a/src/pybamm_cookiecutter/entry_point.py b/src/pybamm_cookiecutter/entry_point.py index 71323f0..691cd86 100644 --- a/src/pybamm_cookiecutter/entry_point.py +++ b/src/pybamm_cookiecutter/entry_point.py @@ -61,8 +61,11 @@ class EntryPoint(Mapping): :footcite:t:`rieger2016new` and references therein. ... - See also: :ref:`adding-parameter-sets` - + >>> print(pybamm_cookiecutter.models.get_docstring("SPM")) + + Single Particle Model (SPM) model of a lithium-ion battery, from :footcite:t:`Marquis2019`. This class differs from the :class:`pybamm.lithium_ion.SPM` model class in that it shows the whole model in a single class. This comes at the cost of flexibility in combining different physical effects, and in general the main SPM class should be used instead. + ... + See also: :ref:`adding-parameter-sets` """ _instances = 0 @@ -94,7 +97,7 @@ def __getitem__(self, key) -> dict: return self._load_entry_point(key)() def _load_entry_point(self, key) -> Callable: - """Check that ``key`` is a registered ``parameter_sets`` or ``model_entry_points`` , + """Check that ``key`` is a registered ``parameter_sets`` or ``models` , and return the entry point for the parameter set/model, loading it needed.""" if key not in self._all_entries: raise KeyError(f"Unknown parameter set or model: {key}") @@ -125,7 +128,7 @@ def __getattribute__(self, name): parameter_sets = EntryPoint(group="parameter_sets") #: Singleton Instance of :class:ModelEntryPoints""" -models = EntryPoint(group="model_entry_points") +models = EntryPoint(group="models") def Model(model:str): """ @@ -139,5 +142,13 @@ def Model(model:str): ------- pybamm.model Model object of the initialised model. + Examples + -------- + Listing available models: + >>> import pybamm_cookiecutter + >>> list(pybamm_cookiecutter.models) + ['SPM', ...] + >>> pybamm_cookiecutter.Model('Author/Year') + """ return models[model] From b8f8757d2373e56a2cb2c5d6d3af62b049b7c4b0 Mon Sep 17 00:00:00 2001 From: Santhosh Sundaram Date: Mon, 15 Jul 2024 23:25:39 +0530 Subject: [PATCH 5/6] merged entry points tests into a single file --- tests/test_entry_points.py | 22 +++++++++++++++++++--- tests/test_model_entry_points.py | 21 --------------------- tests/test_parameter_entry_points.py | 27 --------------------------- 3 files changed, 19 insertions(+), 51 deletions(-) delete mode 100644 tests/test_model_entry_points.py delete mode 100644 tests/test_parameter_entry_points.py diff --git a/tests/test_entry_points.py b/tests/test_entry_points.py index 6845a43..b37d291 100644 --- a/tests/test_entry_points.py +++ b/tests/test_entry_points.py @@ -4,8 +4,8 @@ import sys from pathlib import Path -def test_entry_points(): - """Test if the entry points are loaded correctly.""" +def test_parameter_sets_entry_points(): + """Test if the parameter_sets via entry points are loaded correctly.""" entry_points = list(pybamm_cookiecutter.parameter_sets) parameter_sets = Path("src/pybamm_cookiecutter/parameters/input/").glob("*.py") @@ -14,7 +14,7 @@ def test_entry_points(): assert parameter_sets == entry_points, "Entry points missing either in pyproject.toml or in the input directory" -def test_entry_point_load(): +def test_parameter_sets_entry_point_load(): """Testing if the values get loaded via parameter entry points and are equal when loaded through entry points""" # Loading parameter_sets through entry points parameters = pybamm_cookiecutter.parameter_sets['Chen2020'] @@ -25,3 +25,19 @@ def test_entry_point_load(): spec.loader.exec_module(chen_module) parameters_from_file = chen_module.get_parameter_values() assert parameters.keys() == parameters_from_file.keys(), f"The keys in the module and local input file are not the same, expected {parameters.keys} got {parameters_from_file.keys()}" + +def test_model_entry_points(): + """Test if the models via entry points are loaded correctly.""" + + entry_points = list(pybamm_cookiecutter.models) + models = Path("src/pybamm_cookiecutter/models/input/").glob("*.py") + # Making a list Parameter sets in the parameters/input directory + models = [x.stem for x in models] + + assert models == entry_points, "Entry points missing either in pyproject.toml or in the input directory" + +def test_model_entry_point_load(): + """Testing if the model gets initialised and returned.""" + # Loading parameter_sets through entry points + model_instance = pybamm_cookiecutter.Model("SPM") + assert model_instance != None diff --git a/tests/test_model_entry_points.py b/tests/test_model_entry_points.py deleted file mode 100644 index abe284b..0000000 --- a/tests/test_model_entry_points.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest -import pybamm_cookiecutter -import importlib.util -import sys -from pathlib import Path - -def test_entry_points(): - """Test if the entry points are loaded correctly.""" - - entry_points = list(pybamm_cookiecutter.models) - models = Path("src/pybamm_cookiecutter/models/input/").glob("*.py") - # Making a list Parameter sets in the parameters/input directory - models = [x.stem for x in models] - - assert models == entry_points, "Entry points missing either in pyproject.toml or in the input directory" - -def test_entry_point_load(): - """Testing if the values get loaded via parameter entry points and are equal when loaded through entry points""" - # Loading parameter_sets through entry points - model_instance = pybamm_cookiecutter.Model("SPM") - assert model_instance != None diff --git a/tests/test_parameter_entry_points.py b/tests/test_parameter_entry_points.py deleted file mode 100644 index 58cee98..0000000 --- a/tests/test_parameter_entry_points.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest -import pybamm_cookiecutter -import importlib.util -import sys -from pathlib import Path - -def test_entry_points(): - """Test if the entry points are loaded correctly.""" - - entry_points = list(pybamm_cookiecutter.parameter_sets) - parameter_sets = Path("src/pybamm_cookiecutter/parameters/input/").glob("*.py") - # Making a list Parameter sets in the parameters/input directory - parameter_sets = [x.stem for x in parameter_sets] - - assert parameter_sets == entry_points, "Entry points missing either in pyproject.toml or in the input directory" - -def test_entry_point_load(): - """Testing if the values get loaded via parameter entry points and are equal when loaded through entry points""" - # Loading parameter_sets through entry points - parameters = pybamm_cookiecutter.parameter_sets['Chen2020'] - # Loading parameter sets through the source file by dynamically loading Chen2020.py as a module - spec = importlib.util.spec_from_file_location("Chen2020mod", "src/pybamm_cookiecutter/parameters/input/Chen2020.py") - chen_module = importlib.util.module_from_spec(spec) - sys.modules["Chen2020mod"] = chen_module - spec.loader.exec_module(chen_module) - parameters_from_file = chen_module.get_parameter_values() - assert parameters.keys() == parameters_from_file.keys(), "The keys in the module and local input file are not the same" From 5dfaeec4002ea1816410b414028498dc3c339fbf Mon Sep 17 00:00:00 2001 From: Santhosh <52504160+santacodes@users.noreply.github.com> Date: Thu, 18 Jul 2024 11:47:35 +0530 Subject: [PATCH 6/6] Update tests/test_entry_points.py Co-authored-by: Agriya Khetarpal <74401230+agriyakhetarpal@users.noreply.github.com> --- tests/test_entry_points.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_entry_points.py b/tests/test_entry_points.py index b37d291..ed2b1e7 100644 --- a/tests/test_entry_points.py +++ b/tests/test_entry_points.py @@ -40,4 +40,4 @@ def test_model_entry_point_load(): """Testing if the model gets initialised and returned.""" # Loading parameter_sets through entry points model_instance = pybamm_cookiecutter.Model("SPM") - assert model_instance != None + assert model_instance is not None