Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Store module extra state in tensor #1335

Merged

Conversation

timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Nov 15, 2024

Description

TE modules implement get_extra_state/set_extra_state in order to include FP8 state in checkpoints. We currently pickle the FP8 state and store in a io.BtyesIO object, but this is problematic because PyTorch makes no guarantees if the extra state is not a torch.Tensor. This has resulted in problems with ONNX export and Hugging Face Transformers.

#363 changed from storing the extra state in a torch.Tensor to io.BytesIO in order to reduce the overhead from GPU-CPU memory transfers. This PR restores the original torch.Tensor format, but performs the memory transfers asynchronously to reduce overhead. It's similar to the approach used in the operation-based API (#1063). It should be backward compatible and I've been able to load existing checkpoints. The attention docs mention extra state, but I don't think this PR affects it.

Fixes #1317.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refractor

Changes

  • Store module extra state in tensor instead of io.BytesIO

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 added the bug Something isn't working label Nov 15, 2024
@timmoon10 timmoon10 requested a review from ksivaman November 15, 2024 00:14
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L0 L1

@timmoon10 timmoon10 merged commit 8c00424 into NVIDIA:main Dec 5, 2024
14 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
1.14.0 bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[bug] Failed to load pretrained model with huggingface transformers
1 participant