From 77d4c46301fcec72af9ecc9519b42c632c43514a Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Mon, 23 Oct 2023 08:19:19 -0700 Subject: [PATCH] Corrected logic when generating mesh configurations for DP_TP sharding tests, eliminated redundant assert in compare_ops, fixed missing license comments Signed-off-by: Alp Dener --- tests/jax/custom_ops_helper.py | 38 +++++++++++++++---------------- tests/jax/sharding_configs.py | 10 ++++---- tests/jax/test_custom_ops_cpar.py | 4 ++++ tests/jax/test_custom_ops_xmap.py | 4 ++++ 4 files changed, 32 insertions(+), 24 deletions(-) diff --git a/tests/jax/custom_ops_helper.py b/tests/jax/custom_ops_helper.py index e8dc93abe4..75de3710f3 100644 --- a/tests/jax/custom_ops_helper.py +++ b/tests/jax/custom_ops_helper.py @@ -1,3 +1,6 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. import pytest import numpy as np from dataclasses import dataclass @@ -117,27 +120,22 @@ def compare_ops(self, custom_func, ref_func, ref_count, f"`{func_name}`: Output (fwd) error {jnp.max(jnp.abs(test_fwd - ref_fwd))}" + \ f" exceeds tolerance ({fwd_tol})." - num_grads = len(ref_grads) if isinstance(ref_grads, tuple) else 1 - if num_grads > 1: - failed_grads = {} - for i, grads in enumerate(zip(test_grads, ref_grads)): - test_grad, ref_grad = grads - if test_grad is None and ref_grad is None: - continue - bwd_tol = max(np.finfo(jnp.float32).eps, - np.spacing(jnp.max(jnp.abs(ref_grad)).astype(jnp.float32))) ** (2./3.) - if not jnp.allclose(test_grad, ref_grad, rtol=0.0, atol=bwd_tol): - failed_grads[i] = jnp.max(jnp.abs(test_grad - ref_grad)) - assert len(failed_grads) == 0, \ - f"`{func_name}`: Gradient (bwd) max errors" + \ - f" [{', '.join([f'Arg{k}={v}' for k,v in failed_grads.items()])}]" + \ - f" exceed tolerance ({bwd_tol})." - else: + if len(grad_args) == 1: + ref_grads = (ref_grads, ) + test_grads = (test_grads, ) + failed_grads = {} + for i, grads in enumerate(zip(test_grads, ref_grads)): + test_grad, ref_grad = grads + if test_grad is None and ref_grad is None: + continue bwd_tol = max(np.finfo(jnp.float32).eps, - np.spacing(jnp.max(jnp.abs(ref_grads)).astype(jnp.float32))) ** (2./3.) - assert jnp.allclose(test_grads, ref_grads, rtol=0.0, atol=bwd_tol), \ - f"`{func_name}`: Gradient (bwd) max error" + \ - f" {jnp.max(jnp.abs(test_grads - ref_grads))} exceeds tolerance ({bwd_tol})." + np.spacing(jnp.max(jnp.abs(ref_grad)).astype(jnp.float32))) ** (2./3.) + if not jnp.allclose(test_grad, ref_grad, rtol=0.0, atol=bwd_tol): + failed_grads[i] = jnp.max(jnp.abs(test_grad - ref_grad)) + assert len(failed_grads) == 0, \ + f"`{func_name}`: Gradient (bwd) max errors" + \ + f" [{', '.join([f'Arg{k}={v}' for k,v in failed_grads.items()])}]" + \ + f" exceed tolerance ({bwd_tol})." def check_fused_attn_inputs(self, q_seq, kv_seq, head_dim, pad_ratio, dropout_probability, attn_bias_type, attn_mask_type, backend): diff --git a/tests/jax/sharding_configs.py b/tests/jax/sharding_configs.py index e2dab4857b..095f24939a 100644 --- a/tests/jax/sharding_configs.py +++ b/tests/jax/sharding_configs.py @@ -1,5 +1,7 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. import jax -from dataclasses import dataclass from itertools import product from transformer_engine.jax.sharding import ShardingType from transformer_engine.jax.softmax import SoftmaxType @@ -8,7 +10,7 @@ class ShardingConfigs(object): - def __init__(self, num_gpus=jax.device_count('gpu')): + def __init__(self, num_gpus=len(jax.devices())): super().__init__() if num_gpus < 2: raise ValueError(f"ShardingConfig: Need at least 2 GPUs, but got {num_gpus}.") @@ -19,12 +21,12 @@ def __init__(self, num_gpus=jax.device_count('gpu')): ((self.device_count, 1), ("tp", None), ShardingType.TP_COL), ((self.device_count, 1), ("tp", None), ShardingType.TP_ROW), ] - if self.device_count > 2: + if self.device_count >= 4: mesh_configs += [ ((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_COL), ((self.device_count//2, 2), ("dp", "tp"), ShardingType.DP_TP_ROW), ] - if self.device_count > 4: + if self.device_count >= 6: mesh_configs += [ ((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_COL), ((2, self.device_count//2), ("dp", "tp"), ShardingType.DP_TP_ROW), diff --git a/tests/jax/test_custom_ops_cpar.py b/tests/jax/test_custom_ops_cpar.py index d3730a942c..823a81a363 100644 --- a/tests/jax/test_custom_ops_cpar.py +++ b/tests/jax/test_custom_ops_cpar.py @@ -1,6 +1,10 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +# +# Regression tests for TE-JAX custom ops with cpar-based (Custom PARtitioning) sharding +# https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html +# import os import pytest import numpy as np diff --git a/tests/jax/test_custom_ops_xmap.py b/tests/jax/test_custom_ops_xmap.py index 62992f3fe5..0a53f5ad65 100644 --- a/tests/jax/test_custom_ops_xmap.py +++ b/tests/jax/test_custom_ops_xmap.py @@ -1,6 +1,10 @@ # Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. +# +# Regression tests for TE-JAX custom ops with xmap-based sharding +# https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html +# import os import pytest import numpy as np