-
-
Notifications
You must be signed in to change notification settings - Fork 5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core][VLM] Add support for prefix caching for multi-modal models #8348
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
from vllm.sequence import Sequence | ||
|
||
|
||
class TokenRangeAnnotation(NamedTuple): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very open to suggestions on naming! This is a pretty bulky name.
Sorry for the delay - I was busy with Pixtral release last week but will review this PR this week! |
vllm/inputs/data.py
Outdated
token_annotations: NotRequired[Optional[List["TokenRangeAnnotation"]]] | ||
""" | ||
Optional token annotations to capture content that will replace portions | ||
of the token IDs list. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Given how #8346 has evolved (placeholder ranges are now propagated instead of inlined into MM data) I'll likely remove this and instead compute the annotations downstream once that change lands. But the rest of the change should still be applicable :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to clarify, what is the difference between the placeholder range data structures in #8346 and the TokenRangeAnnotation
data structures in this PR?
Given that you are not creating any new example scripts in this PR, am I correct that the placeholder range data structures in #8346 are more "frontend-oriented" and serve to align placeholder tokens with multimodal input within the prompt (in a way that is model- and workload-specific), while the TokenRangeAnnotation data structures in this PR are more "backend-oriented" and serve to bridge multimodal data into core engine functionality such as prefix cache, block management, etc? With the idea being that the TokenRangeAnnotation's will be computed from the placeholder token range data structures?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, that frontend/backend distinction was how I was thinking about it, but I could be convinced to combine them too.
vllm/core/block/token_ids.py
Outdated
def adjusted(self, tokens_start: int, | ||
tokens_end: int) -> Optional["TokenRangeAnnotation"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think using clip
would be better expressed than adjusted
here for a "range". WDYT?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, "slice" could work (this is even the term which is used in the method comment.)
vllm/core/block/token_ids.py
Outdated
if key.start is None: | ||
start = 0 | ||
elif key.start < 0: | ||
start = len(self) + key.start | ||
else: | ||
start = key.start |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if key.start is None: | |
start = 0 | |
elif key.start < 0: | |
start = len(self) + key.start | |
else: | |
start = key.start | |
start = key.start or 0 | |
start += len(self) if start < 0 else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @petersalas had a few nits and some clarifying questions. Overall very excited for this - very cool how multimodal is integrated into prefix caching. Thanks for the PR!
vllm/core/block/token_ids.py
Outdated
def adjusted(self, tokens_start: int, | ||
tokens_end: int) -> Optional["TokenRangeAnnotation"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, "slice" could work (this is even the term which is used in the method comment.)
|
||
def adjusted(self, tokens_start: int, | ||
tokens_end: int) -> Optional["TokenRangeAnnotation"]: | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit - overall having a little trouble understanding what this method does & why the formulae are as they are; might benefit from explanatory comments for each argument & a few-sentence example on how the token range & content offset get adjusted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good suggestion! I added some examples in the docstring which hopefully clarify things a bit.
vllm/core/block/token_ids.py
Outdated
key=lambda a: a.token_index) | ||
return TokenIds(token_ids, sorted_annotations) | ||
|
||
def chunks(self, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two thoughts:
-
What about renaming this
to_chunks
orget_chunks
, in order to make it a little clearer that this method performs a relatively involved process in order to extract chunks? -
It looks like the chunks() is invoked at least twice within the engine code; I'm wondering does it make sense to cache the result?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Good suggestion -- renamed it to
to_chunks
. - I think it'd be tricky to cache, but I did add a fast path in the slice operation since the structure gets sliced for each decoded token and those will always be after the last annotation (if there is any).
@@ -852,7 +856,9 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], | |||
- int: The computed hash value for the block. | |||
""" | |||
assert (prev_block_hash is None) == is_first_block | |||
return hash((is_first_block, prev_block_hash, *cur_block_token_ids)) | |||
return hash( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So just want to make sure I understand correctly. Regarding prefix caching -
-
It used to be that a prefix caching block hash was derived from
is_first_block
, prev block has, and current block token ids -
Now, the block hash is additionally derived from the
annotations
associated with the token ids.
One question I had when I started reviewing this PR was, How does prefix caching match a prefix that includes multimodal data i.e. an image? Is it based on matching the hash of the raw image data?
Since annotations
includes multimodal content hashes, it would appear that my guess is correct? So for an image (for example), the TokenRangeAnnotation content hashes might be computed from the raw tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You got it! (With one nit w.r.t. your last question: the hashes are specifically computed for anything that can't be mapped to tokens.)
vllm/inputs/data.py
Outdated
token_annotations: NotRequired[Optional[List["TokenRangeAnnotation"]]] | ||
""" | ||
Optional token annotations to capture content that will replace portions | ||
of the token IDs list. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just want to clarify, what is the difference between the placeholder range data structures in #8346 and the TokenRangeAnnotation
data structures in this PR?
Given that you are not creating any new example scripts in this PR, am I correct that the placeholder range data structures in #8346 are more "frontend-oriented" and serve to align placeholder tokens with multimodal input within the prompt (in a way that is model- and workload-specific), while the TokenRangeAnnotation data structures in this PR are more "backend-oriented" and serve to bridge multimodal data into core engine functionality such as prefix cache, block management, etc? With the idea being that the TokenRangeAnnotation's will be computed from the placeholder token range data structures?
replace them. | ||
""" | ||
|
||
content_hash: int |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be a failure on my end, but where are these hashes actually computed? Are these hashes derived from the unprocessed multimodal data (i.e. raw image pixels for images)?
Will there need to be/is there already a way for the engine to automatically choose the appropriate hash function for a given modality?
Are all of these questions contingent on how #8346 gets integrated with this PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, they were -- originally #8346 took an approach of hiding more multi-modal logic away in each multi-modal model and I was going to do the same for hashing (i.e. delegate it to the model). But since I ended up propagating the placeholder ranges explicitly to the Sequence
I updated this change to do the MM hashing there as well.
This pull request has merge conflicts that must be resolved before it can be |
49c8c91
to
342d3d0
Compare
supports_chunked_prefill: ClassVar[bool] = False | ||
""" | ||
A flag that indicates this model supports chunked prefill. | ||
""" | ||
|
||
supports_prefix_caching: ClassVar[bool] = False | ||
""" | ||
A flag that indicates this model supports prefix caching. | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a little weird that these are on SupportsMultiModal
but the alternative that came to mind was to require tagging every non-multi-modal model as well. Happy to do whatever reviewers think is best here :)
|
||
|
||
class TokenIds: | ||
token_ids: Tuple[int, ...] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't particularly love TokenIds.token_ids
. Maybe the type should just beTokens
? Maybe BlockTokens
?
Signed-off-by: Peter Salas <[email protected]>
342d3d0
to
edf4a55
Compare
This pull request has merge conflicts that must be resolved before it can be |
Hi, thanks for the great workt! I was wondering if there’s any update on its status or an estimated timeline for its review/merge? |
Closing as superseded by #11187 |
This adds support for prefix caching with multi-modal models -- in particular it enables it for Ultravox which uses the precise placeholders added in #8346.
With this change,
SelfAttnBlockSpaceManager
et al. now pass aTokenIds
type around instead ofList[int]
to represent token ids. This new type can also containTokenRangeAnnotation
s which capture the contents that will ultimately replace the placeholder tokens. TheSequence
calculates these by hashing multi-modal content that supports it (currently only implemented for audio).FIX #9790