Skip to content

Commit 5031b6f

Browse files
Merge pull request #25625 from mattjj:ref-errors-5
PiperOrigin-RevId: 708196762
2 parents 05f3a70 + b6482f1 commit 5031b6f

File tree

6 files changed

+151
-46
lines changed

6 files changed

+151
-46
lines changed

jax/_src/custom_derivatives.py

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
from jax._src.ad_util import (
3131
stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
3232
from jax._src.api_util import (
33-
argnums_partial, flatten_fun_nokwargs, resolve_kwargs)
33+
argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature,
34+
_arg_names)
3435
from jax._src.errors import UnexpectedTracerError
36+
from jax._src.state.types import AbstractRef
3537
from jax._src.interpreters import ad
3638
from jax._src.interpreters import batching
3739
from jax._src.interpreters import mlir
@@ -41,8 +43,8 @@
4143
from jax._src.lax import lax
4244
from jax._src.tree_util import (
4345
tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
44-
register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr,
45-
treedef_children)
46+
register_pytree_node_class, tree_leaves, tree_flatten_with_path,
47+
tree_leaves_with_path, keystr, treedef_children)
4648
from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable,
4749
unzip2)
4850

@@ -608,16 +610,50 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
608610
fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
609611
args_flat, in_tree = tree_flatten(dyn_args)
610612
in_avals = [core.get_aval(x) for x in args_flat]
613+
if config.mutable_array_checks.value:
614+
f_ = _check_primal_refs(f_, self.nondiff_argnums)
611615
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
612-
flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name,
613-
fwd_name, in_tree, out_type)
616+
flat_fwd, out_trees = _flatten_fwd(
617+
fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name,
618+
fwd_name, in_tree, out_type)
614619
flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
615620
out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
616621
*args_flat, out_trees=out_trees,
617622
symbolic_zeros=self.symbolic_zeros)
618623
_, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
619624
return tree_unflatten(out_tree, out_flat)
620625

626+
@lu.transformation2
627+
def _check_primal_refs(f, nondiff_argnums, *args):
628+
_check_for_aliased_refs(f, nondiff_argnums, args)
629+
out = f(*args)
630+
_check_for_returned_refs(f, out, 'primal')
631+
return out
632+
633+
def _check_for_aliased_refs(f, nondiff_argnums, args):
634+
leaves = tree_leaves(args)
635+
refs: dict[int, int] = {}
636+
for i, x in enumerate(leaves):
637+
if (isinstance((a := core.get_aval(x)), AbstractRef) and
638+
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
639+
arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ())
640+
if arg_names is None:
641+
arg_names = [f'flat index {j}' for j in range(len(leaves))]
642+
raise ValueError(
643+
"only one reference to a mutable array may be passed as an argument "
644+
f"to a function, but custom_vjp function {f} got the same mutable "
645+
f"array reference of type {a.str_short()} at {arg_names[dup_idx]} and"
646+
f" {arg_names[i]}.")
647+
648+
def _check_for_returned_refs(f, out, kind):
649+
leaves = tree_leaves_with_path(out)
650+
for path, leaf in leaves:
651+
if isinstance((a := core.get_aval(leaf)), AbstractRef):
652+
loc = f' at output tree path {keystr(path)}' if path else ''
653+
raise ValueError(f"custom_vjp {kind} function {f} returned a mutable "
654+
f"a array reference of type {a.str_short()}{loc}, "
655+
"but mutable array references cannot be returned.")
656+
621657
@dataclasses.dataclass
622658
class CustomVJPPrimal:
623659
"""Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
@@ -655,14 +691,18 @@ def _check_for_tracers(x):
655691
raise UnexpectedTracerError(msg)
656692

657693
@partial(lu.transformation_with_aux2, use_eq_store=True)
658-
def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type,
659-
*args):
694+
def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name,
695+
fwd_name, in_tree, maybe_out_type, *args):
660696
if symbolic_zeros:
661697
args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
662698
else:
663699
args = args[::2]
664700
py_args = tree_unflatten(in_tree, args)
701+
if config.mutable_array_checks.value:
702+
_check_for_aliased_refs(f, nondiff_argnums, py_args)
665703
pair_out = f(*py_args)
704+
if config.mutable_array_checks.value:
705+
_check_for_returned_refs(f, pair_out, 'fwd')
666706
if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
667707
msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
668708
"must produce a pair (list or tuple of length two) where the first "
@@ -1393,8 +1433,8 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
13931433
fwd_ = lu.wrap_init(fwd)
13941434
args_flat, in_tree = tree_flatten(dyn_args)
13951435
flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
1396-
flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name,
1397-
in_tree, out_type)
1436+
flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False,
1437+
primal_name, fwd_name, in_tree, out_type)
13981438
flat_fwd = _fix_fwd_args(flat_fwd)
13991439

14001440
in_avals = [core.get_aval(x) for x in args_flat]

jax/_src/interpreters/ad.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,9 @@ def full_lower(self):
539539
def to_concrete_value(self):
540540
return core.to_concrete_value(self.primal)
541541

542+
def get_referent(self):
543+
return core.get_referent(self.primal)
544+
542545
def _primal_tangent_shapes_match(primal, tangent):
543546
if type(tangent) is not Zero:
544547
primal_aval = get_aval(primal).strip_weak_type()

jax/_src/interpreters/partial_eval.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,8 +2010,8 @@ def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees,
20102010
def fwd_jaxpr_from_zeros(*zeros):
20112011
for store in fwd.stores: store and store.reset()
20122012
fwd_ = _interleave_fun(fwd, zeros)
2013-
jaxpr, _, consts, atr = trace_to_jaxpr_dynamic(fwd_, in_avals)
2014-
if atr: raise NotImplementedError
2013+
jaxpr, _, consts, attrs = trace_to_jaxpr_dynamic(fwd_, in_avals)
2014+
if attrs: raise NotImplementedError
20152015
return jaxpr, consts
20162016

20172017
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
@@ -2154,14 +2154,14 @@ def trace_to_jaxpr_dynamic(
21542154
ans = fun.call_wrapped(*in_tracers)
21552155

21562156
out_tracers = map(trace.to_jaxpr_tracer, ans)
2157-
_check_no_refs(debug_info, out_tracers)
2157+
_check_no_returned_refs(debug_info, out_tracers)
21582158
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
21592159
del trace, fun, in_tracers, out_tracers, ans
21602160

21612161
config.enable_checks.value and core.check_jaxpr(jaxpr)
21622162
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked
21632163

2164-
def _check_no_refs(
2164+
def _check_no_returned_refs(
21652165
dbg: lu.TracingDebugInfo | None,
21662166
out_tracers: Sequence[DynamicJaxprTracer]
21672167
) -> None:

jax/_src/lax/control_flow/common.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,6 @@ def _initial_style_jaxprs_with_common_consts(
8989

9090
jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
9191
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
92-
# If we get a `Ref` in the consts, we know it must come from an outer
93-
# `run_state`. We also know if shouldn't be boxed up in another tracer.
94-
# We assert that it is in fact a DynamicJaxprTracer
95-
for consts, consts_avals in zip(all_consts, all_const_avals):
96-
for c, aval in zip(consts, consts_avals):
97-
if isinstance(aval, state.AbstractRef):
98-
assert isinstance(c, pe.DynamicJaxprTracer)
9992

10093
# TODO(sharadmv,mattjj): we could dedup *all consts* instead of just the Refs.
10194

jax/_src/lax/control_flow/conditionals.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from jax.tree_util import tree_flatten, tree_unflatten
2727
from jax._src import ad_util
28+
from jax._src.api_util import (
29+
_check_no_aliased_ref_args, _check_no_aliased_closed_over_refs)
2830
from jax._src import config
2931
from jax._src import core
3032
from jax._src import dispatch
@@ -136,8 +138,14 @@ def switch(index, branches, *operands):
136138
ops, ops_tree = tree_flatten(operands)
137139
ops_avals = tuple(map(core.get_aval, ops))
138140

141+
if config.mutable_array_checks.value:
142+
dbg = pe.debug_info(branches[0], ops_tree, None, False, 'switch')
143+
_check_no_aliased_ref_args(dbg, ops_avals, ops)
144+
139145
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
140146
branches, ops_tree, ops_avals, primitive_name='switch')
147+
if config.mutable_array_checks.value:
148+
_check_no_aliased_closed_over_refs(dbg, (*jaxprs[0].consts, *consts), ops)
141149
for i, (out_tree, jaxpr) in enumerate(zip(out_trees[1:], jaxprs[1:])):
142150
_check_tree_and_avals(f"branch 0 and {i + 1} outputs",
143151
out_trees[0], jaxprs[0].out_avals,
@@ -228,11 +236,14 @@ def cond(pred, true_fun, false_fun, *operands):
228236
ops, ops_tree = tree_flatten(operands)
229237
ops_avals = tuple(map(core.get_aval, ops))
230238

239+
if config.mutable_array_checks.value:
240+
dbg = pe.debug_info(true_fun, ops_tree, None, False, 'cond')
241+
_check_no_aliased_ref_args(dbg, ops_avals, ops)
231242
jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
232243
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
233-
if any(isinstance(op_aval, AbstractRef) for op_aval in ops_avals):
234-
raise ValueError("Cannot pass `Ref`s into `cond`.")
235244
true_jaxpr, false_jaxpr = jaxprs
245+
if config.mutable_array_checks.value:
246+
_check_no_aliased_closed_over_refs(dbg, (*true_jaxpr.consts, *consts), ops)
236247

237248
out_tree, false_out_tree = out_trees
238249
if any(isinstance(out_aval, AbstractRef) for out_aval in

tests/mutable_array_test.py

Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -306,30 +306,88 @@ def test_return_from_cond(self):
306306
ValueError, "traced for cond returned a mutable array reference of type"):
307307
jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0))
308308

309-
# TODO test_argument_aliases_cond
310-
# TODO test_closure_and_argument_aliases_cond
311-
312-
# TODO test_return_from_custom_jvp/vjp
313-
# TODO test_argument_aliases_custom_jvp/vjp
314-
# TODO test_closure_and_argument_aliases_custom_jvp/vjp
315-
316-
# TODO(mattjj): enable when cond works with mutable arrays
317-
# @parameterized.parameters([False, True])
318-
# def test_cond_both_branches_close_over_same_mutable_array(self, jit):
319-
# # see also test_cond_with_ref_reuse in state_test.py
320-
# x_ref = core.mutable_array(0.)
321-
# def f(pred):
322-
# def true_fun():
323-
# x_ref[()] = 1.
324-
# def false_fun():
325-
# x_ref[()] = 2.
326-
# jax.lax.cond(pred, true_fun, false_fun)
327-
# if jit:
328-
# f = jax.jit(f)
329-
# out_true = f(True)
330-
# self.assertAllClose(x_ref[...], 1.)
331-
# out_false = f(False)
332-
# self.assertAllClose(x_ref[...], 2.)
309+
def test_argument_aliases_cond(self):
310+
x_ref = core.mutable_array(0.)
311+
with self.assertRaisesRegex( ValueError, r"for cond.*at both x1 and x2"):
312+
jax.lax.cond(True, lambda x1, x2: ..., lambda x1, x2: ..., x_ref, x_ref)
313+
314+
def test_closure_and_argument_aliases_cond(self):
315+
x_ref = core.mutable_array(0.)
316+
with self.assertRaisesRegex(
317+
ValueError, r"closed over and passed as the argument y_ref"):
318+
jax.lax.cond(True,
319+
lambda y_ref: x_ref[...] + y_ref[...],
320+
lambda y_ref: x_ref[...] + y_ref[...],
321+
x_ref)
322+
323+
@parameterized.parameters([False, True])
324+
def test_return_from_custom_vjp_primal(self, jit):
325+
@jax.custom_vjp
326+
def f(ref):
327+
return ref
328+
f.defvjp(lambda ref: ..., lambda *_: ...)
329+
if jit:
330+
f = jax.jit(f)
331+
x_ref = core.mutable_array(0.)
332+
with self.assertRaisesRegex(
333+
ValueError, "custom_vjp primal function"):
334+
f(x_ref)
335+
336+
@parameterized.parameters([False, True])
337+
def test_return_from_custom_vjp_fwd(self, jit):
338+
@jax.custom_vjp
339+
def f(x, ref):
340+
return x
341+
f.defvjp(lambda x, ref: (x, ref), lambda ref, g: g)
342+
if jit:
343+
f = jax.jit(f)
344+
x_ref = core.mutable_array(0.)
345+
with self.assertRaisesRegex(
346+
ValueError, "custom_vjp fwd function"):
347+
jax.vjp(f, 3., x_ref)
348+
349+
@parameterized.parameters([False, True])
350+
def test_argument_aliases_custom_vjp_primal(self, jit):
351+
@jax.custom_vjp
352+
def f(x_ref, y_ref):
353+
...
354+
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
355+
if jit:
356+
f = jax.jit(f)
357+
x_ref = core.mutable_array(0.)
358+
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
359+
f(x_ref, x_ref)
360+
361+
@parameterized.parameters([False, True])
362+
def test_argument_aliases_custom_vjp_fwd(self, jit):
363+
@jax.custom_vjp
364+
def f(x_ref, y_ref):
365+
...
366+
f.defvjp(lambda x_ref, y_ref: (None, None), lambda _, g: (None, None))
367+
if jit:
368+
f = jax.jit(f)
369+
x_ref = core.mutable_array(0.)
370+
with self.assertRaisesRegex(ValueError, "x_ref and y_ref"):
371+
jax.vjp(f, x_ref, x_ref)
372+
373+
# TODO(mattjj): add test test_closure_and_argument_aliases_custom_vjp
374+
375+
@parameterized.parameters([False, True])
376+
def test_cond_both_branches_close_over_same_mutable_array(self, jit):
377+
# see also test_cond_with_ref_reuse in state_test.py
378+
x_ref = core.mutable_array(0.)
379+
def f(pred):
380+
def true_fun():
381+
x_ref[()] = 1.
382+
def false_fun():
383+
x_ref[()] = 2.
384+
jax.lax.cond(pred, true_fun, false_fun)
385+
if jit:
386+
f = jax.jit(f)
387+
out_true = f(True)
388+
self.assertAllClose(x_ref[...], 1.)
389+
out_false = f(False)
390+
self.assertAllClose(x_ref[...], 2.)
333391

334392

335393
if __name__ == '__main__':

0 commit comments

Comments
 (0)