Skip to content

Commit

Permalink
Introduce a mode for a faster nn.scan mode that avoids extra jax retr…
Browse files Browse the repository at this point in the history
…acing.

This adds a new keyword option to linen nn.scan `check_constancy_invariants` that defaults
to True for the existing behavior.  Setting it to False however avoids an extra jax trace to hoist
scan loop constants out of the loop and to check for non-data-dependence of broadcast variables and
body function outputs marked constant.  The time savings from not running this extra trace and static
check can be considerable when tracing and compiling larger models.

PiperOrigin-RevId: 705869200
  • Loading branch information
levskaya authored and Flax Authors committed Dec 13, 2024
1 parent 207966e commit fc38f21
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 28 deletions.
99 changes: 72 additions & 27 deletions flax/core/axes_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,17 @@
# limitations under the License.

"""Wrapper around jax.lax.scan with in_axes/out_axes API."""
from collections.abc import Callable
import functools
from typing import Any, Optional
from collections.abc import Callable

import jax
import jax.numpy as jnp
import numpy as np
from jax import core, lax
from jax import core
from jax import lax
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe
import jax.numpy as jnp
import numpy as np

ScanAxis = Optional[int]

Expand All @@ -35,13 +36,14 @@ class _Broadcast:


def scan(
fn: Callable[..., Any],
in_axes: Any,
out_axes: Any,
length: int | None = None,
reverse: bool = False,
unroll: int = 1,
_split_transpose: bool = False
fn: Callable[..., Any],
in_axes: Any,
out_axes: Any,
length: int | None = None,
reverse: bool = False,
unroll: int = 1,
_split_transpose: bool = False,
check_constancy_invariants: bool = True,
):
"""A wrapper around `jax.lax.scan` with in_axes/out_axes api.
Expand Down Expand Up @@ -78,6 +80,11 @@ def body_fn(b, c, x):
iteration of a loop (default: 1).
_split_transpose: An experimental feature to split the transpose of scan
into a scan and a map, backed by an experimental Jax lax.scan() feature.
check_constancy_invariants: If true, the scan will verify that the
broadcast constants are true loop invariants, and further supports
broadcast function (non-carry) outputs. This requires an extra jax
tracing step however, so setting to false can reduce trace time on larger
models.
Returns:
the function that performs the scan of the form:
(broadcast_in, carry_in, *args) -> (broadcast_out, carry_out, scan_out).
Expand Down Expand Up @@ -114,39 +121,43 @@ def trans(x):
return jax.tree_util.tree_map(trans, xs)

def scan_fn(broadcast_in, init, *args):
# Requires one extra tracing operation to test invariants:
# Verifies that broadcast constants are true loop invariants, and further
# supports broadcast function (non-carry) outputs.

xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args)

def body_fn(c, xs, init_mode=False):
# inject constants
xs = jax.tree_util.tree_map(
lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs
lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs
)
broadcast_out, c, ys = fn(broadcast_in, c, *xs)

if init_mode:
ys = jax.tree_util.tree_map(
lambda ax, y: (y if ax is broadcast else ()), out_axes, ys
lambda ax, y: (y if ax is broadcast else ()), out_axes, ys
)
return broadcast_out, ys
else:
ys = jax.tree_util.tree_map(
lambda ax, y: (() if ax is broadcast else y), out_axes, ys
lambda ax, y: (() if ax is broadcast else y), out_axes, ys
)
return c, ys

broadcast_body = functools.partial(body_fn, init_mode=True)

carry_avals = jax.tree_util.tree_map(
lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init
lambda x: core.ShapedArray(jnp.shape(x), jnp.result_type(x)), init
)
scan_avals = jax.tree_util.tree_map(
lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs
lambda x: core.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x)), xs
)
input_avals = (carry_avals, scan_avals)

in_avals, in_tree = jax.tree_util.tree_flatten(input_avals)
f_flat, out_tree = jax.api_util.flatten_fun_nokwargs(
lu.wrap_init(broadcast_body), in_tree
lu.wrap_init(broadcast_body), in_tree
)
in_pvals = list(map(pe.PartialVal.unknown, in_avals))
_, out_pvals, _ = pe.trace_to_jaxpr_nounits(f_flat, in_pvals)
Expand All @@ -155,29 +166,63 @@ def body_fn(c, xs, init_mode=False):
for pv, const in out_pvals:
if pv is not None:
raise ValueError(
'broadcasted variable has a data dependency on the scan body.'
'broadcasted variable has a data dependency on the scan body.'
)
out_flat.append(const)
broadcast_in, constants_out = jax.tree_util.tree_unflatten(
out_tree(), out_flat
out_tree(), out_flat
)

if jax.version.__version_info__ > (0, 4, 25):
c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll,
_split_transpose=_split_transpose
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll,
_split_transpose=_split_transpose
)
else:
c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
)
ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
ys = jax.tree_util.tree_map(
lambda ax, const, y: (const if ax is broadcast else y),
out_axes,
constants_out,
ys,
lambda ax, const, y: (const if ax is broadcast else y),
out_axes,
constants_out,
ys,
)
return broadcast_in, c, ys

return scan_fn
def simple_scan_fn(broadcast_in, init, *args):
# Saves an extra tracing operation.
# No verification of constancy, and no support for non-carry broadcast
# function outputs.
xs = jax.tree_util.tree_map(transpose_to_front, in_axes, args)

if broadcast in jax.tree_util.tree_leaves(out_axes):
raise ValueError(f"nn.scan run with check_constancy_invariants=False "
f"does not support broadcast non-carry function "
f"outputs. out_axes was given as {out_axes}")

def body_fn(c, xs):
# inject constants
xs = jax.tree_util.tree_map(
lambda ax, arg, x: (arg if ax is broadcast else x), in_axes, args, xs
)
_, c, ys = fn(broadcast_in, c, *xs)
return c, ys

if jax.version.__version_info__ > (0, 4, 25):
c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll,
_split_transpose=_split_transpose
)
else:
c, ys = lax.scan(
body_fn, init, xs, length=length, reverse=reverse, unroll=unroll
)
ys = jax.tree_util.tree_map(transpose_from_front, out_axes, ys)
return broadcast_in, c, ys

if check_constancy_invariants:
return scan_fn
else:
return simple_scan_fn
9 changes: 8 additions & 1 deletion flax/core/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,7 @@ def scan(
_split_transpose: bool = False,
data_transform: Callable[..., Any] | None = None,
metadata_params: dict[Any, Any] = {},
check_constancy_invariants: bool = True,
) -> Callable[..., Any]:
"""A lifted version of ``jax.lax.scan``.
Expand Down Expand Up @@ -946,6 +947,11 @@ def body_fn(scope, c, x):
intended for inline SPMD annotations.
metadata_params: arguments dict passed to AxisMetadata instances in the
variable tree.
check_constancy_invariants: If true, the scan will verify that the
broadcast constants are true loop invariants, and further supports
broadcast function (non-carry) outputs. This requires an extra jax
tracing step however, so setting to false can reduce trace time on larger
models.
Returns:
The scan function with the signature
Expand Down Expand Up @@ -1000,7 +1006,8 @@ def find_length(axis, x):
length=length,
reverse=reverse,
unroll=unroll,
_split_transpose=_split_transpose
_split_transpose=_split_transpose,
check_constancy_invariants=check_constancy_invariants,
)
def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args):
carry_vars, c = carry
Expand Down
7 changes: 7 additions & 0 deletions flax/linen/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1165,6 +1165,7 @@ def scan(
metadata_params: Mapping[Any, Any] = {},
methods=None,
_split_transpose: bool = False,
check_constancy_invariants: bool = True,
) -> Target:
"""A lifted version of ``jax.lax.scan``.
Expand Down Expand Up @@ -1304,6 +1305,11 @@ def scan(
methods: If ``target`` is a ``Module``, the methods of ``Module`` to scan over.
_split_transpose: An experimental feature to split the transpose of a scan
into a scan and a map, backed by an experimental Jax lax.scan() feature.
check_constancy_invariants: If true, the scan will verify that the
broadcast constants are true loop invariants, and further supports
broadcast function (non-carry) outputs. This requires an extra jax
tracing step however, so setting to false can reduce trace time on larger
models.
Returns:
The scan function with the signature ``(module, carry, *xs) -> (carry,
Expand All @@ -1326,6 +1332,7 @@ def scan(
data_transform=data_transform,
metadata_params=metadata_params,
methods=methods,
check_constancy_invariants=check_constancy_invariants,
)


Expand Down
47 changes: 47 additions & 0 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,6 +2715,53 @@ def __call__(self, x):
params = foo.init(key, x)
foo.apply(params, x)

@parameterized.named_parameters(
('retracing scan', True), ('simple scan', False)
)
def test_jit_scan_retracing(self, retracing_scan: bool):
num_blocks = 4
num_patterns = 4
features = 4
trace_counts = [0, 0]

class Block(nn.Module):
def setup(self):
self.dense = nn.Dense(features, use_bias=False)
@nn.jit
def __call__(self, x):
nonlocal trace_counts
trace_counts[1] += 1
return self.dense(x)

class BlockSequence(nn.Module):
def setup(self):
self.blocks = [Block() for _ in range(num_blocks)]
@nn.jit
def __call__(self, carry, inputs):
nonlocal trace_counts
trace_counts[0] += 1
for block in self.blocks:
carry = block(carry)
return carry, inputs

class Transformer(nn.Module):
retracing_scan: bool = True
def setup(self):
self.scan = nn.scan(
BlockSequence,
variable_axes={'params': 0},
split_rngs={'params': False},
length=num_patterns,
check_constancy_invariants=retracing_scan,
)()
def __call__(self, inputs):
return self.scan(jnp.zeros_like(inputs), inputs)

model = Transformer(retracing_scan=retracing_scan)
_ = model.init(random.key(0), jnp.ones((num_patterns, features,)))
self.assertEqual(trace_counts[0], 2 if retracing_scan else 1)
self.assertEqual(trace_counts[1], 2 if retracing_scan else 1)


if __name__ == '__main__':
absltest.main()

0 comments on commit fc38f21

Please sign in to comment.