Skip to content

Commit

Permalink
pass gather/input_is_parallel to build_module from mlp
Browse files Browse the repository at this point in the history
Signed-off-by: Abhinav Khattar <[email protected]>
  • Loading branch information
aklife97 committed Oct 4, 2023
1 parent 7ab6a29 commit 2e30ced
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 7 deletions.
8 changes: 1 addition & 7 deletions megatron/core/tensor_parallel/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,13 +837,7 @@ def __init__(
self.gradient_accumulation_fusion = config.gradient_accumulation_fusion
self.sequence_parallel = config.sequence_parallel
if self.sequence_parallel and not self.input_is_parallel:
# raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")
print(
'WARNING: To enable `sequence_parallel`',
'`input_is_parallel` must be `True ',
flush=True,
)
self.input_is_parallel = True
raise RuntimeError("To enable `sequence_parallel`, `input_is_parallel` must be `True`")

# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
Expand Down
2 changes: 2 additions & 0 deletions megatron/core/transformer/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
ffn_hidden_size,
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=self.config.add_bias_linear,
skip_bias_add=True,
is_expert=is_expert,
Expand All @@ -75,6 +76,7 @@ def glu(x):
config=self.config,
init_method=self.config.output_layer_init_method,
bias=self.config.add_bias_linear,
input_is_parallel=True,
skip_bias_add=True,
is_expert=is_expert,
)
Expand Down

0 comments on commit 2e30ced

Please sign in to comment.