From 9914842b986516dbfc8f3fb56cee139b676d343c Mon Sep 17 00:00:00 2001 From: Gregor Boehl Date: Mon, 8 Apr 2024 09:06:48 +0200 Subject: [PATCH] jax 0.4.25 compatibility (PjitFunction) --- econpizza/parser/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/econpizza/parser/__init__.py b/econpizza/parser/__init__.py index 68a89e3..417ae9d 100644 --- a/econpizza/parser/__init__.py +++ b/econpizza/parser/__init__.py @@ -148,7 +148,7 @@ def _load_external_functions_file(model, context): module = _load_as_module(model["functions_file"]) def func_or_compiled(func): return isinstance( - func, jaxlib.xla_extension.CompiledFunction) or isfunction(func) + func, jaxlib.xla_extension.PjitFunction) or isfunction(func) for m in getmembers(module, func_or_compiled): context[m[0]] = m[1]