Skip to content

Commit 4880e2b

Browse files
committed
Apply forwarding in pjit linearization rule to avoid intermediate copies.
1 parent b0a920d commit 4880e2b

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

jax/_src/pjit.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,14 +2076,22 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
20762076
donated_invars, ctx_mesh, name, keep_unused, inline,
20772077
compiler_options_kvs):
20782078
primal_jaxpr, num_residuals, nzs_out, tangent_jaxpr = ad.linearize_jaxpr(jaxpr, nzs)
2079-
# constvars will become residuals. Move them to the end of the ordinary args.
20802079
res_shardings = (UNSPECIFIED,) * num_residuals
20812080
res_layouts = (None,) * num_residuals
20822081
res_donated = (False,) * num_residuals
2082+
2083+
in_fwd = pe._jaxpr_forwarding(primal_jaxpr.jaxpr)
2084+
in_fwd, _ = split_list(in_fwd, [num_residuals])
2085+
keep = tuple(f is None for f in in_fwd) + (True,) * len(out_shardings)
2086+
primal_jaxpr = pe.prune_closed_jaxpr_outputs(primal_jaxpr, keep)
2087+
num_residuals = sum(f is None for f in in_fwd)
2088+
20832089
def tangent_fun(consts_, *tangents):
2090+
consts_it = iter(consts_)
2091+
res = [next(consts_it) if f is None else primals_in[f] for f in in_fwd]
2092+
assert next(consts_it, None) is None
20842093
tangents_nz = _filter_zeros(nzs, tangents)
2085-
assert len(consts_) == num_residuals
2086-
nz_tangents_out = pjit_p.bind(*(*tangents_nz, *consts_),
2094+
nz_tangents_out = pjit_p.bind(*(*tangents_nz, *res),
20872095
jaxpr=tangent_jaxpr,
20882096
in_shardings=_filter_zeros(nzs, in_shardings) + res_shardings,
20892097
out_shardings=_filter_zeros(nzs_out, out_shardings),
@@ -2106,9 +2114,9 @@ def _filter_zeros(is_nz_l, l):
21062114

21072115
ans = pjit_p.bind(*primals_in, jaxpr=primal_jaxpr,
21082116
in_shardings=in_shardings,
2109-
out_shardings=(*res_shardings, *out_shardings),
2117+
out_shardings=(*res_shardings[:num_residuals], *out_shardings),
21102118
in_layouts=in_layouts,
2111-
out_layouts=(*res_layouts, *out_layouts),
2119+
out_layouts=(*res_layouts[:num_residuals], *out_layouts),
21122120
donated_invars=donated_invars,
21132121
ctx_mesh=ctx_mesh,
21142122
name=name,

0 commit comments

Comments
 (0)