Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[QUESTION] Adding new parameters in ColumnParallelLinear/RowParallelLinear raises AssertionError (Communication call has not been issued for this bucket) when using overlap-grad-reduce #1150

Open
haolibai opened this issue Sep 19, 2024 · 6 comments

Comments

@haolibai
Copy link

haolibai commented Sep 19, 2024

Hi, I am trying to add some new learnable parameters inside ColumnParallelLinear/RowParallelLinear, and the following is an example code snippet:

class ColumnParallelLinear(torch.nn.Module):
    """Linear layer with column parallelism.

    The linear layer is defined as Y = XA + b. A is parallelized along
    its second dimension as A = [A_1, ..., A_p].

    Args:
        input_size: first dimension of matrix A.
        output_size: second dimension of matrix A.
        bias: If true, add bias
        gather_output: If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i
        init_method: method to initialize weights. Note that bias is always set to zero.
        stride: For the strided linear layers.
        keep_master_weight_for_test: This was added for testing and should be set to False. It returns the master weights used for initialization.
        skip_bias_add: If True, do not add the bias term, instead return it to be added by the caller. This enables performance optimations where bias can be fused with other elementwise operations.
        skip_weight_param_allocation: If True, weight parameter is not allocated and must be passed as a keyword argument `weight` during the forward pass. Note that this does not affect bias, which will be allocated if bias is True. Defaults to False.
        embedding_activation_buffer: This buffer holds the input activations of the final embedding linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
        grad_output_buffer: This buffer holds the gradient outputs of the final embedding linear layer on the last pipeline stage when defer_embedding_wgrad_compute is enabled.
        is_expert: If True, the layer is treated as an MoE expert layer.
        config: ModelParallelConfig object
        tp_comm_buffer_name: Communication buffer name is not used in non-Transformer-Engine modules.
    """

    def __init__(
        self,
        input_size,
        output_size,
        *,
        config: ModelParallelConfig,
        init_method: Callable,
        bias=True,
        gather_output=False,
        stride=1,
        keep_master_weight_for_test=False,
        skip_bias_add=False,
        skip_weight_param_allocation: bool = False,
        embedding_activation_buffer: Optional[List[torch.Tensor]] = None,
        grad_output_buffer: Optional[List[torch.Tensor]] = None,
        is_expert: bool = False,
        tp_comm_buffer_name: str = None,  # Not used
    ):
        super(ColumnParallelLinear, self).__init__()
        ...
        # NOTE: a new parameter defined here (for example)
        self.new_param = Parameter(torch.randn(config.hidden_size, dtype=config.params_dtype, device=torch.cuda.current_device()))
    ))

    def forward(self, input_: torch.Tensor, weight: Optional[torch.Tensor] = None):
        ...
        output = output * self.new_param
        return output

However, this gives me the following error during training.

Traceback (most recent call last):
  File "/home/ma-user/work/haoli/code/PanGu/pretrain_gpt.py", line 343, in main
    pretrain(train_valid_test_datasets_provider,
  File "/home/ma-user/work/haoli/code/PanGu/pangu/training/training.py", line 271, in pretrain
    iteration, num_floating_point_operations_so_far = train(
  File "/home/ma-user/work/haoli/code/PanGu/pangu/training/training.py", line 441, in train
    train_step(forward_step_func,
  File "/home/ma-user/work/haoli/code/third_party/Megatron-LM/megatron/training/training.py", line 553, in train_step
    losses_reduced = forward_backward_func(
  File "/home/ma-user/work/haoli/code/third_party/Megatron-LM/megatron/core/pipeline_parallel/schedules.py", line 395, in forward_backward_no_pipelining
    config.finalize_model_grads_func([model])
  File "/home/ma-user/work/haoli/code/third_party/Megatron-LM/megatron/core/distributed/finalize_model_grads.py", line 135, in finalize_model_grads
    model_chunk.finish_grad_sync()
  File "/home/ma-user/work/haoli/code/third_party/Megatron-LM/megatron/core/distributed/distributed_data_parallel.py", line 242, in finish_grad_sync
    buffer.finish_grad_sync()
  File "/home/ma-user/work/haoli/code/third_party/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 512, in finish_grad_sync
    bucket.finish_grad_sync()
Traceback (most recent call last):
  File "/home/ma-user/work/haoli/code/pangu_sophon_pytorch/third_party/Megatron-LM/megatron/core/distributed/param_and_grad_buffer.py", line 157, in finish_grad_sync
    assert self.communication_handle is not None and self.communication_issued, (
AssertionError: Communication call has not been issued for this bucket (1/2 params have grad available)

This happens when the following arguments are passed for training:

--overlap-param-gather
--overlap-grad-reduce

It seems the newly added parameter is not counted into self.params_with_grad.

However, the training goes normal when I do the same procedure in other places, e.g., the init fucntion of ParallelAttention, or ParallelMLP, with no such errors.

@haolibai haolibai changed the title [QUESTION] Defining a new parameter in ColumnParallelLinear/RowParallelLinear raises Error [QUESTION] Adding a new parameter in ColumnParallelLinear/RowParallelLinear raises Error Sep 19, 2024
@liu09114
Copy link

Did you solve this problem? I meet the same problem.

@haolibai
Copy link
Author

Not yet. But I found more colleagues around me also meet this problem.

@KookHoiKim
Copy link

same issue when using
--overlap-param-gather
--overlap-grad-reduce \

@jpilaul
Copy link

jpilaul commented Nov 15, 2024

Also getting this error with

--overlap-param-gather
--overlap-grad-reduce

@haolibai haolibai changed the title [QUESTION] Adding a new parameter in ColumnParallelLinear/RowParallelLinear raises Error [QUESTION] Adding new parameters in ColumnParallelLinear/RowParallelLinear raises Error (Communication call has not been issued for this bucket (1/2 params have grad available)) when using overlap-grad-reduce Nov 18, 2024
@haolibai haolibai changed the title [QUESTION] Adding new parameters in ColumnParallelLinear/RowParallelLinear raises Error (Communication call has not been issued for this bucket (1/2 params have grad available)) when using overlap-grad-reduce [QUESTION] Adding new parameters in ColumnParallelLinear/RowParallelLinear raises AssertionError (Communication call has not been issued for this bucket) when using overlap-grad-reduce Nov 18, 2024
@tfigliolia
Copy link

I am having exactly the same issue. Using the distributed optimizer with EP, and enabling
--overlap-param-gather
--overlap-grad-reduce
I get the same bucket problem

@SeunghyunSEO
Copy link

whatup guys, maybe my issue is related and if I'm right, you can fix this by locating custom layers in forward order

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants