From fc38f21e5d1e4505ac33459efd06908c2e158d4c Mon Sep 17 00:00:00 2001 From: Anselm Levskaya Date: Fri, 13 Dec 2024 06:53:32 -0800 Subject: [PATCH] Introduce a mode for a faster nn.scan mode that avoids extra jax retracing. This adds a new keyword option to linen nn.scan `check_constancy_invariants` that defaults to True for the existing behavior. Setting it to False however avoids an extra jax trace to hoist scan loop constants out of the loop and to check for non-data-dependence of broadcast variables and body function outputs marked constant. The time savings from not running this extra trace and static check can be considerable when tracing and compiling larger models. PiperOrigin-RevId: 705869200 --- flax/core/axes_scan.py | 99 ++++++++++++++++++++-------- flax/core/lift.py | 9 ++- flax/linen/transforms.py | 7 ++ tests/linen/linen_transforms_test.py | 47 +++++++++++++ 4 files changed, 134 insertions(+), 28 deletions(-) diff --git a/flax/core/axes_scan.py b/flax/core/axes_scan.py index c495186d72..bc64221667 100644 --- a/flax/core/axes_scan.py +++ b/flax/core/axes_scan.py @@ -13,16 +13,17 @@ # limitations under the License. """Wrapper around jax.lax.scan with in_axes/out_axes API.""" +from collections.abc import Callable import functools from typing import Any, Optional -from collections.abc import Callable import jax -import jax.numpy as jnp -import numpy as np -from jax import core, lax +from jax import core +from jax import lax from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe +import jax.numpy as jnp +import numpy as np ScanAxis = Optional[int] @@ -35,13 +36,14 @@ class _Broadcast: def scan( - fn: Callable[..., Any], - in_axes: Any, - out_axes: Any, - length: int | None = None, - reverse: bool = False, - unroll: int = 1, - _split_transpose: bool = False + fn: Callable[..., Any], + in_axes: Any, + out_axes: Any, + length: int | None = None, + reverse: bool = False, + unroll: int = 1, + _split_transpose: bool = False, + check_constancy_invariants: bool = True, ): """A wrapper around `jax.lax.scan` with in_axes/out_axes api. @@ -78,6 +80,11 @@ def body_fn(b, c, x): iteration of a loop (default: 1). _split_transpose: An experimental feature to split the transpose of scan into a scan and a map, backed by an experimental Jax lax.scan() feature. + check_constancy_invariants: If true, the scan will verify that the + broadcast constants are true loop invariants, and further supports + broadcast function (non-carry) outputs. This requires an extra jax + tracing step however, so setting to false can reduce trace time on larger + models. Returns: the function that performs the scan of the form: (broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out). @@ -114,39 +121,43 @@ def trans(x): return jax.tree_util.tree_map(trans, xs) def scan_fn(broadcast_in, init, *args): + # Requires one extra tracing operation to test invariants: + # Verifies that broadcast constants are true loop invariants, and further + # supports broadcast function (non-carry) outputs. + xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args) def body_fn(c, xs, init_mode=False): # inject constants xs = jax.tree_util.tree_map( - lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs + lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs ) broadcast_out, c, ys = fn(broadcast_in, c, *xs) if init_mode: ys = jax.tree_util.tree_map( - lambda ax, y: (y if ax is broadcast else ()), out_axes, ys + lambda ax, y: (y if ax is broadcast else ()), out_axes, ys ) return broadcast_out, ys else: ys = jax.tree_util.tree_map( - lambda ax, y: (() if ax is broadcast else y), out_axes, ys + lambda ax, y: (() if ax is broadcast else y), out_axes, ys ) return c, ys broadcast_body = functools.partial(body_fn, init_mode=True) carry_avals = jax.tree_util.tree_map( - lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init + lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init ) scan_avals = jax.tree_util.tree_map( - lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs + lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs ) input_avals = (carry_avals, scan_avals) in_avals, in_tree = jax.tree_util.tree_flatten(input_avals) f_flat, out_tree = jax.api_util.flatten_fun_nokwargs( - lu.wrap_init(broadcast_body), in_tree + lu.wrap_init(broadcast_body), in_tree ) in_pvals = list(map(pe.PartialVal.unknown, in_avals)) _, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals) @@ -155,29 +166,63 @@ def body_fn(c, xs, init_mode=False): for pv, const in out_pvals: if pv is not None: raise ValueError( - 'broadcasted variable has a data dependency on the scan body.' + 'broadcasted variable has a data dependency on the scan body.' ) out_flat.append(const) broadcast_in, constants_out = jax.tree_util.tree_unflatten( - out_tree(), out_flat + out_tree(), out_flat ) if jax.version.__version_info__ > (0, 4, 25): c, ys = lax.scan( - body_fn, init, xs, length=length, reverse=reverse, unroll=unroll, - _split_transpose=_split_transpose + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll, + _split_transpose=_split_transpose ) else: c, ys = lax.scan( - body_fn, init, xs, length=length, reverse=reverse, unroll=unroll + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll ) ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys) ys = jax.tree_util.tree_map( - lambda ax, const, y: (const if ax is broadcast else y), - out_axes, - constants_out, - ys, + lambda ax, const, y: (const if ax is broadcast else y), + out_axes, + constants_out, + ys, ) return broadcast_in, c, ys - return scan_fn + def simple_scan_fn(broadcast_in, init, *args): + # Saves an extra tracing operation. + # No verification of constancy, and no support for non-carry broadcast + # function outputs. + xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args) + + if broadcast in jax.tree_util.tree_leaves(out_axes): + raise ValueError(f"nn.scan run with check_constancy_invariants=False " + f"does not support broadcast non-carry function " + f"outputs. out_axes was given as {out_axes}") + + def body_fn(c, xs): + # inject constants + xs = jax.tree_util.tree_map( + lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs + ) + _, c, ys = fn(broadcast_in, c, *xs) + return c, ys + + if jax.version.__version_info__ > (0, 4, 25): + c, ys = lax.scan( + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll, + _split_transpose=_split_transpose + ) + else: + c, ys = lax.scan( + body_fn, init, xs, length=length, reverse=reverse, unroll=unroll + ) + ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys) + return broadcast_in, c, ys + + if check_constancy_invariants: + return scan_fn + else: + return simple_scan_fn diff --git a/flax/core/lift.py b/flax/core/lift.py index f7b7bfb739..98a929be0c 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -879,6 +879,7 @@ def scan( _split_transpose: bool = False, data_transform: Callable[..., Any] | None = None, metadata_params: dict[Any, Any] = {}, + check_constancy_invariants: bool = True, ) -> Callable[..., Any]: """A lifted version of ``jax.lax.scan``. @@ -946,6 +947,11 @@ def body_fn(scope, c, x): intended for inline SPMD annotations. metadata_params: arguments dict passed to AxisMetadata instances in the variable tree. + check_constancy_invariants: If true, the scan will verify that the + broadcast constants are true loop invariants, and further supports + broadcast function (non-carry) outputs. This requires an extra jax + tracing step however, so setting to false can reduce trace time on larger + models. Returns: The scan function with the signature @@ -1000,7 +1006,8 @@ def find_length(axis, x): length=length, reverse=reverse, unroll=unroll, - _split_transpose=_split_transpose + _split_transpose=_split_transpose, + check_constancy_invariants=check_constancy_invariants, ) def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): carry_vars, c = carry diff --git a/flax/linen/transforms.py b/flax/linen/transforms.py index 80c44f9946..06a130f10e 100644 --- a/flax/linen/transforms.py +++ b/flax/linen/transforms.py @@ -1165,6 +1165,7 @@ def scan( metadata_params: Mapping[Any, Any] = {}, methods=None, _split_transpose: bool = False, + check_constancy_invariants: bool = True, ) -> Target: """A lifted version of ``jax.lax.scan``. @@ -1304,6 +1305,11 @@ def scan( methods: If ``target`` is a ``Module``, the methods of ``Module`` to scan over. _split_transpose: An experimental feature to split the transpose of a scan into a scan and a map, backed by an experimental Jax lax.scan() feature. + check_constancy_invariants: If true, the scan will verify that the + broadcast constants are true loop invariants, and further supports + broadcast function (non-carry) outputs. This requires an extra jax + tracing step however, so setting to false can reduce trace time on larger + models. Returns: The scan function with the signature ``(module, carry, *xs) -> (carry, @@ -1326,6 +1332,7 @@ def scan( data_transform=data_transform, metadata_params=metadata_params, methods=methods, + check_constancy_invariants=check_constancy_invariants, ) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index d5634a011f..8154a2c349 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -2715,6 +2715,53 @@ def __call__(self, x): params = foo.init(key, x) foo.apply(params, x) + @parameterized.named_parameters( + ('retracing scan', True), ('simple scan', False) + ) + def test_jit_scan_retracing(self, retracing_scan: bool): + num_blocks = 4 + num_patterns = 4 + features = 4 + trace_counts = [0, 0] + + class Block(nn.Module): + def setup(self): + self.dense = nn.Dense(features, use_bias=False) + @nn.jit + def __call__(self, x): + nonlocal trace_counts + trace_counts[1] += 1 + return self.dense(x) + + class BlockSequence(nn.Module): + def setup(self): + self.blocks = [Block() for _ in range(num_blocks)] + @nn.jit + def __call__(self, carry, inputs): + nonlocal trace_counts + trace_counts[0] += 1 + for block in self.blocks: + carry = block(carry) + return carry, inputs + + class Transformer(nn.Module): + retracing_scan: bool = True + def setup(self): + self.scan = nn.scan( + BlockSequence, + variable_axes={'params': 0}, + split_rngs={'params': False}, + length=num_patterns, + check_constancy_invariants=retracing_scan, + )() + def __call__(self, inputs): + return self.scan(jnp.zeros_like(inputs), inputs) + + model = Transformer(retracing_scan=retracing_scan) + _ = model.init(random.key(0), jnp.ones((num_patterns, features,))) + self.assertEqual(trace_counts[0], 2 if retracing_scan else 1) + self.assertEqual(trace_counts[1], 2 if retracing_scan else 1) + if __name__ == '__main__': absltest.main()