Skip to content

Commit

Permalink
Merge pull request pybamm-team#3335 from kratman/feat/replacePkgResou…
Browse files Browse the repository at this point in the history
…rces

Replace deprecated pkg_resources
  • Loading branch information
Saransh-cpp authored Sep 22, 2023
2 parents c72a2ef + 31c362d commit 3c48029
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
11 changes: 10 additions & 1 deletion pybamm/parameters/parameter_sets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
import warnings
import importlib.metadata
import textwrap
Expand Down Expand Up @@ -37,9 +38,17 @@ class ParameterSets(Mapping):
def __init__(self):
# Dict of entry points for parameter sets, lazily load entry points as
self.__all_parameter_sets = dict()
for entry_point in importlib.metadata.entry_points()["pybamm_parameter_sets"]:
for entry_point in self.get_entries("pybamm_parameter_sets"):
self.__all_parameter_sets[entry_point.name] = entry_point

@staticmethod
def get_entries(group_name):
# Wrapper for the importlib version logic
if sys.version_info < (3, 10): # pragma: no cover
return importlib.metadata.entry_points()[group_name]
else:
return importlib.metadata.entry_points(group=group_name)

def __new__(cls):
"""Ensure only one instance of ParameterSets exists"""
if not hasattr(cls, "instance"):
Expand Down
9 changes: 5 additions & 4 deletions pybamm/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from warnings import warn

import numpy as np
import pkg_resources
import importlib.metadata

import pybamm

Expand Down Expand Up @@ -271,9 +271,10 @@ def have_jax():

def is_jax_compatible():
"""Check if the available version of jax and jaxlib are compatible with PyBaMM"""
return pkg_resources.get_distribution("jax").version.startswith(
JAX_VERSION
) and pkg_resources.get_distribution("jaxlib").version.startswith(JAXLIB_VERSION)
return (
importlib.metadata.distribution("jax").version.startswith(JAX_VERSION)
and importlib.metadata.distribution("jaxlib").version.startswith(JAXLIB_VERSION)
)


def is_constant_and_can_evaluate(symbol):
Expand Down
3 changes: 1 addition & 2 deletions tests/unit/test_parameters/test_parameter_sets_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from tests import TestCase

import pybamm
import pkg_resources
import unittest


Expand All @@ -25,7 +24,7 @@ def test_all_registered(self):
"""Check that all parameter sets have been registered with the
``pybamm_parameter_sets`` entry point"""
known_entry_points = set(
ep.name for ep in pkg_resources.iter_entry_points("pybamm_parameter_sets")
ep.name for ep in pybamm.parameter_sets.get_entries("pybamm_parameter_sets")
)
self.assertEqual(set(pybamm.parameter_sets.keys()), known_entry_points)
self.assertEqual(len(known_entry_points), len(pybamm.parameter_sets))
Expand Down

0 comments on commit 3c48029

Please sign in to comment.