diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index b6994795d6..f92ee76c9e 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -48,8 +48,7 @@ jobs: echo "tag=all" >> "$GITHUB_OUTPUT" fi - - name: Build and push Docker image to Docker Hub (no solvers) - if: matrix.build-args == 'No solvers' + - name: Build and push Docker image to Docker Hub (${{ matrix.build-args }}) uses: docker/build-push-action@v5 with: context: . @@ -58,29 +57,5 @@ jobs: push: true platforms: linux/amd64, linux/arm64 - - name: Build and push Docker image to Docker Hub (with ODES and IDAKLU solvers) - if: matrix.build-args == 'ODES' || matrix.build-args == 'IDAKLU' - uses: docker/build-push-action@v5 - with: - context: . - file: scripts/Dockerfile - tags: pybamm/pybamm:${{ steps.tags.outputs.tag }} - push: true - build-args: ${{ matrix.build-args }}=true - platforms: linux/amd64, linux/arm64 - - - name: Build and push Docker image to Docker Hub (with ALL and JAX solvers) - if: matrix.build-args == 'ALL' || matrix.build-args == 'JAX' - uses: docker/build-push-action@v5 - with: - context: . - file: scripts/Dockerfile - tags: pybamm/pybamm:${{ steps.tags.outputs.tag }} - push: true - build-args: ${{ matrix.build-args }}=true - # exclude arm64 for JAX and ALL builds for now, see - # https://github.com/google/jax/issues/13608 - platforms: linux/amd64 - - name: List built image(s) run: docker images diff --git a/CHANGELOG.md b/CHANGELOG.md index 5204e0bc82..e9160e5081 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,14 @@ - Fixed bug in calculation of theoretical energy that made it very slow ([#3506](https://github.com/pybamm-team/PyBaMM/pull/3506)) - The irreversible plating model now increments `f"{Domain} dead lithium concentration [mol.m-3]"`, not `f"{Domain} lithium plating concentration [mol.m-3]"` as it did previously. ([#3485](https://github.com/pybamm-team/PyBaMM/pull/3485)) +## Optimizations + +- Updated `jax` and `jaxlib` to the latest available versions and added Windows (Python 3.9+) support for the Jax solver ([#3550](https://github.com/pybamm-team/PyBaMM/pull/3550)) + +## Breaking changes + +- Dropped support for the `[jax]` extra, i.e., the Jax solver when running on Python 3.8. The Jax solver is now available on Python 3.9 and above ([#3550](https://github.com/pybamm-team/PyBaMM/pull/3550)) + # [v23.9](https://github.com/pybamm-team/PyBaMM/tree/v23.9) - 2023-10-31 ## Features diff --git a/docs/source/user_guide/installation/GNU-linux.rst b/docs/source/user_guide/installation/GNU-linux.rst index ca95bbe1b5..cf027db587 100644 --- a/docs/source/user_guide/installation/GNU-linux.rst +++ b/docs/source/user_guide/installation/GNU-linux.rst @@ -133,7 +133,10 @@ Optional - JaxSolver ~~~~~~~~~~~~~~~~~~~~ Users can install ``jax`` and ``jaxlib`` to use the Jax solver. -Currently, only GNU/Linux and macOS are supported. + +.. note:: + + The Jax solver is not supported on Python 3.8. It is supported on Python 3.9, 3.10, and 3.11. .. code:: bash diff --git a/docs/source/user_guide/installation/index.rst b/docs/source/user_guide/installation/index.rst index 2b8b7fe304..65cbad33fb 100644 --- a/docs/source/user_guide/installation/index.rst +++ b/docs/source/user_guide/installation/index.rst @@ -217,13 +217,13 @@ Dependency Minimum Version p Jax dependencies ^^^^^^^^^^^^^^^^^ -Installable with ``pip install "pybamm[jax]"`` +Installable with ``pip install "pybamm[jax]"``, currently supported on Python 3.9-3.11. ========================================================================= ================== ================== ======================= Dependency Minimum Version pip extra Notes ========================================================================= ================== ================== ======================= -`JAX `__ 0.4.8 jax For JAX solvers -`jaxlib `__ 0.4.7 jax Support library for JAX +`JAX `__ 0.4.20 jax For the JAX solver +`jaxlib `__ 0.4.20 jax Support library for JAX ========================================================================= ================== ================== ======================= .. _install.odes_dependencies: diff --git a/docs/source/user_guide/installation/windows.rst b/docs/source/user_guide/installation/windows.rst index 5b104e91bd..5ad77b6f7f 100644 --- a/docs/source/user_guide/installation/windows.rst +++ b/docs/source/user_guide/installation/windows.rst @@ -66,6 +66,21 @@ installed automatically when you install PyBaMM using ``pip``. For an introduction to virtual environments, see (https://realpython.com/python-virtual-environments-a-primer/). +Optional - JaxSolver +~~~~~~~~~~~~~~~~~~~~ + +Users can install ``jax`` and ``jaxlib`` to use the Jax solver. + +.. note:: + + The Jax solver is not supported on Python 3.8. It is supported on Python 3.9, 3.10, and 3.11. + +.. code:: bash + + pip install "pybamm[jax]" + +The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. (``pybamm_install_jax`` is deprecated.) + Uninstall PyBaMM ---------------- diff --git a/noxfile.py b/noxfile.py index fd033a4573..297fc5b3d7 100644 --- a/noxfile.py +++ b/noxfile.py @@ -61,9 +61,12 @@ def run_coverage(session): set_environment_variables(PYBAMM_ENV, session=session) session.install("coverage", silent=False) if sys.platform != "win32": - session.install("-e", ".[all,odes,jax]", silent=False) + session.install("-e", ".[all,jax,odes]", silent=False) else: - session.install("-e", ".[all]", silent=False) + if sys.version_info < (3, 9): + session.install("-e", ".[all]", silent=False) + else: + session.install("-e", ".[all,jax]", silent=False) session.run("coverage", "run", "run-tests.py", "--nosub") session.run("coverage", "combine") session.run("coverage", "xml") @@ -74,9 +77,12 @@ def run_integration(session): """Run the integration tests.""" set_environment_variables(PYBAMM_ENV, session=session) if sys.platform != "win32": - session.install("-e", ".[all,odes,jax]", silent=False) + session.install("-e", ".[all,jax,odes]", silent=False) else: - session.install("-e", ".[all]", silent=False) + if sys.version_info < (3, 9): + session.install("-e", ".[all]", silent=False) + else: + session.install("-e", ".[all,jax]", silent=False) session.run("python", "run-tests.py", "--integration") @@ -92,9 +98,12 @@ def run_unit(session): """Run the unit tests.""" set_environment_variables(PYBAMM_ENV, session=session) if sys.platform != "win32": - session.install("-e", ".[all,odes,jax]", silent=False) + session.install("-e", ".[all,jax,odes]", silent=False) else: - session.install("-e", ".[all]", silent=False) + if sys.version_info < (3, 9): + session.install("-e", ".[all]", silent=False) + else: + session.install("-e", ".[all,jax]", silent=False) session.run("python", "run-tests.py", "--unit") @@ -144,7 +153,24 @@ def set_dev(session): external=True, ) else: - session.run(python, "-m", "pip", "install", "-e", ".[all,dev]", external=True) + if sys.version_info < (3, 9): + session.run( + python, + "-m", + "pip", + "install", + ".[all,dev]", + external=True, + ) + else: + session.run( + python, + "-m", + "pip", + "install", + ".[all,dev,jax]", + external=True, + ) @nox.session(name="tests") @@ -152,9 +178,12 @@ def run_tests(session): """Run the unit tests and integration tests sequentially.""" set_environment_variables(PYBAMM_ENV, session=session) if sys.platform != "win32": - session.install("-e", ".[all,odes,jax]", silent=False) + session.install("-e", ".[all,jax,odes]", silent=False) else: - session.install("-e", ".[all]", silent=False) + if sys.version_info < (3, 9): + session.install("-e", ".[all]", silent=False) + else: + session.install("-e", ".[all,jax]", silent=False) session.run("python", "run-tests.py", "--all") diff --git a/pyproject.toml b/pyproject.toml index f02286ad18..7d25c8e140 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = [ "setuptools>=64", "wheel", # On Windows, use the CasADi vcpkg registry and CMake bundled from MSVC - "casadi>=3.6.0; platform_system!='Windows'", + "casadi>=3.6.3; platform_system!='Windows'", "cmake; platform_system!='Windows'", ] build-backend = "setuptools.build_meta" @@ -110,13 +110,13 @@ dev = [ "nbmake", ] # Reading CSV files -pandas = [ +pandas = [ "pandas>=1.5.0", ] # For the Jax solver. Note: these must be kept in sync with the versions defined in pybamm/util.py. jax = [ - "jax>=0.4,<=0.5", - "jaxlib>=0.4,<=0.5", + "jax==0.4.20; python_version >= '3.9'", + "jaxlib==0.4.20; python_version >= '3.9'", ] # For the scikits.odes solver odes = [ diff --git a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py index ca36804ba0..df33e0fe27 100644 --- a/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py +++ b/tests/unit/test_expression_tree/test_operations/test_evaluate_python.py @@ -503,7 +503,7 @@ def test_evaluator_jax(self): expr = pybamm.exp(a * b) evaluator = pybamm.EvaluatorJax(expr) result = evaluator(t=None, y=np.array([[2], [3]])) - self.assertEqual(result, np.exp(6)) + np.testing.assert_array_almost_equal(result, np.exp(6), decimal=15) # test a constant expression expr = pybamm.Scalar(2) * pybamm.Scalar(3)