Skip to content

[PyTorch] Refactor parameter splitting in Linear and LayerNormLinear#590

Merged
timmoon10 merged 6 commits intoNVIDIA:mainfrom
timmoon10:parameters_split_refactor
Jan 8, 2024
Merged

[PyTorch] Refactor parameter splitting in Linear and LayerNormLinear#590
timmoon10 merged 6 commits intoNVIDIA:mainfrom
timmoon10:parameters_split_refactor

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 5, 2024

#533 reports that TransformerLayer doesn't work out-of-the-box with tensor parallelism. The root cause is because the logic for parameter splitting (e.g. for QKV matrices) does not handle tensor parallelism. We've also had another user run into trouble when trying to set parameters_split in Linear because it currently expects the split names to have exactly one underscore at the end (so mysplit and my_split_ would both fail).

I think this is a good opportunity to refactor this logic:

  • Adjust parameter split size as needed for tensor parallelism
  • Generalize support for split names. To maintain backward compatibility, we now strip all trailing underscores before appending _weight or _bias, resulting in parameter names like q_weight, etc.
  • Separate the noop_cat operation so it is independent from the TE modules.

Closes #533.

Remove module state from noop_cat. Support arbitrary names in parameter split. Handle tensor parallelism.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10 timmoon10 added the bug Something isn't working label Jan 5, 2024
@timmoon10 timmoon10 requested review from cyanguwa and ksivaman January 5, 2024 08:02
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Fix pylint complaints.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would these changes affect when people try to load existing checkpoints?

@timmoon10
Copy link
Collaborator Author

I don't think so. The resulting param names (q_weight, k_weight, v_weight) are unchanged, and I don't think the actual values in parameters_split are involved in checkpointing.

@timmoon10 timmoon10 merged commit bb759ad into NVIDIA:main Jan 8, 2024
@timmoon10 timmoon10 deleted the parameters_split_refactor branch January 8, 2024 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

multi-gpu example with >1 GPU crashes without fuse_qkv=True

2 participants