-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Paligemma: fix static cache test #33941
Conversation
@@ -378,7 +378,7 @@ def _update_causal_mask( | |||
if is_training: | |||
causal_mask = torch.triu(causal_mask, diagonal=1) | |||
else: | |||
causal_mask = torch.zeros_like(causal_mask) | |||
causal_mask[:, :sequence_length] = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was the cause as it was not masking dummy tokens from static cache, and thus we always ended up with no mask on those token positions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aah gotcha. good catch
@@ -604,8 +603,6 @@ def prepare_inputs_for_generation( | |||
min_dtype=min_dtype, | |||
cache_position=cache_position, | |||
batch_size=batch_size, | |||
is_training=is_training, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we come to prepare static cache from here, then we cannot be in training mode. I don't think it is common to pass labels
through generation, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not seeing many use-cases indeed, except for maybe constrained generation and RL?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
guess so, let's see what generation master (gante) thinks 😄
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If labels in paligemma
has the usual meaning (=tensor with which we compute the loss, with no further uses), then generate
will never use labels
:D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, yes those are normal labels :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, added comment on training case for generation :)
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for fixing 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks 🤗
* fix * not flaky anymore + style
What does this PR do?
Fixes the flaky test on paligemma from #33630