Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] LlavaNext add feature size check to avoid CUDA Runtime Error #33608

Conversation

laurentd-lunit
Copy link
Contributor

@laurentd-lunit laurentd-lunit commented Sep 20, 2024

What does this PR do?

In LlavaNextForConditionalGeneration, in the forward pass the following is applied:

special_image_mask = (
                    (input_ids == self.config.image_token_index).unsqueeze(-1).expand_as(inputs_embeds)
                )
                image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
                inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)

It happened to me that for some edge cases the special_image_mask and image_features shapes are incompatible which leads to a RuntimeError: CUDA error: device-side assert triggered in the masked_scatter operation.
The problem is it only happens very rarely and I couldn't pin point yet what causes the missmatch.

Rather than fixing the issue for edge cases, in this PR I'm suggesting to first check the respective size of image tokens and image features and ensure they match before applying the masked_scatter operation. This allows to raise a ValueError instead of getting the CUDA Runtime error which is useful because one can then handle the exception as they see fit and still continue using the GPU while the CUDA Runtime error even if caught as an exception will throw the same CUDA Runtime error if any other CUDA operation is applied, in other words it can't really be handled and necessarily breaks running of the script.

In short, this PR allows nicer error handling in some edge cases in LlavaNext when GPU is used.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@zucchini-nlp

Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @laurentd-lunit ! Raising an error looks good to me, as a general way to prevent errors. Would be super nice if you can get a reproducer to see why this happens, maybe a script that fails once in 10 runs?

This part can error out only if the input ids are prepared incorrectly in the first place (i.e. number of image tokens were not calculated correctly) or if you are generating without caching and for some reason your model generates more image tokens. The second case is possible only if you are tuning the model from scratch so it didn't learn what to generate yet

Please let me know if you can provide reproducers :)

And regarding changes, can we do same error in all llava models if you can? And I thinkg Paligemma and Qwen2-VL also use masked scatter to replace image tokens

@laurentd-lunit
Copy link
Contributor Author

@zucchini-nlp thanks for checking this PR!

Would be super nice if you can get a reproducer to see why this happens, maybe a script that fails once in 10 runs?

I've encountered this issue using a large custom dataset but I haven't pinned down which sample(s) it happens to yet. I suspect it's an edge case of image dimensions that maybe causes this issue. I'll try and find some time to investigate it and make a reproducer but here this PR was really more about getting around it by simply avoiding CUDA to crash if the dimensions do not match.

And regarding changes, can we do same error in all llava models if you can? And I thinkg Paligemma and Qwen2-VL also use masked scatter to replace image tokens

Right, I'll check other Llava models and Paligemma and Qwen2-VL and see if I can make the same changes!

@laurentd-lunit laurentd-lunit force-pushed the feat/llava-next-feature-token-mismatch-error-handling branch from 44d9dfc to e303474 Compare September 26, 2024 06:26
@laurentd-lunit
Copy link
Contributor Author

@zucchini-nlp
After doing some digging, I was able to pin down a sample where it happened within my custom dataset + custom vision backbone but it seems the issue has been fixed by #33564. So I don't think it's needed to make a reproducer here since the bug itself has been fixed.

However, I think it could still be useful to check feature sizes and return ValueErrror before doing the masked_scatter in case a new bug arises in the future. I added the check to all Llava architectures and found that Paligemma already had it implemented.

Let me know if there are any other changes required and if you think it's still worth merging this PR.

@laurentd-lunit laurentd-lunit force-pushed the feat/llava-next-feature-token-mismatch-error-handling branch from e303474 to 8b53700 Compare September 27, 2024 07:05
Copy link
Member

@zucchini-nlp zucchini-nlp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for investigating it! Llava-next is a bit peculiar because of all those pad/unpad going on

Yeah, agreed that we can get a better error message so I'm okay with merging this PR. Can you also add same error in Qwen2-VL and Chameleon pls? And rebase main because we added tests recently, so they should run as part of CI now

Otherwise look good to me, so feel free to tag core maintainer @ LysandreJik after making changes :)

src/transformers/models/llava/modeling_llava.py Outdated Show resolved Hide resolved
@laurentd-lunit laurentd-lunit force-pushed the feat/llava-next-feature-token-mismatch-error-handling branch from 48f4a91 to 53d407d Compare September 30, 2024 05:53
@laurentd-lunit
Copy link
Contributor Author

laurentd-lunit commented Sep 30, 2024

@zucchini-nlp Thanks for your review again! I incorporated your changes and suggestion. Hopefully should be good to merge now.

@LysandreJik would be good if you could check this when you get the chance, especially I had to modify some of the reference values in test_modeling_llava.py and test_modeling_vip_llava because it seems the assumed number of tokens was 1 too low:

Using llava processor logic to compute the expected number of image tokens:
num_image_tokens = (height // self.patch_size) * (width // self.patch_size) + 1

Here we have high,width = 30 and patch_size = 2 which would make: 15*15+1 = 226 then we do minus 1 because CLS token is ignored so it should be 225 (not 224).

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the ping! @ArthurZucker is much more comfortable with this code as he wrote a bunch of it, so his review will be much better than what mine would be, pinging him to review.

Thanks a lot for the PR 🙌

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! For legacy paths (chameleon) that don't have the processing in the processor, sounds good!
For the new path, I am not sure if this is a great addition especially if it breaks our support for fullgraph compile!

WDYT? 🤗 we can wait for our benchmarks as well!

@@ -1343,6 +1343,12 @@ def forward(

if pixel_values is not None:
image_tokens = self.get_image_tokens(pixel_values)
n_image_tokens_in_text = (input_ids == self.vocabulary_mapping.image_token_id).sum().item()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mmmm in general I don't mind, as this should help our users, but the .item() might break compile compatibility (well only full graph).

@McPatate that's where and when we would need to see how much we are losing from this small change ! 🤗 (FYI @LysandreJik )

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can alwayss wrap these in is_torchdynamo_compiling, same was a s we wrap all warnings/logging now in generation code. So we ask users to make sure the code works w/o compilation, to see all warning etc, and then compile the code which will not show the exact reason why/where this CUDA-side error was triggered

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay that makes sense. Just 🥶 to more checks, but this one is most probably cached should be alright

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is these is_compiling are unrelated to normal users ~-> expose them to unrelated codes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i see what you mean. Yes, the processing should maybe check this, but we cannot perform any checks before getting image hidden states. My main idea was to bring the same check we had earlier in merge_inputs method, so that after moving to the new logic we still can trace down bugs related to shape mismatch easily, or let users track that down

Also we won't do the sum() and item() every forward, for generation it is only for prefill stage after which we'll have image states in the cache. But anyway, if you think this is too many checks (given we now support old and new logic in VLMs for a few minor releases), I am okay with not adding it. I don't see it as a major blocker or anything, more like a nice addition for users :D

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay let's add it then 🤗

Comment on lines +514 to +519
n_image_tokens = (input_ids == self.config.image_token_index).sum(dim=-1)[0].item()
n_image_features = image_features.shape[1]
if n_image_tokens != n_image_features:
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
)
Copy link
Collaborator

@ArthurZucker ArthurZucker Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why we are adding this here as the processor is supposed to check this for non legacy path!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is supposed. There was only one edge case with llava-next which uses pad/unpad technique and since we used tensors in modeling, there were minor numerical inconsistencies

Right now it should work, but in general imo it's a good idea to help users pinpoint what went wrong in their code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in the forward pass IMO, we are adding extra processing, .sum and .item() as seen above, which are run for every single forward pass. biggest issue for me is duplicated work!

@zucchini-nlp
Copy link
Member

the test should have been fixed in the main, @laurentd-lunit can you rebase main?

@zucchini-nlp zucchini-nlp merged commit 0f49dea into huggingface:main Oct 15, 2024
16 of 17 checks passed
NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
…uggingface#33608)

* [feat] add feature size check to avoid CUDA Runtime Error

* [minor] add error handling to all llava models

* [minor] avoid nested if else

* [minor] add error message to Qwen2-vl and chameleon

* [fix] token dimension for check

* [minor] add feature dim check for videos too

* [fix] dimension check

* [fix] test reference values

---------

Co-authored-by: Raushan Turganbay <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants