[PyTorch] Store module extra state in tensor #1335
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
TE modules implement
get_extra_state
/set_extra_state
in order to include FP8 state in checkpoints. We currently pickle the FP8 state and store in aio.BtyesIO
object, but this is problematic because PyTorch makes no guarantees if the extra state is not atorch.Tensor
. This has resulted in problems with ONNX export and Hugging Face Transformers.#363 changed from storing the extra state in a
torch.Tensor
toio.BytesIO
in order to reduce the overhead from GPU-CPU memory transfers. This PR restores the originaltorch.Tensor
format, but performs the memory transfers asynchronously to reduce overhead. It's similar to the approach used in the operation-based API (#1063). It should be backward compatible and I've been able to load existing checkpoints. The attention docs mention extra state, but I don't think this PR affects it.Fixes #1317.
Type of change
Changes
io.BytesIO
Checklist: