-
Notifications
You must be signed in to change notification settings - Fork 403
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ExecuTorch][BE] Split kv cache and SDPA for better code sharing #7413
base: gh/kimishpatel/149/base
Are you sure you want to change the base?
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/7413
Note: Links to docs will display an error until the docs builds have been completed. ❌ 7 New FailuresAs of commit 275144b with merge base 49cc399 (): NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 6356acba83a82cb7d19747187a254a735fa77d28 Pull Request resolved: #7413
This PR needs a
|
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
Summary: + Make all the backend specific kvcache and sdpa implementation abide by the new API Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 369434c4d64e6d4500ecfea03b0fd99945b30461 Pull Request resolved: #7413
…sharing" Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress [ghstack-poisoned]
Summary: Why? We have coupled SDPA with kv cache for a while. Initially this was done as we implemented sdpa_with_kv_cache custom op to reduce multiple copy overheads from kv cache update. (This could have been done by having separate custom kv cache update and custom sdpa op. Recent changes enabled this.) As a result of SDPA module owning kv cache, we get a) non-composable implementation and b) harder to reuse model definition and components from repos like tune. Output of this is that we have multiple definition of the same model, llama, lying around in ET, TorchChat and Tune. This diff and subsequent ones will try to move in the direction where custom kv cache and custom sdpa become decoupled and composable, making it more module-swap friendly with tune's model definition. How. Earlier PRs decoupled kv cache update from sdpa. So now 1. Decouple SDPA nn.Module from KV cache. 2. Standardize on KVCache and SDPA interface. That is KVCache and SDPA both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted tensors. 3. 2 will introduce multiple tranposes when KVCache and SDPA are replaced by custom modules, but we will write graph pass to undo those. Test Plan: Existing tests. Make sure perf doesnt regress ghstack-source-id: 6289ce22a2c190da7e38e098ba8a5d0254d6bf9d Pull Request resolved: #7413
@@ -212,6 +215,13 @@ def export(self) -> "LLMEdgeManager": | |||
|
|||
return self | |||
|
|||
def run_canonical_optimizations(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment on this function
@@ -47,20 +37,21 @@ def forward( | |||
seqlen, | |||
mask, | |||
): | |||
q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just thought about it again, and adding this transpose here and also before in the llama_transformer.py
so that we can share code for kv_cache.py (this is the reason right?) doesn't really make sense since we are using a custom export-friendly KV cache already anyways: https://github.com/pytorch/executorch/blob/main/extension/llm/modules/kv_cache.py#L13
Stack from ghstack (oldest at bottom):
Summary:
Why?
We have coupled SDPA with kv cache for a while. Initially this was done
as we implemented sdpa_with_kv_cache custom op to reduce multiple copy
overheads from kv cache update. (This could have been done by having
separate custom kv cache update and custom sdpa op. Recent changes
enabled this.)
As a result of SDPA module owning kv cache, we get a) non-composable
implementation and b) harder to reuse model definition and components
from repos like tune. Output of this is that we have multiple definition
of the same model, llama, lying around in ET, TorchChat and Tune. This
diff and subsequent ones will try to move in the direction where custom
kv cache and custom sdpa become decoupled and composable, making it more
module-swap friendly with tune's model definition.
How.
Earlier PRs decoupled kv cache update from sdpa. So now
both operate on q, k, v in [B, # heads, seq_len, head_dim] formatted
tensors.
replaced by custom modules, but we will write graph pass to undo
those.
Test Plan:
Existing tests.
Make sure perf doesnt regress