|
30 | 30 | from jax._src.ad_util import (
|
31 | 31 | stop_gradient_p, SymbolicZero, Zero, zeros_like_aval)
|
32 | 32 | from jax._src.api_util import (
|
33 |
| - argnums_partial, flatten_fun_nokwargs, resolve_kwargs) |
| 33 | + argnums_partial, flatten_fun_nokwargs, resolve_kwargs, fun_signature, |
| 34 | + _arg_names) |
34 | 35 | from jax._src.errors import UnexpectedTracerError
|
| 36 | +from jax._src.state.types import AbstractRef |
35 | 37 | from jax._src.interpreters import ad
|
36 | 38 | from jax._src.interpreters import batching
|
37 | 39 | from jax._src.interpreters import mlir
|
|
41 | 43 | from jax._src.lax import lax
|
42 | 44 | from jax._src.tree_util import (
|
43 | 45 | tree_flatten, tree_unflatten, tree_map, treedef_is_leaf, treedef_tuple,
|
44 |
| - register_pytree_node_class, tree_leaves, tree_flatten_with_path, keystr, |
45 |
| - treedef_children) |
| 46 | + register_pytree_node_class, tree_leaves, tree_flatten_with_path, |
| 47 | + tree_leaves_with_path, keystr, treedef_children) |
46 | 48 | from jax._src.util import (cache, safe_zip, safe_map, split_list, Unhashable,
|
47 | 49 | unzip2)
|
48 | 50 |
|
@@ -608,16 +610,50 @@ def __call__(self, *args: Any, **kwargs: Any) -> ReturnValue: # pytype: disable
|
608 | 610 | fwd_, bwd = lu.wrap_init(fwd), lu.wrap_init(self.bwd)
|
609 | 611 | args_flat, in_tree = tree_flatten(dyn_args)
|
610 | 612 | in_avals = [core.get_aval(x) for x in args_flat]
|
| 613 | + if config.mutable_array_checks.value: |
| 614 | + f_ = _check_primal_refs(f_, self.nondiff_argnums) |
611 | 615 | flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
612 |
| - flat_fwd, out_trees = _flatten_fwd(fwd_, self.symbolic_zeros, primal_name, |
613 |
| - fwd_name, in_tree, out_type) |
| 616 | + flat_fwd, out_trees = _flatten_fwd( |
| 617 | + fwd_, self.nondiff_argnums, self.symbolic_zeros, primal_name, |
| 618 | + fwd_name, in_tree, out_type) |
614 | 619 | flat_bwd = _flatten_bwd(bwd, in_tree, in_avals, out_trees).call_wrapped
|
615 | 620 | out_flat = custom_vjp_call_p.bind(flat_fun, flat_fwd, flat_bwd,
|
616 | 621 | *args_flat, out_trees=out_trees,
|
617 | 622 | symbolic_zeros=self.symbolic_zeros)
|
618 | 623 | _, (out_tree, _) = lu.merge_linear_aux(out_type, out_trees)
|
619 | 624 | return tree_unflatten(out_tree, out_flat)
|
620 | 625 |
|
| 626 | +@lu.transformation2 |
| 627 | +def _check_primal_refs(f, nondiff_argnums, *args): |
| 628 | + _check_for_aliased_refs(f, nondiff_argnums, args) |
| 629 | + out = f(*args) |
| 630 | + _check_for_returned_refs(f, out, 'primal') |
| 631 | + return out |
| 632 | + |
| 633 | +def _check_for_aliased_refs(f, nondiff_argnums, args): |
| 634 | + leaves = tree_leaves(args) |
| 635 | + refs: dict[int, int] = {} |
| 636 | + for i, x in enumerate(leaves): |
| 637 | + if (isinstance((a := core.get_aval(x)), AbstractRef) and |
| 638 | + (dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i): |
| 639 | + arg_names = _arg_names(fun_signature(f), args, {}, nondiff_argnums, ()) |
| 640 | + if arg_names is None: |
| 641 | + arg_names = [f'flat index {j}' for j in range(len(leaves))] |
| 642 | + raise ValueError( |
| 643 | + "only one reference to a mutable array may be passed as an argument " |
| 644 | + f"to a function, but custom_vjp function {f} got the same mutable " |
| 645 | + f"array reference of type {a.str_short()} at {arg_names[dup_idx]} and" |
| 646 | + f" {arg_names[i]}.") |
| 647 | + |
| 648 | +def _check_for_returned_refs(f, out, kind): |
| 649 | + leaves = tree_leaves_with_path(out) |
| 650 | + for path, leaf in leaves: |
| 651 | + if isinstance((a := core.get_aval(leaf)), AbstractRef): |
| 652 | + loc = f' at output tree path {keystr(path)}' if path else '' |
| 653 | + raise ValueError(f"custom_vjp {kind} function {f} returned a mutable " |
| 654 | + f"a array reference of type {a.str_short()}{loc}, " |
| 655 | + "but mutable array references cannot be returned.") |
| 656 | + |
621 | 657 | @dataclasses.dataclass
|
622 | 658 | class CustomVJPPrimal:
|
623 | 659 | """Primal to a ``custom_vjp``'s forward rule when ``symbolic_zeros`` is set"""
|
@@ -655,14 +691,18 @@ def _check_for_tracers(x):
|
655 | 691 | raise UnexpectedTracerError(msg)
|
656 | 692 |
|
657 | 693 | @partial(lu.transformation_with_aux2, use_eq_store=True)
|
658 |
| -def _flatten_fwd(f, store, symbolic_zeros, primal_name, fwd_name, in_tree, maybe_out_type, |
659 |
| - *args): |
| 694 | +def _flatten_fwd(f, store, nondiff_argnums, symbolic_zeros, primal_name, |
| 695 | + fwd_name, in_tree, maybe_out_type, *args): |
660 | 696 | if symbolic_zeros:
|
661 | 697 | args = [CustomVJPPrimal(x, z) for x, z in zip(args[::2], args[1::2])]
|
662 | 698 | else:
|
663 | 699 | args = args[::2]
|
664 | 700 | py_args = tree_unflatten(in_tree, args)
|
| 701 | + if config.mutable_array_checks.value: |
| 702 | + _check_for_aliased_refs(f, nondiff_argnums, py_args) |
665 | 703 | pair_out = f(*py_args)
|
| 704 | + if config.mutable_array_checks.value: |
| 705 | + _check_for_returned_refs(f, pair_out, 'fwd') |
666 | 706 | if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
|
667 | 707 | msg = (f"Custom VJP fwd rule {fwd_name} for function {primal_name} "
|
668 | 708 | "must produce a pair (list or tuple of length two) where the first "
|
@@ -1393,8 +1433,8 @@ def wrapped_fwd(*args, **kwargs) -> tuple[ReturnValue, Any]:
|
1393 | 1433 | fwd_ = lu.wrap_init(fwd)
|
1394 | 1434 | args_flat, in_tree = tree_flatten(dyn_args)
|
1395 | 1435 | flat_fun, out_type = _flatten_fun_nokwargs(f_, in_tree)
|
1396 |
| - flat_fwd, out_trees = _flatten_fwd(fwd_, False, primal_name, fwd_name, |
1397 |
| - in_tree, out_type) |
| 1436 | + flat_fwd, out_trees = _flatten_fwd(fwd_, nondiff_argnums, False, |
| 1437 | + primal_name, fwd_name, in_tree, out_type) |
1398 | 1438 | flat_fwd = _fix_fwd_args(flat_fwd)
|
1399 | 1439 |
|
1400 | 1440 | in_avals = [core.get_aval(x) for x in args_flat]
|
|
0 commit comments