Skip to content

Tracer error for pure_callbacks #25704

Answered by jakevdp
SNMS95 asked this question in Q&A
Dec 31, 2024 · 1 comments · 3 replies
Discussion options

You must be logged in to vote

That said, static arguments to jit must be hashable, and dicts are not hashable, so instead you could keep it static using a closure:

problem_dict = {'scale_v': 1, 'scale_z': 1.0}
j1 = partial(jv, 1, problem_dict=problem_dict)
print(jax.grad(jax.jit(j1))(2.0))

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@jakevdp
Comment options

@jakevdp
Comment options

Answer selected by SNMS95
@SNMS95
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants