Skip to content

Commit

Permalink
[JAX] Bugfix for softmax primitives accepting invalid input sharding (#…
Browse files Browse the repository at this point in the history
…664)

* Softmax now forces XLA to unshard the hidden dimension with a warning. Unittests updated to check for numerics and warning with bad sharding

Signed-off-by: Alp Dener <[email protected]>

* correcting cudnn-frontend version

Signed-off-by: Alp Dener <[email protected]>

* fixed mismatched output sharding

Signed-off-by: Alp Dener <[email protected]>

* combined softmax tests and fixed code style/linting issues

Signed-off-by: Alp Dener <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera authored Feb 28, 2024
1 parent 8bba5ee commit 0404095
Show file tree
Hide file tree
Showing 2 changed files with 144 additions and 78 deletions.
80 changes: 55 additions & 25 deletions tests/jax/test_distributed_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
#
# See LICENSE for license information.

import warnings
import pytest
from functools import partial

import jax
import jax.numpy as jnp
Expand All @@ -26,7 +28,7 @@ def generate_collectives_count_ref(self):
all_reduce_loss_bytes = 4 # 1 * FP32
return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0)

def generate_inputs(self, shape, mesh_resource, softmax_type, dtype):
def generate_inputs(self, shape, mesh_resource, softmax_type, dtype, bad_sharding):
batch, _, sqelen, _ = shape

x = random.normal(random.PRNGKey(1124), shape, dtype=dtype)
Expand All @@ -35,49 +37,77 @@ def generate_inputs(self, shape, mesh_resource, softmax_type, dtype):
else:
mask = make_self_mask(batch, sqelen)

x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource, None, None)
if not bad_sharding:
x_pspec = PartitionSpec(mesh_resource.dp_resource, mesh_resource.tp_resource,
None, None)
else:
x_pspec = PartitionSpec(mesh_resource.dp_resource, None,
None, mesh_resource.tp_resource)
mask_pspec = PartitionSpec(mesh_resource.dp_resource, None, None, None)

return (x, mask), (x_pspec, mask_pspec)

@staticmethod
def target_func(x, mask, scale_factor=1.0, softmax_type=SoftmaxType.SCALED):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))

@staticmethod
def ref_func(x, mask, scale_factor=1.0, dtype=jnp.float16):
bias = None
if mask is not None:
bias = jax.lax.select(mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.).astype(dtype))
if bias is not None:
x = x + bias.astype(dtype)
output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output)

@pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs())
@pytest.mark.parametrize('data_shape', [[32, 12, 128, 128], [64, 16, 1024, 1024]])
@pytest.mark.parametrize(
'softmax_type',
[SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED])
@pytest.mark.parametrize('scale_factor', [1.0, 3.0])
@pytest.mark.parametrize('dtype', DTYPES)
@pytest.mark.parametrize('bad_sharding', [False, True])
def test_softmax(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape,
softmax_type, scale_factor, dtype):

def target_func(x, mask):
return jnp.mean(softmax(x, mask, scale_factor=scale_factor, softmax_type=softmax_type))
softmax_type, scale_factor, dtype, bad_sharding):

def ref_func(x, mask):
bias = None
if mask is not None:
bias = jax.lax.select(mask > 0,
jnp.full(mask.shape, -1e10).astype(dtype),
jnp.full(mask.shape, 0.).astype(dtype))
if bias is not None:
x = x + bias.astype(dtype)
output = jax.nn.softmax(x * scale_factor)
return jnp.mean(output)
target_func = partial(self.target_func,
scale_factor=scale_factor,
softmax_type=softmax_type)
ref_func = partial(self.ref_func, scale_factor=scale_factor, dtype=dtype)

(x, mask), (x_pspec, mask_pspec) = \
self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype)
self.generate_inputs(data_shape, mesh_resource, softmax_type, dtype, bad_sharding)
collective_count_ref = self.generate_collectives_count_ref()
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes)
with mesh, fp8_autocast(mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec))
mask_ = jax.device_put(mask, NamedSharding(mesh, mask_pspec))

compare_ops(target_func,
ref_func, [x_, mask_],
collective_count_ref,
grad_args=(0,),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)))
with warnings.catch_warnings(record=True) as warns:
try:
compare_ops(target_func,
ref_func, [x_, mask_],
collective_count_ref,
grad_args=(0,),
metric_fwd_dtype=dtype,
metric_bwd_dtype=dtype,
in_shardings=(x_pspec, mask_pspec),
out_shardings=(None, (x_pspec,)))
except AssertionError as err:
# Softmax should still produce the correct numerical result with
# bad sharding. However, the collective count may not be the same
# when XLA is forced to unshard the hidden dimension. We can catch
# and ignore that specific error here.
if not bad_sharding or "Expected collective count" not in str(err):
raise err
finally:
for w in warns:
assert "Sharding the hidden dimension is not supported" in str(w), (
"Softmax primitive did not raise the correct warning for "
"unsupported sharding in the hidden dimension."
)
142 changes: 89 additions & 53 deletions transformer_engine/jax/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ class SoftmaxPrimitive(BasePrimitive):
Softmax Primitive
"""
max_k_seqlen_supported = 4096
name = "te_softmax_internal_placeholder"

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -1118,26 +1119,37 @@ def forward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
out_bdims = logits_bdim
return primitive.bind(logits, scale_factor=scale_factor), out_bdims

@staticmethod
def forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
@classmethod
def forward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
"""
softmax_forward infer_sharding_from_operands
"""
del scale_factor, result_infos # Unused.
logits_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
if logits_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
return out_sharding

@staticmethod
def forward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
@classmethod
def forward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
"""
softmax_forward partitioning
"""
del result_infos
logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
out_spec = logits_spec
arg_shardings = (logits_spec,)
out_shardings = out_spec
logits_spec = get_padded_spec(arg_infos[0])
if logits_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
out_shardings = NamedSharding(mesh, PartitionSpec(*logits_spec[:-1], None))
arg_shardings = (out_shardings,)
impl = partial(impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings

Expand All @@ -1154,7 +1166,7 @@ def backward_abstract(dz_aval, softmax_out_aval, scale_factor=None): # pylint

assert dz_aval.shape == softmax_out_aval.shape

dx_aval = core.raise_to_shaped(softmax_out_aval)
dx_aval = core.raise_to_shaped(dz_aval)
return dx_aval

@staticmethod
Expand All @@ -1177,7 +1189,7 @@ def backward_lowering(name, ctx, dz, softmax_out, *, scale_factor):
softmax_out_type = ir.RankedTensorType(softmax_out.type)
softmax_out_shape = softmax_out_type.shape

out_types = [ir.RankedTensorType.get(softmax_out_shape, softmax_out_type.element_type)]
out_types = [ir.RankedTensorType.get(dz_shape, dz_type.element_type)]
operands = [dz, softmax_out]
operand_shapes = [dz_shape, softmax_out_shape]
args = CustomCallArgsWrapper(out_types, operands, operand_shapes)
Expand Down Expand Up @@ -1211,27 +1223,44 @@ def backward_batcher(primitive, batched_args, batch_dims, *, scale_factor):
out_bdims = softmax_out_bdim
return primitive.bind(dz, softmax_out, scale_factor=scale_factor), out_bdims

@staticmethod
def backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
@classmethod
def backward_infer_sharding_from_operands(cls, scale_factor, mesh, arg_infos, result_infos):
"""
softmax_backward infer_sharding_from_operands
"""
del scale_factor, result_infos # Unused.
softmax_out_spec = get_padded_spec(arg_infos[1])
dx_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec))
dz_spec = get_padded_spec(arg_infos[0])
if dz_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)
dx_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
return dx_sharding

@staticmethod
def backward_partition(impl, scale_factor, mesh, arg_infos, result_infos):
@classmethod
def backward_partition(cls, impl, scale_factor, mesh, arg_infos, result_infos):
"""
softmax_backward partition
"""
del result_infos
dz_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
softmax_out_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
dx_spec = softmax_out_spec
arg_shardings = (dz_spec, softmax_out_spec)
out_shardings = dx_spec

dz_spec = get_padded_spec(arg_infos[0])
softmax_out_spec = get_padded_spec(arg_infos[1])
if dz_spec[-1] is not None or softmax_out_spec[-1] is not None:
warnings.warn(
f"Sharding the hidden dimension is not supported in {cls.name}! " \
f"Forcing XLA to not shard the hidden dim, which might introduce extra " \
f"collective ops and hurt performance."
)

dz_sharding = NamedSharding(mesh, PartitionSpec(*dz_spec[:-1], None))
softmax_out_sharding = NamedSharding(mesh, PartitionSpec(*softmax_out_spec[:-1], None))
dx_sharding = dz_sharding
arg_shardings = (dz_sharding, softmax_out_sharding)
out_shardings = dx_sharding

impl = partial(impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings

Expand Down Expand Up @@ -1296,13 +1325,15 @@ def batcher(batched_args, batch_dims, *, scale_factor):

@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)

@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, scale_factor,
mesh, arg_infos, result_infos)
return ScaledSoftmaxFwdPrimitive.forward_partition(
ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)


register_primitive(ScaledSoftmaxFwdPrimitive)
Expand Down Expand Up @@ -1370,13 +1401,15 @@ def batcher(batched_args, batch_dims, *, scale_factor):

@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)

@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, scale_factor,
mesh, arg_infos, result_infos)
return ScaledSoftmaxBwdPrimitive.backward_partition(
ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)


register_primitive(ScaledSoftmaxBwdPrimitive)
Expand Down Expand Up @@ -1505,20 +1538,15 @@ def batcher(batched_args, batch_dims, *, scale_factor):

@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
del scale_factor, result_infos # Unused.
logits_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*logits_spec))
return out_sharding
return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos,result_infos
)

@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
del result_infos
logits_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])))
mask_spec = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1])))
arg_shardings = (logits_spec, mask_spec)
out_shardings = logits_spec
impl = partial(ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor=scale_factor)
return mesh, impl, out_shardings, arg_shardings
return ScaledMaskedSoftmaxFwdPrimitive.backward_partition(
ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)


register_primitive(ScaledMaskedSoftmaxFwdPrimitive)
Expand Down Expand Up @@ -1589,13 +1617,15 @@ def batcher(batched_args, batch_dims, *, scale_factor):

@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)

@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_partition(ScaledMaskedSoftmaxBwdPrimitive.impl,
scale_factor, mesh, arg_infos, result_infos)
return ScaledMaskedSoftmaxBwdPrimitive.backward_partition(
ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos
)


register_primitive(ScaledMaskedSoftmaxBwdPrimitive)
Expand Down Expand Up @@ -1676,13 +1706,16 @@ def batcher(batched_args, batch_dims, *, scale_factor):

@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)

@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.forward_partition(ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl,
scale_factor, mesh, arg_infos, result_infos)
return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition(
ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh,
arg_infos, result_infos
)


register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive)
Expand Down Expand Up @@ -1753,13 +1786,16 @@ def batcher(batched_args, batch_dims, *, scale_factor):

@staticmethod
def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_infer_sharding_from_operands(scale_factor, mesh, arg_infos,
result_infos)
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands(
scale_factor, mesh, arg_infos, result_infos
)

@staticmethod
def partition(scale_factor, mesh, arg_infos, result_infos):
return SoftmaxPrimitive.backward_partition(ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl,
scale_factor, mesh, arg_infos, result_infos)
return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition(
ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh,
arg_infos, result_infos
)


register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive)
Expand Down

0 comments on commit 0404095

Please sign in to comment.