-
Notifications
You must be signed in to change notification settings - Fork 337
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX] Collective GEMM custom op with nvte_cublas_gemm
(no comm. overlap)
#1307
Open
denera
wants to merge
19
commits into
NVIDIA:main
Choose a base branch
from
denera:jax-collective-gemm
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
ad6bf2a
added XLA custom op defs for TE GEMM
denera c9774d8
fixed batching rules to accommodated batched RHS operand for GEMM
denera e523018
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 2c3dbf1
re-applied bug fixes to working older version, updated backward pass,…
denera 448eaa9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] cb6ae3c
batched operands for GEMM custom op seem to be working now
denera 6f67355
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 4b2b2d4
fixed batch size 1 issue and enabled FSDP sharding for RHS operand
denera 2b2753e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 969f597
fixed FSDP+TP w/ DP=1 and TP+DP, but FSDP+TP w/ DP>1 still crashes
denera ce86dcb
fixed logic to remove FSDP sharding
denera b215f20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] cbab16c
retained FSDP dims and pushed FSDP all-gather of weight array to outs…
denera 0ea55c0
Added useful warning about DGRAD sharding not matching sequence/conte…
denera 2acb92f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] b07bb2d
documentation fixes
denera 765b844
added unit test, both AG+GEMM and GEMM+AR passing with FSDP+TP, DP+TP…
denera 2ce4377
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] f68d71e
restored old test_custom_call_compute.py to remove erroneous changes
denera File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,302 @@ | ||
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# See LICENSE for license information. | ||
import pytest | ||
from functools import partial | ||
from collections.abc import Iterable | ||
|
||
import numpy as np | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
from jax.sharding import Mesh, NamedSharding, PartitionSpec | ||
from jax.experimental import mesh_utils | ||
|
||
import transformer_engine.jax as te | ||
from transformer_engine.jax.gemm import gemm | ||
|
||
from utils import assert_allclose | ||
|
||
|
||
jax.config.update("jax_enable_compilation_cache", False) | ||
|
||
|
||
# AG+GEMM: (4, 32/P, 128) ----(AG)----> (4, 32, 128) x (128, 256/P) ----------> (4, 32, 256/P) | ||
# - DGRAD: (4, 32, 256/P) x (128, 256/P)^T --(AR)--> (4, 32, 128) | ||
# - WGRAD: (4, 32/P, 128)^T --(AG)--> (4, 32, 128)^T x (4, 32, 256/P) --------> (128, 256/P) | ||
|
||
# GEMM+AR: (4, 32, 256/P) x (256/P, 128) --(AR)--> (4, 32, 128) | ||
# - DGRAD: (4, 32, 128) x (256/P, 128)^T ------> (4, 32, 256/P) | ||
# - WGRAD: (4, 32, 256/P)^T --(AG)--> (4, 32, 256)^T x (4, 32, 128) --------> (256, 128) | ||
|
||
BATCH = 4 | ||
BASE_SIZE = 16 | ||
SEQ_LEN = BASE_SIZE * 8 | ||
HIDDEN_SIZE = BASE_SIZE * 6 | ||
FFN_HIDDEN_SIZE = BASE_SIZE * 16 | ||
|
||
COMM_TYPES = ["ALL_GATHER", "ALL_REDUCE"] | ||
MESH_TYPES = ["FSDP_TP", "DP_TP", "TP"] | ||
NUM_DEVICES = 4 | ||
|
||
is_fp8_supported, no_fp8_reason = te.fp8.is_fp8_available() | ||
|
||
|
||
def _get_mesh(parallel_dist): | ||
jax.clear_caches() | ||
|
||
batched = False | ||
fsdp = False | ||
mesh_shape = dict(tp=NUM_DEVICES) | ||
resources = dict(cp_resource="tp", tp_resource="tp") | ||
if parallel_dist in ["DP_TP", "FSDP_TP"]: | ||
batched = True | ||
mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=NUM_DEVICES // 2)) | ||
resources.update(dict(dp_resource="dp")) | ||
if parallel_dist == "FSDP_TP": | ||
fsdp = True | ||
mesh_shape.update(dict(tp=NUM_DEVICES // 2, dp=1, zp=NUM_DEVICES // 2)) | ||
resources.update(dict(fsdp_resource="zp")) | ||
mesh_resource = te.MeshResource(**resources) | ||
|
||
devices = mesh_utils.create_device_mesh((NUM_DEVICES,), devices=jax.devices()[:NUM_DEVICES]) | ||
|
||
mesh = Mesh(np.array(devices).reshape(tuple(mesh_shape.values())), tuple(mesh_shape.keys())) | ||
|
||
return mesh, mesh_resource, batched, fsdp | ||
|
||
|
||
def _get_inputs(mesh, mesh_resource, dtype, fwd_comm_type, batched, fsdp, fwd_bwd=False): | ||
fp8_gemm = dtype in [jnp.float8_e4m3fn, jnp.float8_e5m2] | ||
|
||
# Operand and output shapes | ||
lhs_shape = ( | ||
[SEQ_LEN, HIDDEN_SIZE] if fwd_comm_type == "ALL_GATHER" else [SEQ_LEN, FFN_HIDDEN_SIZE] | ||
) | ||
rhs_shape = ( | ||
[HIDDEN_SIZE, FFN_HIDDEN_SIZE] | ||
if fwd_comm_type == "ALL_GATHER" | ||
else [FFN_HIDDEN_SIZE, HIDDEN_SIZE] | ||
) | ||
out_shape = [lhs_shape[0], rhs_shape[1]] | ||
|
||
if batched: | ||
lhs_shape = [BATCH] + lhs_shape | ||
out_shape = [BATCH] + out_shape | ||
|
||
# Operand and output partition specs | ||
lhs_spec = ( | ||
[mesh_resource.tp_resource, None] | ||
if fwd_comm_type == "ALL_GATHER" | ||
else [None, mesh_resource.tp_resource] | ||
) | ||
rhs_spec = ( | ||
[None, mesh_resource.tp_resource] | ||
if fwd_comm_type == "ALL_GATHER" | ||
else [mesh_resource.tp_resource, None] | ||
) | ||
out_spec = [None, rhs_spec[-1]] | ||
|
||
# Modify RHS operand for FP8 | ||
fsdp_gathered_rhs_spec = rhs_spec.copy() | ||
if fp8_gemm: | ||
rhs_shape = list(reversed(rhs_shape)) | ||
rhs_spec = list(reversed(rhs_spec)) | ||
fsdp_gathered_rhs_spec = list(reversed(fsdp_gathered_rhs_spec)) | ||
|
||
# Add batch dimensions and specs | ||
if batched: | ||
if fsdp: | ||
lhs_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + lhs_spec | ||
rhs_spec = [mesh_resource.fsdp_resource if spec is None else spec for spec in rhs_spec] | ||
out_spec = [(mesh_resource.dp_resource, mesh_resource.fsdp_resource)] + out_spec | ||
else: | ||
lhs_spec = [mesh_resource.dp_resource] + lhs_spec | ||
out_spec = [mesh_resource.dp_resource] + out_spec | ||
|
||
# Allocate global operands on device | ||
key = jax.random.PRNGKey(42) | ||
split_keys = jax.random.split(key, 3 if fwd_bwd else 2) | ||
mu = 0.0 | ||
sigma = 0.023 | ||
shapes = (lhs_shape, rhs_shape) | ||
if fwd_bwd: | ||
shapes += (out_shape,) | ||
global_operands = list( | ||
map( | ||
lambda key, shape: jax.device_put( | ||
mu + (sigma * jax.random.normal(key, shape, dtype=dtype)), | ||
NamedSharding(mesh, PartitionSpec(None)), | ||
), | ||
split_keys, | ||
shapes, | ||
) | ||
) | ||
|
||
# Allocate sharded operands on device | ||
partition_axes = (lhs_spec, rhs_spec) | ||
if fwd_bwd: | ||
partition_axes += (out_spec,) | ||
local_operands = list( | ||
map( | ||
lambda x, spec: jax.device_put(x, NamedSharding(mesh, PartitionSpec(*spec))), | ||
global_operands, | ||
partition_axes, | ||
) | ||
) | ||
|
||
# Tranpose global RHS back to non-transpoosed orientation if it was originally allocated | ||
# for FP8 GEMM | ||
if fp8_gemm: | ||
rhs_global = jnp.matrix_transpose(global_operands[1]) | ||
global_operands = (global_operands[0], rhs_global, *global_operands[2:]) | ||
|
||
return ( | ||
local_operands, | ||
global_operands, | ||
(out_shape, out_spec), | ||
fsdp_gathered_rhs_spec, | ||
) | ||
|
||
|
||
def _check_output(mesh, expected_out_shape, expected_out_specs, *tensors, fwd_bwd=False): | ||
num_operands = 3 if fwd_bwd else 2 | ||
ref_operands = tensors[:num_operands] | ||
test_outputs = tensors[num_operands:] | ||
|
||
# Check number of dimensions | ||
assert test_outputs[0].ndim == len(expected_out_shape), ( | ||
f"Output has different number of dimensions ({test_outputs[0].ndim}) than expected " | ||
+ f"({len(expected_out_shape)})" | ||
) | ||
|
||
# Pad test output spec for unsharded dimensions | ||
test_spec = te.sharding.get_padded_spec(test_outputs[0].sharding.spec, test_outputs[0].ndim) | ||
|
||
for i in range(test_outputs[0].ndim): | ||
# Check shape | ||
assert test_outputs[0].shape[i] == expected_out_shape[i], ( | ||
f"Output with shape {test_outputs[0].shape} does not match expected shape " | ||
+ f"{expected_out_shape} in dimension index {i}." | ||
) | ||
|
||
# Check shardings (with padded output spec) | ||
spec_mismatch = False | ||
if isinstance(expected_out_specs[i], str): | ||
if test_spec[i] != expected_out_specs[i]: | ||
spec_mismatch = True | ||
elif isinstance(expected_out_specs[i], Iterable): | ||
if not isinstance(test_spec[i], type(expected_out_specs[i])): | ||
if test_spec[i] not in expected_out_specs[i]: | ||
spec_mismatch = True | ||
elif len(test_spec[i]) != len(expected_out_specs[i]): | ||
spec_mismatch = True | ||
else: | ||
for j in range(len(expected_out_specs[i])): | ||
if test_spec[i][j] != expected_out_specs[i][j]: | ||
spec_mismatch = True | ||
break | ||
elif expected_out_specs[i] == None: | ||
if test_spec[i] != None: | ||
spec_mismatch = True | ||
else: | ||
raise RuntimeError("Internal TE error: Unrecognized reference partition spec type.") | ||
if spec_mismatch: | ||
raise AssertionError( | ||
f"Output sharding {test_spec} does not match expected sharding " | ||
+ f"{expected_out_specs} in dimension index {i}." | ||
) | ||
|
||
def _native_gemm_fwd_bwd(lhs, rhs, grad): | ||
fwd_out, vjp_fn = jax.vjp(jnp.dot, lhs, rhs) | ||
lhs_grad, rhs_grad = vjp_fn(grad) | ||
return fwd_out, lhs_grad, rhs_grad | ||
|
||
ref_fn = jax.jit(_native_gemm_fwd_bwd if fwd_bwd else jnp.dot) | ||
|
||
out_names = ["output"] | ||
ref_outputs = ref_fn(*ref_operands) | ||
if not fwd_bwd: | ||
ref_outputs = [ref_outputs] | ||
else: | ||
out_names += ["dgrad", "wgrad"] | ||
|
||
for i, (test_out, ref_out) in enumerate(zip(test_outputs, ref_outputs)): | ||
test_out_global = jax.lax.with_sharding_constraint( | ||
test_out, NamedSharding(mesh, PartitionSpec(None)) | ||
) | ||
try: | ||
assert_allclose(ref_out, test_out_global) | ||
except AssertionError as err: | ||
raise AssertionError(f"Numerical mismatch in {out_names[i]}:\n" + str(err)) | ||
|
||
|
||
@pytest.mark.parametrize("comm_type", COMM_TYPES) | ||
@pytest.mark.parametrize("mesh_type", MESH_TYPES) | ||
def test_gemm_impl(comm_type, mesh_type): | ||
mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) | ||
|
||
( | ||
local_operands, | ||
global_operands, | ||
output_info, | ||
fsdp_gathered_rhs_spec, | ||
) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp) | ||
|
||
@jax.jit | ||
def _test_fn(lhs, rhs): | ||
rhs_no_fsdp = jax.lax.with_sharding_constraint( | ||
rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) | ||
) | ||
return te.cpp_extensions.gemm_impl(lhs, rhs_no_fsdp, batched_output=batched) | ||
|
||
with te.sharding.global_shard_guard(mesh_resource): | ||
output, *_ = _test_fn(*local_operands) | ||
|
||
_check_output(mesh, *output_info, *global_operands, output) | ||
|
||
|
||
@pytest.mark.parametrize("comm_type", COMM_TYPES) | ||
@pytest.mark.parametrize("mesh_type", MESH_TYPES) | ||
def test_gemm_fwd_bwd(comm_type, mesh_type): | ||
mesh, mesh_resource, batched, fsdp = _get_mesh(mesh_type) | ||
|
||
( | ||
local_operands, | ||
global_operands, | ||
output_info, | ||
fsdp_gathered_rhs_spec, | ||
) = _get_inputs(mesh, mesh_resource, jnp.bfloat16, comm_type, batched, fsdp, fwd_bwd=True) | ||
|
||
@jax.jit | ||
def _test_fn(lhs, rhs, grad): | ||
# Gather weights in FSDP axis | ||
rhs_no_fsdp = jax.lax.with_sharding_constraint( | ||
rhs, NamedSharding(mesh, PartitionSpec(*fsdp_gathered_rhs_spec)) | ||
) | ||
|
||
# FWD pass | ||
fwd_out, vjp_fn = jax.vjp(gemm, lhs, rhs_no_fsdp) | ||
|
||
# BWD pass | ||
lhs_grad, rhs_grad = vjp_fn(grad) | ||
|
||
return fwd_out, lhs_grad, rhs_grad | ||
|
||
print( | ||
f"INPUTS: {local_operands[0].shape} x {local_operands[1].shape}\n" | ||
+ f" LHS sharding: {local_operands[0].sharding.spec}\n" | ||
+ f" RHS sharding: {local_operands[1].sharding.spec}\n" | ||
) | ||
|
||
with te.sharding.global_shard_guard(mesh_resource): | ||
output, dgrad, wgrad = _test_fn(*local_operands) | ||
|
||
print( | ||
f"{'AG + GEMM' if comm_type == 'AG' else 'GEMM + AR'} output: " | ||
+ f"{output.shape} | {output.sharding.spec}\n" | ||
+ f"DGRAD: {dgrad.shape} | {dgrad.sharding.spec}\n" | ||
+ f"WGRAD: {wgrad.shape} | {wgrad.sharding.spec}\n" | ||
) | ||
|
||
_check_output(mesh, *output_info, *global_operands, output, dgrad, wgrad, fwd_bwd=True) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mesh shape calculation is incorrect. Suggested revision: