Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Refactor parameter splitting in Linear and LayerNormLinear #590

Merged
merged 6 commits into from
Jan 8, 2024

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 5, 2024

#533 reports that TransformerLayer doesn't work out-of-the-box with tensor parallelism. The root cause is because the logic for parameter splitting (e.g. for QKV matrices) does not handle tensor parallelism. We've also had another user run into trouble when trying to set parameters_split in Linear because it currently expects the split names to have exactly one underscore at the end (so mysplit and my_split_ would both fail).

I think this is a good opportunity to refactor this logic:

  • Adjust parameter split size as needed for tensor parallelism
  • Generalize support for split names. To maintain backward compatibility, we now strip all trailing underscores before appending _weight or _bias, resulting in parameter names like q_weight, etc.
  • Separate the noop_cat operation so it is independent from the TE modules.

Closes #533.

Remove module state from noop_cat. Support arbitrary names in parameter split. Handle tensor parallelism.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 added the bug Something isn't working label Jan 5, 2024
@timmoon10 timmoon10 requested review from cyanguwa and ksivaman January 5, 2024 08:02
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

Signed-off-by: Tim Moon <[email protected]>
Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would these changes affect when people try to load existing checkpoints?

@timmoon10
Copy link
Collaborator Author

I don't think so. The resulting param names (q_weight, k_weight, v_weight) are unchanged, and I don't think the actual values in parameters_split are involved in checkpointing.

@timmoon10 timmoon10 merged commit bb759ad into NVIDIA:main Jan 8, 2024
9 checks passed
@timmoon10 timmoon10 deleted the parameters_split_refactor branch January 8, 2024 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

multi-gpu example with >1 GPU crashes without fuse_qkv=True
2 participants