Skip to content

Commit 138cbf2

Browse files
authored
Merge pull request pybamm-team#3423 from jsbrittain/jax_gpu
JaxSolver fails when using GPU support with no input parameters
2 parents 618b481 + 6cc3940 commit 138cbf2

File tree

3 files changed

+6
-1
lines changed

3 files changed

+6
-1
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# [Unreleased](https://github.com/pybamm-team/PyBaMM/)
22

3+
## Bug fixes
4+
5+
- Fixed a bug where the JaxSolver would fails when using GPU support with no input parameters ([#3423](https://github.com/pybamm-team/PyBaMM/pull/3423))
6+
37
# [v23.9rc0](https://github.com/pybamm-team/PyBaMM/tree/v23.9rc0) - 2023-10-31
48

59
## Features

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@
154154
"navbar_end": ["theme-switcher", "navbar-icon-links"],
155155
# add Algolia to the persistent navbar, this removes the default search icon
156156
"navbar_persistent": "algolia-searchbox",
157+
"navigation_with_keys": False,
157158
"use_edit_page_button": True,
158159
"analytics": {
159160
"plausible_analytics_domain": "docs.pybamm.org",

pybamm/solvers/jax_solver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def _integrate(self, model, t_eval, inputs=None):
215215

216216
y = []
217217
platform = jax.lib.xla_bridge.get_backend().platform.casefold()
218-
if platform.startswith("cpu"):
218+
if len(inputs) <= 1 or platform.startswith("cpu"):
219219
# cpu execution runs faster when multithreaded
220220
async def solve_model_for_inputs():
221221
async def solve_model_async(inputs_v):

0 commit comments

Comments
 (0)