-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] Refactor caching of cumulative sequence lengths #630
[PyTorch] Refactor caching of cumulative sequence lengths #630
Conversation
Signed-off-by: Tim Moon <[email protected]>
/te-ci pytorch |
if cu_seqlens_q is None or cu_seqlens_kv is None: | ||
assert (attention_mask is not None | ||
), "Please provide attention_mask for padding!" | ||
cu_seqlens_q, indices_q = get_cu_seqlens_and_indices( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, so here we are basically not doing caching anymore, right? This makes sense I guess since it depends on the contents of the mask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea, I think we'll need a new API to safely support this optimization.
Adding that MLPerf LLM training is currently using the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
/te-ci pytorch |
I just saw @ptrendx's comment on #635 regarding the issue of pipeline parallelism and how |
@ksivaman The problem is that NeMo is setting the layer number, but it is counting them "globally" in the full model which is cut using PP. This behavior makes sense so we should not assume that layer_number==1 will ever be true on a given GPU. |
To add another observation here why layer_number based caching does not work - There are model implementations where attention is called multiple times in a transformer block with a different sequence length and batch size shape. Relying on layer number caching breaks this as the cache is set by the last attention to be called in layer_number=1. The next layer where the 1st attention is called, it ends up with an incorrect cu_seqlens_q/kv. |
To avoid redundantly calculating cumulative sequence lengths for attention, we only compute when
layer_number == 1
and otherwise use a cached value. However, pipeline parallelism breaks this optimization since most ranks will never have a layer withlayer_number == 1
. This PR removes the caching logic's dependency onlayer_number
, which unfortunately reintroduces redundant calculation except in cases where the sequence lengths are fixed.This is a quick bugfix, but discussion is welcomed on how best to avoid the redundant calculation. Maybe we could split
layer_number
into two things: the local layer number and the scaling factor used at:TransformerEngine/transformer_engine/pytorch/attention.py
Line 1318 in 6c1a8bb