Skip to content

Commit

Permalink
Hierarchical CP implementation (Ulysses + Ring) (#1209)
Browse files Browse the repository at this point in the history
* change API for hierarchical CP

Signed-off-by: Xiaowei Ren <[email protected]>

* move fp8 code before qkv reshape

Signed-off-by: Xiaowei Ren <[email protected]>

* try to insert A2A for hierarchical CP

Signed-off-by: Xiaowei Ren <[email protected]>

* make fwd work

Signed-off-by: Xiaowei Ren <[email protected]>

* remove a redundant sync

Signed-off-by: Xiaowei Ren <[email protected]>

* make bwd of hierarchical CP work

Signed-off-by: Xiaowei Ren <[email protected]>

* fix dout a2a in bwd

Signed-off-by: Xiaowei Ren <[email protected]>

* fix q_f16 with fp8

Signed-off-by: Xiaowei Ren <[email protected]>

* assert hierarchical CP implementation does not support THD format

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* assert hierarchical CP does not support attn bias

Signed-off-by: Xiaowei Ren <[email protected]>

* add unit test for hierarchical CP

Signed-off-by: Xiaowei Ren <[email protected]>

* fix cp_comm_type in unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix and code cleaning

Signed-off-by: Xiaowei Ren <[email protected]>

* minor change

Signed-off-by: Xiaowei Ren <[email protected]>

* an assert info change

Signed-off-by: Xiaowei Ren <[email protected]>

* dout shape fix

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* move function definitions to the front of the first call

Signed-off-by: Xiaowei Ren <[email protected]>

* fix tensor view comments

Signed-off-by: Xiaowei Ren <[email protected]>

* refine CP unit test

Signed-off-by: Xiaowei Ren <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* save cp_size_a2a and rank_a2a in fwd

Signed-off-by: Xiaowei Ren <[email protected]>

* add more explainations of cp_group in doc_string

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
xrennvidia and pre-commit-ci[bot] authored Oct 7, 2024
1 parent 60f738f commit c24a4c4
Show file tree
Hide file tree
Showing 4 changed files with 390 additions and 202 deletions.
23 changes: 15 additions & 8 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,17 @@ def run_dpa_with_cp(
cp_comm_ranks = range(world_size)
assert rank in cp_comm_ranks
cp_comm_group = dist.new_group(cp_comm_ranks, backend="nccl")
if cp_comm_type == "a2a+p2p":
assert (
world_size % 2 == 0
), "Assuming CP size for A2A is 2, and CP size for P2P is (world_size // 2)!"
cp_comm_sub_ranks = [range(i * 2, (i + 1) * 2) for i in range(world_size // 2)]
cp_comm_sub_ranks += [range(i, world_size, 2) for i in range(2)]
cp_comm_sub_groups = []
for sub_ranks in cp_comm_sub_ranks:
sub_group = dist.new_group(sub_ranks, backend="nccl")
if rank in sub_ranks:
cp_comm_sub_groups.append(sub_group)

if dtype == "fp8":
fp8_recipe = DelayedScaling(fp8_dpa=True)
Expand Down Expand Up @@ -167,13 +178,6 @@ def run_dpa_with_cp(
else:
bias = None

# make sure all GPU ranks have same inputs
for x in [q, k, v, dout] + ([] if bias is None else [bias]):
dist.broadcast(x, 0, group=cp_comm_group)
if qkv_format == "thd":
for x in [cu_seqlens_q, cu_seqlens_q_padded, cu_seqlens_kv, cu_seqlens_kv_padded]:
dist.broadcast(x, 0, group=cp_comm_group)

# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
Expand Down Expand Up @@ -239,7 +243,10 @@ def run_dpa_with_cp(
bias_ = bias_.index_select(2, seq_idx)
bias_ = bias_.view(*bias_.shape[:2], -1, bias_.shape[-1])
core_attn.set_context_parallel_group(
cp_comm_group, cp_comm_ranks, torch.cuda.Stream(), cp_comm_type
cp_comm_sub_groups if cp_comm_type == "a2a+p2p" else cp_comm_group,
cp_comm_ranks,
torch.cuda.Stream(),
cp_comm_type,
)

if dtype == "fp8":
Expand Down
29 changes: 18 additions & 11 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,13 @@
}


def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
def get_bash_arguments(num_gpus_per_node, **kwargs):
args = [
"python",
"-m",
"torch.distributed.launch",
"--nproc-per-node=" + str(num_gpus_per_node),
]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
Expand All @@ -51,27 +56,28 @@ def get_bash_arguments(**kwargs):
@pytest.mark.parametrize("dtype", ["bf16", "fp16"])
@pytest.mark.parametrize("model", model_configs_flash_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
config = model_configs_flash_attn[model]
if cp_comm_type == "p2p" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and qkv_format == "thd":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and qkv_format == "thd":
if "a2a" in cp_comm_type and qkv_format == "thd":
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)

subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
dtype=dtype,
model=model,
qkv_format=qkv_format,
Expand Down Expand Up @@ -106,7 +112,7 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type):
@pytest.mark.parametrize("dtype", ["bf16", "fp16", "fp8"])
@pytest.mark.parametrize("model", model_configs_fused_attn.keys())
@pytest.mark.parametrize("qkv_format", ["bshd", "sbhd", "thd"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a"])
@pytest.mark.parametrize("cp_comm_type", ["p2p", "all_gather", "a2a", "a2a+p2p"])
def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
if qkv_format == "thd" and get_device_compute_capability() < (9, 0):
pytest.skip("THD format is only supported on sm90+!")
Expand All @@ -122,7 +128,7 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and cp_comm_type == "a2a":
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
Expand All @@ -140,16 +146,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if cp_comm_type == "a2a" and config.attn_bias_type != "no_bias":
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with QKVO A2A does not support bias yet!")
if cp_comm_type == "a2a" and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
if "a2a" in cp_comm_type and (config.num_heads % 2 != 0 or config.num_gqa_groups % 2 != 0):
pytest.skip(
f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and"
f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!"
)

subprocess.run(
get_bash_arguments(
num_gpus_per_node=4 if cp_comm_type == "a2a+p2p" else 2,
dtype=dtype,
model=model,
qkv_format=qkv_format,
Expand Down
Loading

0 comments on commit c24a4c4

Please sign in to comment.