-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] LlavaNext add feature size check to avoid CUDA Runtime Error #33608
[feat] LlavaNext add feature size check to avoid CUDA Runtime Error #33608
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @laurentd-lunit ! Raising an error looks good to me, as a general way to prevent errors. Would be super nice if you can get a reproducer to see why this happens, maybe a script that fails once in 10 runs?
This part can error out only if the input ids are prepared incorrectly in the first place (i.e. number of image tokens were not calculated correctly) or if you are generating without caching and for some reason your model generates more image tokens. The second case is possible only if you are tuning the model from scratch so it didn't learn what to generate yet
Please let me know if you can provide reproducers :)
And regarding changes, can we do same error in all llava models if you can? And I thinkg Paligemma and Qwen2-VL also use masked scatter to replace image tokens
@zucchini-nlp thanks for checking this PR!
I've encountered this issue using a large custom dataset but I haven't pinned down which sample(s) it happens to yet. I suspect it's an edge case of image dimensions that maybe causes this issue. I'll try and find some time to investigate it and make a reproducer but here this PR was really more about getting around it by simply avoiding CUDA to crash if the dimensions do not match.
Right, I'll check other Llava models and Paligemma and Qwen2-VL and see if I can make the same changes! |
44d9dfc
to
e303474
Compare
@zucchini-nlp However, I think it could still be useful to check feature sizes and return Let me know if there are any other changes required and if you think it's still worth merging this PR. |
e303474
to
8b53700
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for investigating it! Llava-next is a bit peculiar because of all those pad/unpad going on
Yeah, agreed that we can get a better error message so I'm okay with merging this PR. Can you also add same error in Qwen2-VL and Chameleon pls? And rebase main because we added tests recently, so they should run as part of CI now
Otherwise look good to me, so feel free to tag core maintainer @ LysandreJik after making changes :)
48f4a91
to
53d407d
Compare
@zucchini-nlp Thanks for your review again! I incorporated your changes and suggestion. Hopefully should be good to merge now. @LysandreJik would be good if you could check this when you get the chance, especially I had to modify some of the reference values in Using llava processor logic to compute the expected number of image tokens: Here we have high,width = 30 and patch_size = 2 which would make: 15*15+1 = 226 then we do minus 1 because CLS token is ignored so it should be 225 (not 224). |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the ping! @ArthurZucker is much more comfortable with this code as he wrote a bunch of it, so his review will be much better than what mine would be, pinging him to review.
Thanks a lot for the PR 🙌
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! For legacy paths (chameleon) that don't have the processing in the processor, sounds good!
For the new path, I am not sure if this is a great addition especially if it breaks our support for fullgraph compile!
WDYT? 🤗 we can wait for our benchmarks as well!
@@ -1343,6 +1343,12 @@ def forward( | |||
|
|||
if pixel_values is not None: | |||
image_tokens = self.get_image_tokens(pixel_values) | |||
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mmmm in general I don't mind, as this should help our users, but the .item() might break compile compatibility (well only full graph).
@McPatate that's where and when we would need to see how much we are losing from this small change ! 🤗 (FYI @LysandreJik )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can alwayss wrap these in is_torchdynamo_compiling
, same was a s we wrap all warnings/logging now in generation code. So we ask users to make sure the code works w/o compilation, to see all warning etc, and then compile the code which will not show the exact reason why/where this CUDA-side error was triggered
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay that makes sense. Just 🥶 to more checks, but this one is most probably cached should be alright
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The thing is these is_compiling
are unrelated to normal users ~-> expose them to unrelated codes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i see what you mean. Yes, the processing should maybe check this, but we cannot perform any checks before getting image hidden states. My main idea was to bring the same check we had earlier in merge_inputs
method, so that after moving to the new logic we still can trace down bugs related to shape mismatch easily, or let users track that down
Also we won't do the sum() and item() every forward, for generation it is only for prefill stage after which we'll have image states in the cache. But anyway, if you think this is too many checks (given we now support old and new logic in VLMs for a few minor releases), I am okay with not adding it. I don't see it as a major blocker or anything, more like a nice addition for users :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay let's add it then 🤗
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item() | ||
n_image_features = image_features.shape[1] | ||
if n_image_tokens != n_image_features: | ||
raise ValueError( | ||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know why we are adding this here as the processor is supposed to check this for non legacy path!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, it is supposed. There was only one edge case with llava-next which uses pad/unpad technique and since we used tensors in modeling, there were minor numerical inconsistencies
Right now it should work, but in general imo it's a good idea to help users pinpoint what went wrong in their code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in the forward pass IMO, we are adding extra processing, .sum and .item() as seen above, which are run for every single forward pass. biggest issue for me is duplicated work!
the test should have been fixed in the |
…ture-token-mismatch-error-handling
…ture-token-mismatch-error-handling
…uggingface#33608) * [feat] add feature size check to avoid CUDA Runtime Error * [minor] add error handling to all llava models * [minor] avoid nested if else * [minor] add error message to Qwen2-vl and chameleon * [fix] token dimension for check * [minor] add feature dim check for videos too * [fix] dimension check * [fix] test reference values --------- Co-authored-by: Raushan Turganbay <[email protected]>
What does this PR do?
In
LlavaNextForConditionalGeneration
, in the forward pass the following is applied:It happened to me that for some edge cases the
special_image_mask
andimage_features
shapes are incompatible which leads to aRuntimeError: CUDA error: device-side assert triggered
in themasked_scatter
operation.The problem is it only happens very rarely and I couldn't pin point yet what causes the missmatch.
Rather than fixing the issue for edge cases, in this PR I'm suggesting to first check the respective size of image tokens and image features and ensure they match before applying the
masked_scatter
operation. This allows to raise aValueError
instead of getting the CUDA Runtime error which is useful because one can then handle the exception as they see fit and still continue using the GPU while the CUDA Runtime error even if caught as an exception will throw the same CUDA Runtime error if any other CUDA operation is applied, in other words it can't really be handled and necessarily breaks running of the script.In short, this PR allows nicer error handling in some edge cases in LlavaNext when GPU is used.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@zucchini-nlp