From 7d528d68564f0ecd3d5089a0f8298919ffed7e65 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Mon, 11 Nov 2024 11:29:40 +0800 Subject: [PATCH] sync after all2all for device memory saving (#93) --- yunchang/comm/all_to_all.py | 19 +++++++++++++------ yunchang/hybrid/attn_layer.py | 21 +++++++++++++-------- 2 files changed, 26 insertions(+), 14 deletions(-) diff --git a/yunchang/comm/all_to_all.py b/yunchang/comm/all_to_all.py index 8a6764a..901b78a 100644 --- a/yunchang/comm/all_to_all.py +++ b/yunchang/comm/all_to_all.py @@ -13,7 +13,7 @@ def all_to_all_4D( - input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None + input: torch.tensor, scatter_idx: int = 2, gather_idx: int = 1, group=None, use_sync: bool = False ) -> torch.tensor: """ all-to-all for QKV @@ -23,6 +23,7 @@ def all_to_all_4D( scatter_idx (int): default 1 gather_idx (int): default 2 group : torch process group + use_sync (bool): whether to synchronize after all-to-all Returns: torch.tensor: resharded tensor (bs, seqlen/P, hc, hs) @@ -50,9 +51,11 @@ def all_to_all_4D( output = torch.empty_like(input_t) # https://pytorch.org/docs/stable/distributed.html#torch.distributed.all_to_all_single # (P, seq_len/P, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, bs, hc/P, hs) scatter head + if seq_world_size > 1: dist.all_to_all_single(output, input_t, group=group) - torch.cuda.synchronize() + if use_sync: + torch.cuda.synchronize() else: output = input_t # if scattering the seq-dim, transpose the heads back to the original dimension @@ -85,7 +88,8 @@ def all_to_all_4D( # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head if seq_world_size > 1: dist.all_to_all_single(output, input_t, group=group) - torch.cuda.synchronize() + if use_sync: + torch.cuda.synchronize() else: output = input_t @@ -129,7 +133,7 @@ def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, None, None]: def all_to_all_5D( - input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None + input: torch.tensor, scatter_idx: int = 3, gather_idx: int = 1, group=None, use_sync: bool = False ) -> torch.tensor: """ all-to-all for QKV @@ -140,6 +144,7 @@ def all_to_all_5D( scatter_idx (int): default 1 gather_idx (int): default 2 group : torch process group + use_sync: whether to synchronize after all-to-all Returns: torch.tensor: resharded tensor (bs, seqlen/P, 3, hc, hs) @@ -171,7 +176,8 @@ def all_to_all_5D( # (P, seq_len/P, 3, bs, hc/P, hs) scatter seqlen -all2all-> (P, seq_len/P, 3, bs, hc/P, hs) scatter head if seq_world_size > 1: dist.all_to_all_single(output, input_t, group=group) - # torch.cuda.synchronize() + if use_sync: + torch.cuda.synchronize() else: output = input_t @@ -204,7 +210,8 @@ def all_to_all_5D( # (P, bs x hc/P, seqlen/P, hs) scatter seqlen -all2all-> (P, bs x seq_len/P, hc/P, hs) scatter head if seq_world_size > 1: dist.all_to_all_single(output, input_t, group=group) - # torch.cuda.synchronize() + if use_sync: + torch.cuda.synchronize() else: output = input_t diff --git a/yunchang/hybrid/attn_layer.py b/yunchang/hybrid/attn_layer.py index c81feef..ae75857 100644 --- a/yunchang/hybrid/attn_layer.py +++ b/yunchang/hybrid/attn_layer.py @@ -18,6 +18,7 @@ class LongContextAttention(torch.nn.Module): ring_pg (ProcessGroup): ring process group scatter_idx (int): scatter_idx for all2all comm gather_idx (int): gather_idx for all2all comm + use_sync (bool): whether to synchronize after all-to-all """ def __init__( @@ -26,6 +27,7 @@ def __init__( gather_idx: int = 1, ring_impl_type: str = "basic", use_pack_qkv: bool = False, + use_sync: bool = False, ) -> None: super(LongContextAttention, self).__init__() @@ -33,6 +35,7 @@ def __init__( self.ulysses_pg = PROCESS_GROUP.ULYSSES_PG self.use_pack_qkv = use_pack_qkv + self.use_sync = use_sync assert ( self.ulysses_pg is not None or self.ring_pg is not None ), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" @@ -74,7 +77,7 @@ def forward( qkv = torch.cat([query, key, value]).continous() # (3*bs, seq_len, head_cnt/N, head_size) qkv = SeqAllToAll4D.apply( - self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx + self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, use_sync=self.use_sync ) qkv = torch.chunk(qkv, 3, dim=0) out = self.ring_attn_fn( @@ -93,13 +96,13 @@ def forward( ) else: query_layer = SeqAllToAll4D.apply( - self.ulysses_pg, query, self.scatter_idx, self.gather_idx + self.ulysses_pg, query, self.scatter_idx, self.gather_idx, use_sync=self.use_sync ) key_layer = SeqAllToAll4D.apply( - self.ulysses_pg, key, self.scatter_idx, self.gather_idx + self.ulysses_pg, key, self.scatter_idx, self.gather_idx, use_sync=self.use_sync ) value_layer = SeqAllToAll4D.apply( - self.ulysses_pg, value, self.scatter_idx, self.gather_idx + self.ulysses_pg, value, self.scatter_idx, self.gather_idx, use_sync=self.use_sync ) out = self.ring_attn_fn( @@ -125,7 +128,7 @@ def forward( # (bs, seq_len, head_cnt/N, head_size) -> (bs, seq_len/N, head_cnt, head_size) # scatter 1, gather 2 output = SeqAllToAll4D.apply( - self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx + self.ulysses_pg, context_layer, self.gather_idx, self.scatter_idx, use_sync=self.use_sync ) # out e.g., [s/p::h] @@ -140,6 +143,7 @@ class LongContextAttentionQKVPacked(torch.nn.Module): ring_pg (ProcessGroup): ring process group scatter_idx (int): scatter_idx for all2all comm gather_idx (int): gather_idx for all2all comm + use_sync (bool): whether to synchronize after all-to-all """ def __init__( @@ -147,6 +151,7 @@ def __init__( scatter_idx: int = 3, gather_idx: int = 1, ring_impl_type: str = "basic", + use_sync: bool = False, ) -> None: super(LongContextAttentionQKVPacked, self).__init__() @@ -159,7 +164,7 @@ def __init__( ), f"use set_seq_parallel_pg() first. Now ulysses pg {self.ulysses_pg} and ring pg {self.ring_pg}" self.scatter_idx = scatter_idx self.gather_idx = gather_idx - + self.use_sync = use_sync self.ring_attn_fn = RING_IMPL_QKVPACKED_DICT[ring_impl_type] def forward( @@ -193,7 +198,7 @@ def forward( if world_size > 1: qkv = SeqAllToAll5D.apply( - self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx + self.ulysses_pg, qkv, self.scatter_idx, self.gather_idx, use_sync=self.use_sync ) out = self.ring_attn_fn( @@ -219,7 +224,7 @@ def forward( if world_size > 1: out = SeqAllToAll4D.apply( - self.ulysses_pg, out, self.gather_idx, self.scatter_idx - 1 + self.ulysses_pg, out, self.gather_idx, self.scatter_idx - 1, use_sync=self.use_sync ) # out e.g., [s/p::h] return out