Skip to content

Commit

Permalink
Cleaner pytest skip condition for sharding tests.
Browse files Browse the repository at this point in the history
Signed-off-by: Alp Dener <[email protected]>
  • Loading branch information
denera committed Nov 1, 2023
1 parent bdbcba9 commit c832ea1
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
35 changes: 28 additions & 7 deletions tests/jax/custom_ops_helper.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import pytest
import numpy as np
from dataclasses import dataclass
Expand All @@ -16,28 +17,44 @@

import flax

from transformer_engine.jax.sharding import ShardingType, ShardingResource
from transformer_engine.jax.sharding import ShardingType
try:
# temporary workaround to be removed after jax.experimental.custom_partitioning migration
# try importing the new custom partitioning implementation
from transformer_engine.jax.sharding import MeshResource
ShardingResource = None
except ImportError:
pytest.skip("Need working MeshResource implementation to test " +
"jax.experimental.custom_partitioning sharding.")
# must be using an older TE/JAX version so fall back on the xmap sharding implementation
from transformer_engine.jax.sharding import ShardingResource
MeshResource = None
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.softmax import SoftmaxType, softmax
from transformer_engine.jax.fused_attn import \
AttnBiasType, AttnMaskType, is_fused_attn_kernel_available, self_fused_attn, cross_fused_attn


class FusedAttnBackend(Enum):
Max512 = "0"
Arbitrary = "1"


@pytest.fixture(name="backend", params=[FusedAttnBackend.Max512, FusedAttnBackend.Arbitrary])
def fixture_backend(request):
backend = request.param
os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value
yield backend
os.environ["NVTE_FUSED_ATTN_BACKEND"] = ""


@dataclass
class CustomOpsTestHelper:
qkv_shape: Tuple[int,int,int,int] = (32, 128, 16, 64)
pad_ratio: float = 0.3
dropout_prob: float = 0.1
dtype: type = jnp.float16

@staticmethod
def use_custom_partitioning():
return (MeshResource is not None)

@staticmethod
def get_sharding_spec(mesh_names, sharding_type):
Expand All @@ -53,15 +70,18 @@ def get_sharding_spec(mesh_names, sharding_type):
def get_sharding_resource(mesh_names, sharding_type):
dp_r = None
tp_r = None

if sharding_type in (ShardingType.DP, ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
dp_r = mesh_names[0]
if sharding_type in (ShardingType.TP_COL, ShardingType.TP_ROW):
tp_r = mesh_names[0]
if sharding_type in (ShardingType.DP_TP_COL, ShardingType.DP_TP_ROW):
tp_r = mesh_names[1]
xmap_resource = ShardingResource(dp_r, tp_r)
cp_resource = MeshResource(dp_r, tp_r) if MeshResource is not None else None
return xmap_resource, cp_resource

if CustomOpsTestHelper.use_custom_partitioning():
return MeshResource(dp_r, tp_r)
else:
return ShardingResource(dp_r, tp_r)

@staticmethod
def make_mask(q_tokens, kv_tokens, mask_type, dtype=jnp.uint8):
Expand All @@ -71,6 +91,7 @@ def make_mask(q_tokens, kv_tokens, mask_type, dtype=jnp.uint8):
return flax.linen.combine_masks(causal, padding)
else:
return flax.linen.make_attention_mask(q_tokens > 0, kv_tokens > 0, dtype=dtype)

@staticmethod
def count_collectives(hlo):
tmp = hlo.splitlines()
Expand Down
9 changes: 1 addition & 8 deletions tests/jax/test_custom_ops_cpar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# Regression tests for TE-JAX custom ops with cpar-based (Custom PARtitioning) sharding
# https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html
#
import os
import pytest
import numpy as np
from functools import partial
Expand All @@ -24,13 +23,7 @@
configs = ShardingConfigs(num_gpus=8)
helper = CustomOpsTestHelper()

@pytest.fixture(name="backend", params=[FusedAttnBackend.Max512, FusedAttnBackend.Arbitrary])
def fixture_backend(request):
backend = request.param
os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value
yield backend
os.environ["NVTE_FUSED_ATTN_BACKEND"] = ""

@pytest.mark.skipif(not helper.use_custom_partitioning())
@pytest.mark.skipif(not is_devices_enough(configs.device_count), reason='Num of GPU is not enough')
class TestCustomPartitioningOpsGenerator:

Expand Down
10 changes: 1 addition & 9 deletions tests/jax/test_custom_ops_xmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# Regression tests for TE-JAX custom ops with xmap-based sharding
# https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html
#
import os
import pytest
import numpy as np
from functools import partial
Expand All @@ -23,14 +22,7 @@
configs = ShardingConfigs(num_gpus=8)
helper = CustomOpsTestHelper()

@pytest.fixture(name="backend", params=[FusedAttnBackend.Max512, FusedAttnBackend.Arbitrary])
def fixture_backend(request):
backend = request.param
os.environ["NVTE_FUSED_ATTN_BACKEND"] = backend.value
yield backend
os.environ["NVTE_FUSED_ATTN_BACKEND"] = ""


@pytest.mark.skipif(helper.use_custom_partitioning())
@pytest.mark.skipif(not is_devices_enough(configs.device_count), reason='Num of GPU is not enough')
class TestXmapOpsGenerator:

Expand Down

0 comments on commit c832ea1

Please sign in to comment.