Skip to content

Commit

Permalink
sync after all2all for device memory saving (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifeibear authored Nov 11, 2024
1 parent a6d31fd commit 7d528d6
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 14 deletions.
19 changes: 13 additions & 6 deletions yunchang/comm/all_to_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def all_to_all_4D(
input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None
input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False
) -> torch.tensor:
"""
all-to-all for QKV
Expand All @@ -23,6 +23,7 @@ def all_to_all_4D(
scatter_idx (int): default 1
gather_idx (int): default 2
group : torch process group
use_sync (bool): whether to synchronize after all-to-all
Returns:
torch.tensor: resharded tensor (bs, seqlen/P, hc, hs)
Expand Down Expand Up @@ -50,9 +51,11 @@ def all_to_all_4D(
output = torch.empty_like(input_t)
# https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single
# (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head

if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
if use_sync:
torch.cuda.synchronize()
else:
output = input_t
# if scattering the seq-dim, transpose the heads back to the original dimension
Expand Down Expand Up @@ -85,7 +88,8 @@ def all_to_all_4D(
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
torch.cuda.synchronize()
if use_sync:
torch.cuda.synchronize()
else:
output = input_t

Expand Down Expand Up @@ -129,7 +133,7 @@ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]:


def all_to_all_5D(
input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None
input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None, use_sync: bool = False
) -> torch.tensor:
"""
all-to-all for QKV
Expand All @@ -140,6 +144,7 @@ def all_to_all_5D(
scatter_idx (int): default 1
gather_idx (int): default 2
group : torch process group
use_sync: whether to synchronize after all-to-all
Returns:
torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs)
Expand Down Expand Up @@ -171,7 +176,8 @@ def all_to_all_5D(
# (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
# torch.cuda.synchronize()
if use_sync:
torch.cuda.synchronize()
else:
output = input_t

Expand Down Expand Up @@ -204,7 +210,8 @@ def all_to_all_5D(
# (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head
if seq_world_size > 1:
dist.all_to_all_single(output, input_t, group=group)
# torch.cuda.synchronize()
if use_sync:
torch.cuda.synchronize()
else:
output = input_t

Expand Down
21 changes: 13 additions & 8 deletions yunchang/hybrid/attn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class LongContextAttention(torch.nn.Module):
ring_pg (ProcessGroup): ring process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
use_sync (bool): whether to synchronize after all-to-all
"""

def __init__(
Expand All @@ -26,13 +27,15 @@ def __init__(
gather_idx: int = 1,
ring_impl_type: str = "basic",
use_pack_qkv: bool = False,
use_sync: bool = False,
) -> None:

super(LongContextAttention, self).__init__()
self.ring_pg = PROCESS_GROUP.RING_PG
self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG

self.use_pack_qkv = use_pack_qkv
self.use_sync = use_sync
assert (
self.ulysses_pg is not None or self.ring_pg is not None
), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}"
Expand Down Expand Up @@ -74,7 +77,7 @@ def forward(
qkv = torch.cat([query, key, value]).continous()
# (3*bs, seq_len, head_cnt/N, head_size)
qkv = SeqAllToAll4D.apply(
self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx
self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, use_sync=self.use_sync
)
qkv = torch.chunk(qkv, 3, dim=0)
out = self.ring_attn_fn(
Expand All @@ -93,13 +96,13 @@ def forward(
)
else:
query_layer = SeqAllToAll4D.apply(
self.ulysses_pg, query, self.scatter_idx, self.gather_idx
self.ulysses_pg, query, self.scatter_idx, self.gather_idx, use_sync=self.use_sync
)
key_layer = SeqAllToAll4D.apply(
self.ulysses_pg, key, self.scatter_idx, self.gather_idx
self.ulysses_pg, key, self.scatter_idx, self.gather_idx, use_sync=self.use_sync
)
value_layer = SeqAllToAll4D.apply(
self.ulysses_pg, value, self.scatter_idx, self.gather_idx
self.ulysses_pg, value, self.scatter_idx, self.gather_idx, use_sync=self.use_sync
)

out = self.ring_attn_fn(
Expand All @@ -125,7 +128,7 @@ def forward(
# (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size)
# scatter 1, gather 2
output = SeqAllToAll4D.apply(
self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx
self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx, use_sync=self.use_sync
)

# out e.g., [s/p::h]
Expand All @@ -140,13 +143,15 @@ class LongContextAttentionQKVPacked(torch.nn.Module):
ring_pg (ProcessGroup): ring process group
scatter_idx (int): scatter_idx for all2all comm
gather_idx (int): gather_idx for all2all comm
use_sync (bool): whether to synchronize after all-to-all
"""

def __init__(
self,
scatter_idx: int = 3,
gather_idx: int = 1,
ring_impl_type: str = "basic",
use_sync: bool = False,
) -> None:

super(LongContextAttentionQKVPacked, self).__init__()
Expand All @@ -159,7 +164,7 @@ def __init__(
), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}"
self.scatter_idx = scatter_idx
self.gather_idx = gather_idx

self.use_sync = use_sync
self.ring_attn_fn = RING_IMPL_QKVPACKED_DICT[ring_impl_type]

def forward(
Expand Down Expand Up @@ -193,7 +198,7 @@ def forward(

if world_size > 1:
qkv = SeqAllToAll5D.apply(
self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx
self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, use_sync=self.use_sync
)

out = self.ring_attn_fn(
Expand All @@ -219,7 +224,7 @@ def forward(

if world_size > 1:
out = SeqAllToAll4D.apply(
self.ulysses_pg, out, self.gather_idx, self.scatter_idx - 1
self.ulysses_pg, out, self.gather_idx, self.scatter_idx - 1, use_sync=self.use_sync
)
# out e.g., [s/p::h]
return out

0 comments on commit 7d528d6

Please sign in to comment.