Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][VLM] Add support for prefix caching for multi-modal models #8348

Closed

Conversation

petersalas
Copy link
Contributor

@petersalas petersalas commented Sep 10, 2024

This adds support for prefix caching with multi-modal models -- in particular it enables it for Ultravox which uses the precise placeholders added in #8346.

With this change, SelfAttnBlockSpaceManager et al. now pass a TokenIds type around instead of List[int] to represent token ids. This new type can also contain TokenRangeAnnotations which capture the contents that will ultimately replace the placeholder tokens. The Sequence calculates these by hashing multi-modal content that supports it (currently only implemented for audio).

FIX #9790

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@ywang96 ywang96 self-assigned this Sep 11, 2024
@petersalas petersalas changed the title [WIP] [Core][VLM] Add support for placeholder token content hashes [Core][VLM] Add support for placeholder token content hashes Sep 12, 2024
@petersalas petersalas marked this pull request as ready for review September 12, 2024 22:51
from vllm.sequence import Sequence


class TokenRangeAnnotation(NamedTuple):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very open to suggestions on naming! This is a pretty bulky name.

@ywang96
Copy link
Member

ywang96 commented Sep 16, 2024

Sorry for the delay - I was busy with Pixtral release last week but will review this PR this week!

Comment on lines 126 to 155
token_annotations: NotRequired[Optional[List["TokenRangeAnnotation"]]]
"""
Optional token annotations to capture content that will replace portions
of the token IDs list.
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given how #8346 has evolved (placeholder ranges are now propagated instead of inlined into MM data) I'll likely remove this and instead compute the annotations downstream once that change lands. But the rest of the change should still be applicable :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to clarify, what is the difference between the placeholder range data structures in #8346 and the TokenRangeAnnotation data structures in this PR?

Given that you are not creating any new example scripts in this PR, am I correct that the placeholder range data structures in #8346 are more "frontend-oriented" and serve to align placeholder tokens with multimodal input within the prompt (in a way that is model- and workload-specific), while the TokenRangeAnnotation data structures in this PR are more "backend-oriented" and serve to bridge multimodal data into core engine functionality such as prefix cache, block management, etc? With the idea being that the TokenRangeAnnotation's will be computed from the placeholder token range data structures?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, that frontend/backend distinction was how I was thinking about it, but I could be convinced to combine them too.

Comment on lines 32 to 33
def adjusted(self, tokens_start: int,
tokens_end: int) -> Optional["TokenRangeAnnotation"]:
Copy link
Collaborator

@Isotr0py Isotr0py Sep 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using clip would be better expressed than adjusted here for a "range". WDYT?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, "slice" could work (this is even the term which is used in the method comment.)

Comment on lines 216 to 221
if key.start is None:
start = 0
elif key.start < 0:
start = len(self) + key.start
else:
start = key.start
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if key.start is None:
start = 0
elif key.start < 0:
start = len(self) + key.start
else:
start = key.start
start = key.start or 0
start += len(self) if start < 0 else 0

Copy link
Contributor

@afeldman-nm afeldman-nm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @petersalas had a few nits and some clarifying questions. Overall very excited for this - very cool how multimodal is integrated into prefix caching. Thanks for the PR!

Comment on lines 32 to 33
def adjusted(self, tokens_start: int,
tokens_end: int) -> Optional["TokenRangeAnnotation"]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, "slice" could work (this is even the term which is used in the method comment.)


def adjusted(self, tokens_start: int,
tokens_end: int) -> Optional["TokenRangeAnnotation"]:
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit - overall having a little trouble understanding what this method does & why the formulae are as they are; might benefit from explanatory comments for each argument & a few-sentence example on how the token range & content offset get adjusted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good suggestion! I added some examples in the docstring which hopefully clarify things a bit.

vllm/core/block/token_ids.py Outdated Show resolved Hide resolved
key=lambda a: a.token_index)
return TokenIds(token_ids, sorted_annotations)

def chunks(self,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two thoughts:

  1. What about renaming this to_chunks or get_chunks, in order to make it a little clearer that this method performs a relatively involved process in order to extract chunks?

  2. It looks like the chunks() is invoked at least twice within the engine code; I'm wondering does it make sense to cache the result?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Good suggestion -- renamed it to to_chunks.
  2. I think it'd be tricky to cache, but I did add a fast path in the slice operation since the structure gets sliced for each decoded token and those will always be after the last annotation (if there is any).

@@ -852,7 +856,9 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int],
- int: The computed hash value for the block.
"""
assert (prev_block_hash is None) == is_first_block
return hash((is_first_block, prev_block_hash, *cur_block_token_ids))
return hash(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just want to make sure I understand correctly. Regarding prefix caching -

  • It used to be that a prefix caching block hash was derived from is_first_block, prev block has, and current block token ids

  • Now, the block hash is additionally derived from the annotations associated with the token ids.

One question I had when I started reviewing this PR was, How does prefix caching match a prefix that includes multimodal data i.e. an image? Is it based on matching the hash of the raw image data?

Since annotations includes multimodal content hashes, it would appear that my guess is correct? So for an image (for example), the TokenRangeAnnotation content hashes might be computed from the raw tokens?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You got it! (With one nit w.r.t. your last question: the hashes are specifically computed for anything that can't be mapped to tokens.)

Comment on lines 126 to 155
token_annotations: NotRequired[Optional[List["TokenRangeAnnotation"]]]
"""
Optional token annotations to capture content that will replace portions
of the token IDs list.
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just want to clarify, what is the difference between the placeholder range data structures in #8346 and the TokenRangeAnnotation data structures in this PR?

Given that you are not creating any new example scripts in this PR, am I correct that the placeholder range data structures in #8346 are more "frontend-oriented" and serve to align placeholder tokens with multimodal input within the prompt (in a way that is model- and workload-specific), while the TokenRangeAnnotation data structures in this PR are more "backend-oriented" and serve to bridge multimodal data into core engine functionality such as prefix cache, block management, etc? With the idea being that the TokenRangeAnnotation's will be computed from the placeholder token range data structures?

replace them.
"""

content_hash: int
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be a failure on my end, but where are these hashes actually computed? Are these hashes derived from the unprocessed multimodal data (i.e. raw image pixels for images)?

Will there need to be/is there already a way for the engine to automatically choose the appropriate hash function for a given modality?

Are all of these questions contingent on how #8346 gets integrated with this PR?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, they were -- originally #8346 took an approach of hiding more multi-modal logic away in each multi-modal model and I was going to do the same for hashing (i.e. delegate it to the model). But since I ended up propagating the placeholder ranges explicitly to the Sequence I updated this change to do the MM hashing there as well.

Copy link

mergify bot commented Oct 29, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. @petersalas please rebase it. https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@petersalas petersalas changed the title [Core][VLM] Add support for placeholder token content hashes [Core][VLM] Add support for prefix caching for multi-modal models Nov 8, 2024
Comment on lines +30 to +39
supports_chunked_prefill: ClassVar[bool] = False
"""
A flag that indicates this model supports chunked prefill.
"""

supports_prefix_caching: ClassVar[bool] = False
"""
A flag that indicates this model supports prefix caching.
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a little weird that these are on SupportsMultiModal but the alternative that came to mind was to require tagging every non-multi-modal model as well. Happy to do whatever reviewers think is best here :)



class TokenIds:
token_ids: Tuple[int, ...]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't particularly love TokenIds.token_ids. Maybe the type should just beTokens? Maybe BlockTokens?

@petersalas petersalas force-pushed the psalas/annotated-token-ids branch from 342d3d0 to edf4a55 Compare November 8, 2024 23:36
Copy link

mergify bot commented Nov 13, 2024

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @petersalas.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@cooleel
Copy link
Contributor

cooleel commented Dec 10, 2024

Hi, thanks for the great workt! I was wondering if there’s any update on its status or an estimated timeline for its review/merge?

@ywang96
Copy link
Member

ywang96 commented Dec 10, 2024

@cooleel We decided to work on adding prefix caching for multimodal models on V1 instead since there are some fundamental changes on how cache manager is designed. Stay tuned and feel free to check our multimodality roadmap at #4194!

@DarkLight1337
Copy link
Member

Closing as superseded by #11187

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Usage]: prefix caching support for multimodal models
6 participants