Skip to content

Commit

Permalink
Corrected logic when generating mesh configurations for DP_TP shardin…
Browse files Browse the repository at this point in the history
…g tests, eliminated redundant assert in compare_ops, fixed missing license comments

Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Oct 23, 2023
1 parent a1d0744 commit 77d4c46
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 24 deletions.
38 changes: 18 additions & 20 deletions tests/jax/custom_ops_helper.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import pytest
import numpy as np
from dataclasses import dataclass
Expand Down Expand Up @@ -117,27 +120,22 @@ def compare_ops(self, custom_func, ref_func, ref_count,
f"`{func_name}`: Output (fwd) error {jnp.max(jnp.abs(test_fwd - ref_fwd))}" + \
f" exceeds tolerance ({fwd_tol})."

num_grads = len(ref_grads) if isinstance(ref_grads, tuple) else 1
if num_grads > 1:
failed_grads = {}
for i, grads in enumerate(zip(test_grads, ref_grads)):
test_grad, ref_grad = grads
if test_grad is None and ref_grad is None:
continue
bwd_tol = max(np.finfo(jnp.float32).eps,
np.spacing(jnp.max(jnp.abs(ref_grad)).astype(jnp.float32))) ** (2./3.)
if not jnp.allclose(test_grad, ref_grad, rtol=0.0, atol=bwd_tol):
failed_grads[i] = jnp.max(jnp.abs(test_grad - ref_grad))
assert len(failed_grads) == 0, \
f"`{func_name}`: Gradient (bwd) max errors" + \
f" [{', '.join([f'Arg{k}={v}' for k,v in failed_grads.items()])}]" + \
f" exceed tolerance ({bwd_tol})."
else:
if len(grad_args) == 1:
ref_grads = (ref_grads, )
test_grads = (test_grads, )
failed_grads = {}
for i, grads in enumerate(zip(test_grads, ref_grads)):
test_grad, ref_grad = grads
if test_grad is None and ref_grad is None:
continue
bwd_tol = max(np.finfo(jnp.float32).eps,
np.spacing(jnp.max(jnp.abs(ref_grads)).astype(jnp.float32))) ** (2./3.)
assert jnp.allclose(test_grads, ref_grads, rtol=0.0, atol=bwd_tol), \
f"`{func_name}`: Gradient (bwd) max error" + \
f" {jnp.max(jnp.abs(test_grads - ref_grads))} exceeds tolerance ({bwd_tol})."
np.spacing(jnp.max(jnp.abs(ref_grad)).astype(jnp.float32))) ** (2./3.)
if not jnp.allclose(test_grad, ref_grad, rtol=0.0, atol=bwd_tol):
failed_grads[i] = jnp.max(jnp.abs(test_grad - ref_grad))
assert len(failed_grads) == 0, \
f"`{func_name}`: Gradient (bwd) max errors" + \
f" [{', '.join([f'Arg{k}={v}' for k,v in failed_grads.items()])}]" + \
f" exceed tolerance ({bwd_tol})."

def check_fused_attn_inputs(self, q_seq, kv_seq, head_dim, pad_ratio, dropout_probability,
attn_bias_type, attn_mask_type, backend):
Expand Down
10 changes: 6 additions & 4 deletions tests/jax/sharding_configs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
from dataclasses import dataclass
from itertools import product
from transformer_engine.jax.sharding import ShardingType
from transformer_engine.jax.softmax import SoftmaxType
Expand All @@ -8,7 +10,7 @@

class ShardingConfigs(object):

def __init__(self, num_gpus=jax.device_count('gpu')):
def __init__(self, num_gpus=len(jax.devices())):
super().__init__()
if num_gpus < 2:
raise ValueError(f"ShardingConfig: Need at least 2 GPUs, but got {num_gpus}.")
Expand All @@ -19,12 +21,12 @@ def __init__(self, num_gpus=jax.device_count('gpu')):
((self.device_count, 1), ("tp", None), ShardingType.TP_COL),
((self.device_count, 1), ("tp", None), ShardingType.TP_ROW),
]
if self.device_count > 2:
if self.device_count >= 4:
mesh_configs += [
((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_COL),
((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW),
]
if self.device_count > 4:
if self.device_count >= 6:
mesh_configs += [
((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_COL),
((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_ROW),
Expand Down
4 changes: 4 additions & 0 deletions tests/jax/test_custom_ops_cpar.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#
# Regression tests for TE-JAX custom ops with cpar-based (Custom PARtitioning) sharding
# https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html
#
import os
import pytest
import numpy as np
Expand Down
4 changes: 4 additions & 0 deletions tests/jax/test_custom_ops_xmap.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
#
# Regression tests for TE-JAX custom ops with xmap-based sharding
# https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html
#
import os
import pytest
import numpy as np
Expand Down

0 comments on commit 77d4c46

Please sign in to comment.