Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Refactor caching of cumulative sequence lengths #630

Merged

Conversation

timmoon10
Copy link
Collaborator

To avoid redundantly calculating cumulative sequence lengths for attention, we only compute when layer_number == 1 and otherwise use a cached value. However, pipeline parallelism breaks this optimization since most ranks will never have a layer with layer_number == 1. This PR removes the caching logic's dependency on layer_number, which unfortunately reintroduces redundant calculation except in cases where the sequence lengths are fixed.

This is a quick bugfix, but discussion is welcomed on how best to avoid the redundant calculation. Maybe we could split layer_number into two things: the local layer number and the scaling factor used at:

scale *= self.layer_number

@timmoon10 timmoon10 added the bug Something isn't working label Jan 25, 2024
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch

@timmoon10 timmoon10 changed the title [PyTroch] Refactor caching of cumulative sequence lengths [PyTorch] Refactor caching of cumulative sequence lengths Jan 25, 2024
if cu_seqlens_q is None or cu_seqlens_kv is None:
assert (attention_mask is not None
), "Please provide attention_mask for padding!"
cu_seqlens_q, indices_q = get_cu_seqlens_and_indices(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so here we are basically not doing caching anymore, right? This makes sense I guess since it depends on the contents of the mask.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, I think we'll need a new API to safely support this optimization.

@minitu
Copy link
Contributor

minitu commented Feb 5, 2024

Adding that MLPerf LLM training is currently using the release_v1.3 branch because this change was merged in that release branch but not into main, so please push this forward.

Copy link
Member

@ksivaman ksivaman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@ksivaman
Copy link
Member

ksivaman commented Feb 5, 2024

/te-ci pytorch

@ksivaman
Copy link
Member

ksivaman commented Feb 5, 2024

I just saw @ptrendx's comment on #635 regarding the issue of pipeline parallelism and how layer_number==1 is not set for any layer. To re-enable caching and remove this issue, wouldn't a simple fix be to assume/set layer number to 1 when not provided, or to always recompute the cu_seqlens* tensors and indices when layer number is not set?

@timmoon10

@ptrendx
Copy link
Member

ptrendx commented Feb 5, 2024

@ksivaman The problem is that NeMo is setting the layer number, but it is counting them "globally" in the full model which is cut using PP. This behavior makes sense so we should not assume that layer_number==1 will ever be true on a given GPU.

@timmoon10 timmoon10 merged commit da30634 into NVIDIA:main Feb 6, 2024
9 checks passed
@parthmannan
Copy link

To add another observation here why layer_number based caching does not work - There are model implementations where attention is called multiple times in a transformer block with a different sequence length and batch size shape. Relying on layer number caching breaks this as the cache is set by the last attention to be called in layer_number=1. The next layer where the 1st attention is called, it ends up with an incorrect cu_seqlens_q/kv.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants