-
Notifications
You must be signed in to change notification settings - Fork 495
Explore new pipeline that overlaps optimizer with emb_lookup #2916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
TroyGarden
wants to merge
2
commits into
pytorch:main
Choose a base branch
from
TroyGarden:export-D64479105
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Summary: # context * this is some BE work and minor refactoring when working on pipeline optimization * major change is to add "pin_memory" option to the test_input file for `ModelInput` generation: ``` The `pin_memory()` call for all KJT tensors are important for training benchmark, and also valid argument for the prod training scenario: TrainModelInput should be created on pinned memory for a fast transfer to gpu. For more on pin_memory: https://pytorch.org/tutorials/intermediate/pinmem_nonblock.html#pin-memory ``` * minor refactoring includes (1) default parameters for TrainPipeline benchmark so that the embedding size, batch size, etc. are resonable. (2) fix the batch index error in trace, previously used (curr_index+1) (3) split the `EmbeddingPipelinedForward` __call__ function into two parts. * trace comparison: the `pin_memory()` for the ModelInput is critical for a non_blocking cpu to gpu data copy before copy_batch_to_gpu is the same size as gpu data transfer {F1977324224} after: copy_batch_to_gpu is hardly seen in trace {F1977324220} Reviewed By: aporialiao Differential Revision: D73514639
Summary: # context * this workstream started from an training QPS optimization initiated from the PG side (see doc in the reference section), observing the embedding lookup can overlap with the optimizer. * Embedding table weights are updated in the fused backward (fused-TBE), so the embedding lookup can start immediately after backward is completed without dependency on the optiimzer. * we use a separate stream to run embedding lookup so that it can overlap with the previous optimizer (changed, see below) * there is also an option of using data_dist stream for this embedding lookup, the output_dist won't be block but the start_sparse_data_dist would, which results a smaller mem footprint. WARNING: This pipeline **DOES NOT** work for EBC/EC with feature processors because the embedding lookup is started immediately after TBE backward (where the embedding tables' weights have been updated) # benchmark readings * runtime: SemiSync < FusedSparseDist (lookup after opt) < FusedSparseDist (lookup before opt) < SparseDist ``` TrainPipelineSemiSync | Runtime (P90): 5447.42 ms | Peak Memory alloc (P90): 61.63 GB | Peak Memory reserved (P90): 64.31 GB TrainPipelineFusedSparseDist | Runtime (P90): 5605.63 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.61 GB TrainPipelineFusedSparseDist* | Runtime (P90): 5661.92 ms | Peak Memory alloc (P90): 53.23 GB | Peak Memory reserved (P90): 68.67 GB TrainPipelineSparseDist | Runtime (P90): 6034.46 ms | Peak Memory alloc (P90): 51.80 GB | Peak Memory reserved (P90): 62.25 GB * embedding_lookup_after_opt = False ``` * traces show that: (1) the emb_lookup is right behind the TBE-bwd (on the same cuda stream) (2) the output_dist is invoked right after each emb_lookup (there are two, one for unweighted ebc, one for weighted) (3) the optimizer seems **NOT** overlap with emb_lookup kernel when `embedding_lookup_after_opt = False` {F1977309185} (4) the optimizer still does **NOT** overlap with emb_lookup kernel, but it fills in the gap between the `KJTTensorAwaitable.wait()` and the embedding lookup kernel when `embedding_lookup_after_opt = True` {F1977309202} (5) if use a separate stream for embedding lookup, so that the following `start_sparse_data_dist` can start immediately. however this causes extra memory consumption. {F1977366363} (6) if re-use the data_dist stream for embedding lookup, the following up `start_sparse_data_dist` will wait for embedding lookup to complete, the measured memory footprint is smaller {F1977366349} NOTE: Based on (5) and (6) we set `use_emb_lookup_stream = False` is the default behavior # conclusions * Based on a simple model (SparseNN), both "Fused Sparse Dist" pipeline and the "Semi Sync" pipeline are faster than the current default (commonly used) "Sparse Dist" pipeline, respectively -7% (fused sparse dist) and -10% (semi sync) in runtime. * In a more realistic scenario, the optimizer step has a longer runtime footprint, which can amplify this optimization. * The "Semi Sync" pipeline has a larger QPS win but it produces slightly different numerical training results, while the "Fused Sparse Dist" pipeline with a slight few QPS win should be numerically the same as the default pipeline. * It would be the user's choice for which one to use. # reference * https://dev-discuss.pytorch.org/t/fsdp-cudacachingallocator-an-outsider-newb-perspective/1486 Differential Revision: D64479105
This pull request was exported from Phabricator. Differential Revision: D64479105 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Labels
CLA Signed
This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
fb-exported
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary:
context
WARNING: This pipeline DOES NOT work for EBC/EC with feature processors because the embedding lookup is started immediately after TBE backward (where the embedding tables' weights have been updated)
benchmark readings
(1) the emb_lookup is right behind the TBE-bwd (on the same cuda stream)
(2) the output_dist is invoked right after each emb_lookup (there are two, one for unweighted ebc, one for weighted)
(3) the optimizer seems NOT overlap with emb_lookup kernel when
embedding_lookup_after_opt = False
{F1977309185}
(4) the optimizer still does NOT overlap with emb_lookup kernel, but it fills in the gap between the
KJTTensorAwaitable.wait()
and the embedding lookup kernel whenembedding_lookup_after_opt = True
{F1977309202}
(5) if use a separate stream for embedding lookup, so that the following
start_sparse_data_dist
can start immediately. however this causes extra memory consumption.{F1977366363}
(6) if re-use the data_dist stream for embedding lookup, the following up
start_sparse_data_dist
will wait for embedding lookup to complete, the measured memory footprint is smaller{F1977366349}
NOTE: Based on (5) and (6) we set
use_emb_lookup_stream = False
is the default behaviorconclusions
reference
Differential Revision: D64479105