@@ -2076,14 +2076,22 @@ def _pjit_linearization(nzs, *primals_in, jaxpr,
2076
2076
donated_invars , ctx_mesh , name , keep_unused , inline ,
2077
2077
compiler_options_kvs ):
2078
2078
primal_jaxpr , num_residuals , nzs_out , tangent_jaxpr = ad .linearize_jaxpr (jaxpr , nzs )
2079
- # constvars will become residuals. Move them to the end of the ordinary args.
2080
2079
res_shardings = (UNSPECIFIED ,) * num_residuals
2081
2080
res_layouts = (None ,) * num_residuals
2082
2081
res_donated = (False ,) * num_residuals
2082
+
2083
+ in_fwd = pe ._jaxpr_forwarding (primal_jaxpr .jaxpr )
2084
+ in_fwd , _ = split_list (in_fwd , [num_residuals ])
2085
+ keep = tuple (f is None for f in in_fwd ) + (True ,) * len (out_shardings )
2086
+ primal_jaxpr = pe .prune_closed_jaxpr_outputs (primal_jaxpr , keep )
2087
+ num_residuals = sum (f is None for f in in_fwd )
2088
+
2083
2089
def tangent_fun (consts_ , * tangents ):
2090
+ consts_it = iter (consts_ )
2091
+ res = [next (consts_it ) if f is None else primals_in [f ] for f in in_fwd ]
2092
+ assert next (consts_it , None ) is None
2084
2093
tangents_nz = _filter_zeros (nzs , tangents )
2085
- assert len (consts_ ) == num_residuals
2086
- nz_tangents_out = pjit_p .bind (* (* tangents_nz , * consts_ ),
2094
+ nz_tangents_out = pjit_p .bind (* (* tangents_nz , * res ),
2087
2095
jaxpr = tangent_jaxpr ,
2088
2096
in_shardings = _filter_zeros (nzs , in_shardings ) + res_shardings ,
2089
2097
out_shardings = _filter_zeros (nzs_out , out_shardings ),
@@ -2106,9 +2114,9 @@ def _filter_zeros(is_nz_l, l):
2106
2114
2107
2115
ans = pjit_p .bind (* primals_in , jaxpr = primal_jaxpr ,
2108
2116
in_shardings = in_shardings ,
2109
- out_shardings = (* res_shardings , * out_shardings ),
2117
+ out_shardings = (* res_shardings [: num_residuals ] , * out_shardings ),
2110
2118
in_layouts = in_layouts ,
2111
- out_layouts = (* res_layouts , * out_layouts ),
2119
+ out_layouts = (* res_layouts [: num_residuals ] , * out_layouts ),
2112
2120
donated_invars = donated_invars ,
2113
2121
ctx_mesh = ctx_mesh ,
2114
2122
name = name ,
0 commit comments