Skip to content

Commit b682a04

Browse files
deneracyanguwa
authored andcommitted
Incorrect use of extend_fsdp_sharding_meta() in cross_fused_attn() (NVIDIA#482)
fixed incorrect of extend_fsdp_sharding_meta() in cross_fused_attn() Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Charlene Yang <[email protected]>
1 parent f52de71 commit b682a04

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

transformer_engine/jax/fused_attn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def cross_fused_attn(q: jnp.ndarray,
206206
tp_dims=([2, 3, None, None], [2]),
207207
dp_axis_name=dp_axis_name,
208208
tp_axis_name=tp_axis_name)
209-
sharding_meta = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
209+
sharding_meta, _ = extend_fsdp_sharding_meta(sharding_meta, {0: 0, 2: 0})
210210

211211
inputs_ = tuple(
212212
jnp.reshape(x, new_shape) if x is not None else None

0 commit comments

Comments
 (0)