-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make float32_qk_product and float32_logits apply during inference #1225
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd prefer that the quantization settings are are not part of the /models - the precision isn't a property of the model, the user can still run the gemma model with different precision settings There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense, I'll remove these from the model configs |
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -477,7 +477,7 @@ def apply_attention_dot( | |||
"""Apply Attention.""" | ||||
validate_compute_axis_order(self.compute_axis_order) | ||||
# Casting qk_product and softmaxt computation for float32 for model stability. | ||||
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_qk_product: | ||||
if self.float32_qk_product: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you have to set precision as well for float32 to actually take effect https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision - we have this option in maxtext maxtext/MaxText/configs/base.yml Line 79 in d33821f
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point which I hadn't considered. I suppose there are two moving parts here: are the flops performed in bf16 or fp32, and are the results accumulated in bf16 or fp32. I think these are generally controlled with the It appears that on default precision the flops happen in bf16 even when one or more of the inputs are in fp32. However, the accumulation can still happen in fp32, and that seems to have been enough to solve our particular problem. In particular, the compiler seems to recognize that even though the python says to upcast to fp32, it can elide that because it's going to do the computation. However, it still outputs fp32. This is the qk product with And this is with I'm not 100% confident in my interpretation of those graphs, but this would explain why it takes longer even without changing the precision parameter. Separately, it looks like |
||||
if isinstance(key, KVTensor): | ||||
key = key.dequant() | ||||
query = query.astype(jnp.float32) | ||||
|
@@ -491,7 +491,7 @@ def apply_attention_dot( | |||
attn_weights = attn_weights * self.attn_logits_soft_cap | ||||
|
||||
# Casting softmaxt computation for float32 for model stability. | ||||
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits: | ||||
if self.float32_logits: | ||||
attn_weights = attn_weights.astype(jnp.float32) | ||||
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode) | ||||
if attn_mask is not None: | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Can we move these next to similar options in line 123?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do