Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make float32_qk_product and float32_logits apply during inference #1225

Merged
merged 3 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,10 @@ final_logits_soft_cap: 0.0
use_post_attn_norm: False
use_post_ffw_norm: False

# In dot_product attention, whether to upcast the qk product and attention logits to fp32
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Can we move these next to similar options in line 123?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do

float32_qk_product: False
float32_logits: False


# Combine matmuls for QKV and MLP
fused_qkv: False
Expand Down
4 changes: 3 additions & 1 deletion MaxText/configs/models/gemma-2b.yml
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer that the quantization settings are are not part of the /models - the precision isn't a property of the model, the user can still run the gemma model with different precision settings

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, I'll remove these from the model configs

Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ mlp_activations: ["gelu","linear"]
vocab_size: 256128
decoder_block: "gemma"
normalization_layer_epsilon: 1.e-06
logits_via_embedding: True
logits_via_embedding: True
float32_qk_product: True
float32_qk_logits: True
4 changes: 3 additions & 1 deletion MaxText/configs/models/gemma-7b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,6 @@ mlp_activations: ["gelu","linear"]
vocab_size: 256128
decoder_block: "gemma"
normalization_layer_epsilon: 1.e-06
logits_via_embedding: True
logits_via_embedding: True
float32_qk_product: True
float32_qk_logits: True
2 changes: 2 additions & 0 deletions MaxText/configs/models/gemma2-27b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ attn_logits_soft_cap: 50.0
sliding_window_size: 4096
use_post_attn_norm: True
use_post_ffw_norm: True
float32_qk_product: True
float32_qk_logits: True
2 changes: 2 additions & 0 deletions MaxText/configs/models/gemma2-2b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ attn_logits_soft_cap: 50.0
sliding_window_size: 4096
use_post_attn_norm: True
use_post_ffw_norm: True
float32_qk_product: True
float32_qk_logits: True
2 changes: 2 additions & 0 deletions MaxText/configs/models/gemma2-9b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,5 @@ attn_logits_soft_cap: 50.0
sliding_window_size: 4096
use_post_attn_norm: True
use_post_ffw_norm: True
float32_qk_product: True
float32_qk_logits: True
4 changes: 2 additions & 2 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def apply_attention_dot(
"""Apply Attention."""
validate_compute_axis_order(self.compute_axis_order)
# Casting qk_product and softmaxt computation for float32 for model stability.
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_qk_product:
if self.float32_qk_product:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you have to set precision as well for float32 to actually take effect https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision - we have this option in maxtext

matmul_precision: "default"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point which I hadn't considered. I suppose there are two moving parts here: are the flops performed in bf16 or fp32, and are the results accumulated in bf16 or fp32. I think these are generally controlled with the precision and preferred_element_type arguments, respectively.

It appears that on default precision the flops happen in bf16 even when one or more of the inputs are in fp32. However, the accumulation can still happen in fp32, and that seems to have been enough to solve our particular problem. In particular, the compiler seems to recognize that even though the python says to upcast to fp32, it can elide that because it's going to do the computation. However, it still outputs fp32.

This is the qk product with float32_qk_product=False
Screenshot 2025-01-31 at 4 53 15 PM

And this is with float32_qk_product=True (note the output type is now f32)
Screenshot 2025-01-31 at 3 14 11 PM

I'm not 100% confident in my interpretation of those graphs, but this would explain why it takes longer even without changing the precision parameter.

Separately, it looks like matmul_precision consistently gets routed into DenseGeneral usages, but not into the raw einsums used in qk_product and wv_product. When I change matmul_precision in the config it does not affect the runtime of those operations, but if I add it explicitly to the einsums then the wv_product does take longer, which makes sense. Is that something we should fix just by adding those arguments to the einsums?

if isinstance(key, KVTensor):
key = key.dequant()
query = query.astype(jnp.float32)
Expand All @@ -491,7 +491,7 @@ def apply_attention_dot(
attn_weights = attn_weights * self.attn_logits_soft_cap

# Casting softmaxt computation for float32 for model stability.
if model_mode == common_types.MODEL_MODE_TRAIN and self.float32_logits:
if self.float32_logits:
attn_weights = attn_weights.astype(jnp.float32)
attn_mask = self.generate_attention_mask(query, key, decoder_segment_ids, model_mode)
if attn_mask is not None:
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=True,
float32_logits=True,
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
use_ragged_attention=cfg.use_ragged_attention,
Expand Down
4 changes: 2 additions & 2 deletions MaxText/layers/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention_local",
float32_qk_product=True,
float32_logits=True,
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
attention_type=attentions.AttentionType.LOCAL_SLIDING,
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,8 @@ def __call__(
mesh=mesh,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
fused_qkv=cfg.fused_qkv,
use_bias=True,
quant=self.quant,
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
)
Expand Down
2 changes: 2 additions & 0 deletions MaxText/layers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def __call__(
weight_dtype=cfg.weight_dtype,
dropout_rate=cfg.dropout_rate,
name="self_attention",
float32_qk_product=cfg.float32_qk_product,
float32_logits=cfg.float32_logits,
quant=self.quant,
kv_quant=quantizations.configure_kv_quant(cfg),
prefill_cache_axis_order=tuple([int(i) for i in cfg.prefill_cache_axis_order.split(",")]),
Expand Down