-
Notifications
You must be signed in to change notification settings - Fork 10
Tensor parallelism #374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tensor parallelism #374
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work!
I think there is a problem regarding TP + GELU, see my comment in src/modalities/models/model_factory.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, I just requested some minor changes in the test
I have added GELU support now for TP (including unit tests) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes look good! Only some minor issues left, see comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :)
Co-authored-by: Felix Stollenwerk <[email protected]>
Co-authored-by: Felix Stollenwerk <[email protected]>
…lelism' into tensor_parallelism
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
What does this PR do?
This PR adds support for Tensor Parallelism (including Sequence Parallelism).
Additionally, this PR adds a debugging toolkit to track the input and output tensors during a forward pass, gradients during the backward pass and weight tensors.
Tensors can bei either normal Tensors or DTensors.
Checklist before submitting final PR
python tests/tests.py
)CHANGELOG_DEV.md
)