From 33d14e6901ae1549ef0d0bb8133a0bf7241c349e Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Thu, 11 May 2023 14:20:58 +0200 Subject: [PATCH 1/3] Don't need to run communication for kv --- megatron/model/transformer.py | 37 +++++++++++++++------------------ megatron/optimizer/optimizer.py | 1 - 2 files changed, 17 insertions(+), 21 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 2e437a901a..798499f372 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -32,6 +32,7 @@ from .glu_activations import GLU_ACTIVATIONS +from ..mpu import copy_to_tensor_model_parallel_region, LinearWithGradAccumulationAndAsyncCommunication # flags required to enable jit fusion kernels torch._C._jit_set_profiling_mode(False) @@ -565,6 +566,11 @@ def __init__(self, init_method, args.hidden_size, 2 * args.kv_channels, init_method=init_method) + + self.async_tensor_model_parallel_allreduce = args.async_tensor_model_parallel_allreduce and world_size > 1 + self.sequence_parallel = args.sequence_parallel and world_size > 1 + self.gradient_accumulation_fusion = args.gradient_accumulation_fusion + elif attention_type == AttnType.cross_attn and self.attention_head_type == 'multihead': assert attention_type == AttnType.cross_attn self.query = mpu.ColumnParallelLinear( @@ -686,28 +692,17 @@ def forward(self, hidden_states, attention_mask, key_layer, value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) elif self.attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery': - kv_input=hidden_states - # Attention heads [sq, b, h] --> [sq, b, (2 * hn)] - mixed_kv_layer = self.key_value(kv_input) + kv_input = hidden_states - # Reduce the KV gradients in the tensor-parallel direction. - # This is different from multi-head attention which reduces the KV input, - # because the sum over attn heads happens in the attn weight gradient instead of the KV layer: - # A [b, n * sq, sk] = Q [b, n * sq, hn] x K^T [b, hn, sk] - # G_K [b, sk, hn] = G_A [b, sk, n * sq] x Q [b, n * sq, hn] - # = sum_p (G_Ap [b, sk, np * sq] x Q_p [b, np * sq, hn]) - if get_args().sequence_parallel: - # We switch to the tensor parallel regime here instead of at the KV input - # so that the KV layer is done in parallel instead of just duplicated. - mixed_kv_layer = mpu.gather_from_sequence_parallel_region(mixed_kv_layer, tensor_parallel_output_grad=True) + # Manually handle communication of kv_input + if self.async_tensor_model_parallel_allreduce or \ + self.sequence_parallel: + kv_input = kv_input else: - mixed_kv_layer = mpu.copy_to_tensor_model_parallel_region(mixed_kv_layer) + kv_input = copy_to_tensor_model_parallel_region(kv_input) - # [sq, b, (2 * hn)] --> [sq, b, np (expanded), 2 * hn] - # new_tensor_shape = mixed_kv_layer.size()[:-1] + \ - # (self.num_attention_heads_per_partition, - # 2 * self.hidden_size_per_attention_head) - # mixed_kv_layer = mixed_kv_layer.unsqueeze(2).expand(*new_tensor_shape) + # Attention heads [sq, b, h] --> [sq, b, (2 * hn)] + mixed_kv_layer = self.key_value(kv_input) # [sq, b, (2 * hn)] --> [sq, b, 1, 2 * hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + \ @@ -720,7 +715,9 @@ def forward(self, hidden_states, attention_mask, value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2) # Attention head [sq, b, h] --> [sq, b, np * hn] - query_layer, _ = self.query(hidden_states) + query_layer = LinearWithGradAccumulationAndAsyncCommunication.apply( + kv_input, self.query.weight, self.query.bias, self.gradient_accumulation_fusion, + self.async_tensor_model_parallel_allreduce, self.sequence_parallel) # [sq, b, np * hn] --> [sq, b, np, hn] new_tensor_shape = query_layer.size()[:-1] + \ (self.num_attention_heads_per_partition, diff --git a/megatron/optimizer/optimizer.py b/megatron/optimizer/optimizer.py index efa1bd36f8..10e4a69615 100644 --- a/megatron/optimizer/optimizer.py +++ b/megatron/optimizer/optimizer.py @@ -338,7 +338,6 @@ def reduce_model_grads(self, args, timers): if ( args.attention_head_type == "multiquery" and mpu.get_tensor_model_parallel_world_size() > 1 - and args.sequence_parallel ): timers('backward-key-value-all-reduce').start() self.allreduce_key_value_grads(args) From 67f733cdd67f8c7851edd0a6d528157c704698b6 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Thu, 11 May 2023 14:28:26 +0200 Subject: [PATCH 2/3] Woops --- megatron/model/transformer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 798499f372..bc5aa5b0d2 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -553,10 +553,9 @@ def __init__(self, init_method, init_method=init_method) elif attention_type == AttnType.self_attn and self.attention_head_type == 'multiquery': # TODO: Find a way to merge the query and key-value computations? - self.query = mpu.ColumnParallelLinear( + self.query = get_linear_layer( args.hidden_size, projection_size, - gather_output=False, init_method=init_method) # In MultiQuery attention, keys and values are shared across heads # Use args.kv_channels instead of projection_size From 080d1c0af466c67e99a4311f489612ce7c371242 Mon Sep 17 00:00:00 2001 From: thomasw21 <24695242+thomasw21@users.noreply.github.com> Date: Fri, 12 May 2023 10:48:49 +0200 Subject: [PATCH 3/3] I think this fixes sequence parallel --- megatron/model/transformer.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index bc5aa5b0d2..d6e8c1110d 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -700,8 +700,16 @@ def forward(self, hidden_states, attention_mask, else: kv_input = copy_to_tensor_model_parallel_region(kv_input) + # TODO @thomasw21: This is stupid because `LinearWithGradAccumulationAndAsyncCommunication` also all_gathers the activations + if self.sequence_parallel: + kv_input_gathered = mpu.gather_from_sequence_parallel_region( + kv_input, + tensor_parallel_output_grad=True) + else: + kv_input_gathered = kv_input + # Attention heads [sq, b, h] --> [sq, b, (2 * hn)] - mixed_kv_layer = self.key_value(kv_input) + mixed_kv_layer = self.key_value(kv_input_gathered) # [sq, b, (2 * hn)] --> [sq, b, 1, 2 * hn] new_tensor_shape = mixed_kv_layer.size()[:-1] + \