-
Notifications
You must be signed in to change notification settings - Fork 352
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Regression tests for custom ops sharding with both xmap and custom_pa…
…rtitioning. Xma-based sharding tests are functional, while custom_partitioning tests are awaiting the custom ops migration to be merged in. Coverage: - layernorm: fwd/grad, zero_centered_gamma, DP, TP_COL, DP_TP_COL - rmsnorm: fwd/grad, DP, TP_COL, DP_TP_COL - softmax: fwd/grad, SCALED, SCALED_MASKED, SCALED_UPPER_TRIANG_MASKED, DP, TP_COL, TP_ROW, DP_TP_COL, DP_TP_ROW - self_fused_attn: fwd/grad, NO_BIAS, PRE_SCALE_BIAS, POST_SCALE_BIAS, NO_MASK, CAUSAL_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL - cross_fused_attn: fwd/grad, NO_BIAS, NO_MASK, PADDING_MASK, DP, TP_COL, DP_TP_COL Signed-off-by: Alp Dener <[email protected]>
- Loading branch information
Showing
5 changed files
with
841 additions
and
301 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
import pytest | ||
import numpy as np | ||
from dataclasses import dataclass | ||
from typing import Tuple | ||
from enum import Enum | ||
from functools import partial | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax import random | ||
from jax.experimental.pjit import pjit, _UNSPECIFIED | ||
from jax.sharding import PartitionSpec | ||
|
||
import flax | ||
|
||
from transformer_engine.jax.sharding import ShardingType, ShardingResource | ||
try: | ||
# temporary workaround to be removed after jax.experimental.custom_partitioning migration | ||
from transformer_engine.jax.sharding import MeshResource | ||
except ImportError: | ||
pytest.skip("Need working MeshResource implementation to test " + | ||
"jax.experimental.custom_partitioning sharding.") | ||
from transformer_engine.jax.layernorm import layernorm | ||
from transformer_engine.jax.softmax import SoftmaxType, softmax | ||
from transformer_engine.jax.fused_attn import \ | ||
AttnBiasType, AttnMaskType, is_fused_attn_kernel_available, self_fused_attn, cross_fused_attn | ||
|
||
class FusedAttnBackend(Enum): | ||
Max512 = "0" | ||
Arbitrary = "1" | ||
|
||
@dataclass | ||
class CustomOpsTestHelper: | ||
qkv_shape: Tuple[int,int,int,int] = (32, 128, 16, 64) | ||
pad_ratio: float = 0.3 | ||
dropout_prob: float = 0.1 | ||
dtype: type = jnp.float16 | ||
|
||
@staticmethod | ||
def get_sharding_spec(mesh_names, sharding_type): | ||
P = PartitionSpec | ||
if sharding_type is ShardingType.DP: | ||
return P(mesh_names[0], None), P(None), P(None) | ||
elif sharding_type is ShardingType.DP_TP_COL: | ||
return P(mesh_names[0], mesh_names[1]), P(None), P(None) | ||
else: | ||
raise NotImplementedError | ||
|
||
@staticmethod | ||
def get_sharding_resource(mesh_names, sharding_type): | ||
dp_r = None | ||
tp_r = None | ||
if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW): | ||
dp_r = mesh_names[0] | ||
if sharding_type in (ShardingType.TP_COL, ShardingType.TP_ROW): | ||
tp_r = mesh_names[0] | ||
if sharding_type in (ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW): | ||
tp_r = mesh_names[1] | ||
xmap_resource = ShardingResource(dp_r, tp_r) | ||
cp_resource = MeshResource(dp_r, tp_r) if MeshResource is not None else None | ||
return xmap_resource, cp_resource | ||
|
||
@staticmethod | ||
def make_mask(q_tokens, kv_tokens, mask_type, dtype=jnp.uint8): | ||
if mask_type == AttnMaskType.CAUSAL_MASK: | ||
causal = flax.linen.make_causal_mask(q_tokens, dtype=dtype) | ||
padding = flax.linen.make_attention_mask(q_tokens > 0, kv_tokens > 0, dtype=dtype) | ||
return flax.linen.combine_masks(causal, padding) | ||
else: | ||
return flax.linen.make_attention_mask(q_tokens > 0, kv_tokens > 0, dtype=dtype) | ||
@staticmethod | ||
def count_collectives(hlo): | ||
tmp = hlo.splitlines() | ||
symb = "-start" | ||
result = { | ||
"all-reduce" : 0, | ||
"other" : 0 | ||
} | ||
for line in tmp: | ||
txt = line.split() | ||
if len(txt) > 0 and symb in txt[0]: | ||
if "all-reduce" in txt[0]: | ||
result["all-reduce"] += 1 | ||
else: | ||
result["other"] += 1 | ||
return result | ||
|
||
def compare_ops(self, custom_func, ref_func, ref_count, | ||
*args, grad_args=None, | ||
in_shardings=_UNSPECIFIED, out_shardings=_UNSPECIFIED, | ||
**kwargs): | ||
if isinstance(custom_func, partial): | ||
func_name = custom_func.func.__name__ | ||
else: | ||
func_name = custom_func.__name__ | ||
func_name = func_name.removeprefix('custom_') | ||
if grad_args is None: | ||
grad_args = tuple(range(len(args))) | ||
|
||
custom_gradded = jax.value_and_grad(custom_func, argnums=grad_args) | ||
test_fwd, test_grads = custom_gradded(*args, **kwargs) | ||
custom_pjitter = pjit(custom_gradded, | ||
in_shardings=in_shardings, | ||
out_shardings=out_shardings) | ||
custom_hlo = custom_pjitter.lower(*args, **kwargs).compile().as_text() | ||
custom_count = self.count_collectives(custom_hlo) | ||
if ref_count is not None: | ||
assert custom_count==ref_count, \ | ||
f"`{func_name}`: Expected collective count is {ref_count}, but got {custom_count}." | ||
else: | ||
print(f"`{func_name}`: Output collective count is {custom_count}.") | ||
|
||
ref_gradded = jax.value_and_grad(ref_func, argnums=grad_args) | ||
ref_fwd, ref_grads = ref_gradded(*args, **kwargs) | ||
fwd_tol = max(np.finfo(jnp.float16).eps, np.spacing(jnp.float16(ref_fwd))) ** (2./3.) | ||
assert jnp.allclose(test_fwd, ref_fwd, rtol=0.0, atol=fwd_tol), \ | ||
f"`{func_name}`: Output (fwd) error {jnp.max(jnp.abs(test_fwd - ref_fwd))}" + \ | ||
f" exceeds tolerance ({fwd_tol})." | ||
|
||
num_grads = len(ref_grads) if isinstance(ref_grads, tuple) else 1 | ||
if num_grads > 1: | ||
failed_grads = {} | ||
for i, grads in enumerate(zip(test_grads, ref_grads)): | ||
test_grad, ref_grad = grads | ||
if test_grad is None and ref_grad is None: | ||
continue | ||
bwd_tol = max(np.finfo(jnp.float32).eps, | ||
np.spacing(jnp.max(jnp.abs(ref_grad)).astype(jnp.float32))) ** (2./3.) | ||
if not jnp.allclose(test_grad, ref_grad, rtol=0.0, atol=bwd_tol): | ||
failed_grads[i] = jnp.max(jnp.abs(test_grad - ref_grad)) | ||
assert len(failed_grads) == 0, \ | ||
f"`{func_name}`: Gradient (bwd) max errors" + \ | ||
f" [{', '.join([f'Arg{k}={v}' for k,v in failed_grads.items()])}]" + \ | ||
f" exceed tolerance ({bwd_tol})." | ||
else: | ||
bwd_tol = max(np.finfo(jnp.float32).eps, | ||
np.spacing(jnp.max(jnp.abs(ref_grads)).astype(jnp.float32))) ** (2./3.) | ||
assert jnp.allclose(test_grads, ref_grads, rtol=0.0, atol=bwd_tol), \ | ||
f"`{func_name}`: Gradient (bwd) max error" + \ | ||
f" {jnp.max(jnp.abs(test_grads - ref_grads))} exceeds tolerance ({bwd_tol})." | ||
|
||
def check_fused_attn_inputs(self, q_seq, kv_seq, head_dim, pad_ratio, dropout_probability, | ||
attn_bias_type, attn_mask_type, backend): | ||
if (q_seq > 512 or kv_seq > 512 or backend == FusedAttnBackend.Arbitrary) \ | ||
and pad_ratio != 0: | ||
pytest.skip( | ||
"`fused_attention`: Arbitrary seqlen backend does not support padded input.") | ||
|
||
if not is_fused_attn_kernel_available( | ||
self.dtype, self.dtype, attn_bias_type, attn_mask_type, | ||
dropout_probability, q_seq, kv_seq, head_dim): | ||
pytest.skip( | ||
"`fused_attention`: Unsupported inputs combination or device compute capability.") | ||
|
||
def fused_attn_core(self, query, key, value, bias, mask, scale_factor, | ||
attn_bias_type, attn_mask_type, dropout_rng, dropout_prob): | ||
# Q*K matmul | ||
query = jnp.squeeze(query) | ||
key = jnp.squeeze(key) | ||
value = jnp.squeeze(value) | ||
attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key) | ||
# scale and bias | ||
if attn_bias_type == AttnBiasType.PRE_SCALE_BIAS: | ||
attn_weights = scale_factor * (attn_weights + bias) | ||
elif attn_bias_type == AttnBiasType.POST_SCALE_BIAS: | ||
attn_weights = scale_factor * attn_weights + bias | ||
else: | ||
attn_weights = scale_factor * attn_weights | ||
# padding mask | ||
if attn_mask_type != AttnMaskType.NO_MASK and mask is not None: | ||
big_neg = jnp.finfo(self.dtype).min | ||
attn_weights = jnp.where(mask, attn_weights, big_neg) | ||
# softmax | ||
attn_weights = jax.nn.softmax(attn_weights).astype(self.dtype) | ||
# dropout | ||
if dropout_prob == 1.0: | ||
attn_weights = jnp.zeros_like(attn_weights) | ||
elif dropout_prob > 0.0: | ||
keep_prob = 1.0 - dropout_prob | ||
keep = random.bernoulli(dropout_rng, p=keep_prob, shape=attn_weights.shape) | ||
multiplier = keep.astype(self.dtype) / jnp.asarray(keep_prob, dtype=self.dtype) | ||
attn_weights = attn_weights * multiplier | ||
# QK*V matmul | ||
result = jnp.einsum('...hqk,...khd->...qhd', attn_weights, value) | ||
return jnp.mean(result) | ||
|
||
@staticmethod | ||
def custom_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, sharding_type): | ||
result = layernorm(x, gamma, beta, | ||
layernorm_type='layernorm', | ||
zero_centered_gamma=zero_centered_gamma, | ||
epsilon=epsilon, | ||
sharding_type=sharding_type, | ||
dp_dim_index=0) | ||
return jnp.mean(result) | ||
|
||
def reference_layernorm(self, x, gamma, beta, zero_centered_gamma, epsilon): | ||
x_ = jnp.asarray(x, jnp.float32) | ||
mean = jnp.mean(x_, axis=-1, keepdims=True) | ||
var = jnp.mean(jnp.square(x_ - mean), axis=-1, keepdims=True) | ||
normed_input = (x_ - mean) * jax.lax.rsqrt(var + epsilon) | ||
if zero_centered_gamma: | ||
result = jnp.asarray(normed_input * (gamma + 1) + beta).astype(self.dtype) | ||
else: | ||
result = jnp.asarray(normed_input * gamma + beta).astype(self.dtype) | ||
return jnp.mean(result) | ||
|
||
@staticmethod | ||
def custom_rmsnorm(x, gamma, epsilon, sharding_type): | ||
result = layernorm(x, gamma, None, | ||
layernorm_type='rmsnorm', | ||
zero_centered_gamma=False, | ||
epsilon=epsilon, | ||
sharding_type=sharding_type, | ||
dp_dim_index=0) | ||
return jnp.mean(result) | ||
|
||
def reference_rmsnorm(self, x, gamma, epsilon): | ||
x = jnp.asarray(x, jnp.float32) | ||
mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) | ||
y = jnp.asarray(x * jax.lax.rsqrt(mean2 + epsilon), self.dtype) | ||
result = y * gamma | ||
return jnp.mean(result) | ||
|
||
@staticmethod | ||
def custom_softmax(x, mask, scale_factor, softmax_type, sharding_type): | ||
result = softmax(x, mask, | ||
scale_factor=scale_factor, | ||
softmax_type=softmax_type, | ||
sharding_type=sharding_type) | ||
return jnp.mean(result) | ||
|
||
def reference_softmax(self, x, mask, scale_factor, softmax_type): | ||
attn_weights = scale_factor * x | ||
if softmax_type != SoftmaxType.SCALED: | ||
big_neg = jnp.finfo(self.dtype).min | ||
attn_weights = jnp.where(mask, attn_weights, big_neg) | ||
result = jax.nn.softmax(attn_weights).astype(self.dtype) | ||
return jnp.mean(result) | ||
|
||
@staticmethod | ||
def custom_self_fused_attn(qkv, bias, mask, rng_key, dropout_prob, | ||
attn_bias_type, attn_mask_type, | ||
scaling_factor, sharding_type): | ||
mask = (mask == 0) # invert mask | ||
bias_ = None if attn_bias_type == AttnBiasType.NO_BIAS else bias | ||
result = self_fused_attn(qkv, bias_, mask, | ||
seed=rng_key, | ||
attn_bias_type=attn_bias_type, | ||
attn_mask_type=attn_mask_type, | ||
scaling_factor=scaling_factor, | ||
dropout_probability=dropout_prob, | ||
is_training=True, | ||
sharding_type=sharding_type) | ||
return jnp.mean(result) | ||
|
||
def reference_self_fused_attn(self, qkv, bias, mask, rng_key, dropout_prob, | ||
attn_bias_type, attn_mask_type, | ||
scaling_factor): | ||
# split interleaved QKV into separate matrices | ||
query, key, value = jnp.split(qkv, [1, 2], axis=-3) | ||
return self.fused_attn_core( | ||
query, key, value, bias, mask, scaling_factor, | ||
attn_bias_type, attn_mask_type, | ||
rng_key, dropout_prob) | ||
|
||
@staticmethod | ||
def custom_cross_fused_attn(query, key_value, mask, rng_key, dropout_prob, | ||
attn_mask_type, scaling_factor, sharding_type): | ||
mask = (mask == 0) # invert mask | ||
result = cross_fused_attn(query, key_value, mask, | ||
seed=rng_key, | ||
attn_bias_type=AttnBiasType.NO_BIAS, | ||
attn_mask_type=attn_mask_type, | ||
scaling_factor=scaling_factor, | ||
dropout_probability=dropout_prob, | ||
is_training=True, | ||
sharding_type=sharding_type) | ||
return jnp.mean(result) | ||
|
||
def reference_cross_fused_attn(self, query, key_value, mask, rng_key, dropout_prob, | ||
attn_mask_type, scaling_factor): | ||
key, value = jnp.split(key_value, [1], axis=-3) | ||
return self.fused_attn_core( | ||
query, key, value, None, mask, scaling_factor, | ||
AttnBiasType.NO_BIAS, attn_mask_type, | ||
rng_key, dropout_prob) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import jax | ||
from dataclasses import dataclass | ||
from itertools import product | ||
from transformer_engine.jax.sharding import ShardingType | ||
from transformer_engine.jax.softmax import SoftmaxType | ||
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType\ | ||
|
||
|
||
class ShardingConfigs(object): | ||
|
||
def __init__(self, num_gpus=jax.device_count('gpu')): | ||
super().__init__() | ||
if num_gpus < 2: | ||
raise ValueError(f"ShardingConfig: Need at least 2 GPUs, but got {num_gpus}.") | ||
|
||
self.device_count = min(num_gpus, 8) | ||
mesh_configs = [ | ||
((self.device_count, 1), ("dp", None), ShardingType.DP), | ||
((self.device_count, 1), ("tp", None), ShardingType.TP_COL), | ||
((self.device_count, 1), ("tp", None), ShardingType.TP_ROW), | ||
] | ||
if self.device_count > 2: | ||
mesh_configs += [ | ||
((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_COL), | ||
((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW), | ||
] | ||
if self.device_count > 4: | ||
mesh_configs += [ | ||
((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_COL), | ||
((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_ROW), | ||
] | ||
|
||
layernorm_collectives = { | ||
ShardingType.DP : {'all-reduce': 2, 'other': 0}, | ||
ShardingType.TP_COL : {'all-reduce': 0, 'other': 0}, | ||
ShardingType.DP_TP_COL : {'all-reduce': 2, 'other': 0} | ||
} | ||
self.layernorm_refs = [ | ||
mesh_config + (layernorm_collectives[mesh_config[2]], ) \ | ||
for mesh_config in mesh_configs \ | ||
if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) | ||
] | ||
|
||
self.softmax_types = [ | ||
SoftmaxType.SCALED, | ||
SoftmaxType.SCALED_MASKED, | ||
SoftmaxType.SCALED_UPPER_TRIANG_MASKED | ||
] | ||
softmax_collectives = { | ||
ShardingType.DP : {'all-reduce': 1, 'other': 0}, | ||
ShardingType.TP_COL : {'all-reduce': 1, 'other': 0}, | ||
ShardingType.TP_ROW : {'all-reduce': 1, 'other': 0}, | ||
ShardingType.DP_TP_COL : {'all-reduce': 1, 'other': 0}, | ||
ShardingType.DP_TP_ROW : {'all-reduce': 1, 'other': 0} | ||
} | ||
self.softmax_refs = [ | ||
mesh_config + (softmax_collectives[mesh_config[2]], ) for mesh_config in mesh_configs | ||
] | ||
|
||
self.self_attn_bias_types = [ | ||
AttnBiasType.NO_BIAS, | ||
AttnBiasType.PRE_SCALE_BIAS, | ||
AttnBiasType.POST_SCALE_BIAS | ||
] | ||
self.self_attn_mask_types = [ | ||
AttnMaskType.CAUSAL_MASK, | ||
AttnMaskType.PADDING_MASK, | ||
AttnMaskType.NO_MASK | ||
] | ||
self_attn_collectives = { | ||
ShardingType.DP : { | ||
AttnBiasType.NO_BIAS : {'all-reduce': 1, 'other': 0}, | ||
AttnBiasType.PRE_SCALE_BIAS : {'all-reduce': 2, 'other': 0}, | ||
AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 2, 'other': 0}, | ||
}, | ||
ShardingType.TP_COL : { | ||
AttnBiasType.NO_BIAS : {'all-reduce': 1, 'other': 0}, | ||
AttnBiasType.PRE_SCALE_BIAS : {'all-reduce': 1, 'other': 0}, | ||
AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 1, 'other': 0} | ||
}, | ||
ShardingType.DP_TP_COL : { | ||
AttnBiasType.NO_BIAS : {'all-reduce': 1, 'other': 0}, | ||
AttnBiasType.PRE_SCALE_BIAS : {'all-reduce': 2, 'other': 0}, | ||
AttnBiasType.POST_SCALE_BIAS : {'all-reduce': 2, 'other': 0} | ||
}, | ||
} | ||
self.self_attn_refs = [ | ||
mesh_config + (bias_type, self_attn_collectives[mesh_config[2]][bias_type]) \ | ||
for mesh_config, bias_type in product(mesh_configs, self.self_attn_bias_types) \ | ||
if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) | ||
] | ||
|
||
self.cross_attn_mask_types = [ | ||
AttnMaskType.PADDING_MASK, | ||
AttnMaskType.NO_MASK | ||
] | ||
self.cross_attn_refs = [ | ||
mesh_config + ({'all-reduce': 1, 'other': 0}, ) \ | ||
for mesh_config in mesh_configs \ | ||
if mesh_config[2] not in (ShardingType.TP_ROW, ShardingType.DP_TP_ROW) | ||
] |
Oops, something went wrong.