Skip to content

Commit 05abc90

Browse files
Merge pull request #2706 from AI-Hypercomputer:mohit/tokamax_quant_gmm
PiperOrigin-RevId: 834605168
2 parents 7701bd2 + 67dc1f3 commit 05abc90

File tree

6 files changed

+219
-76
lines changed

6 files changed

+219
-76
lines changed

src/MaxText/configs/base.yml

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,9 @@ save_quantized_params_path: ""
128128
model_call_mode: ""
129129
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the model will be quantized using qwix.
130130
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
131-
quantization_calibration_method: "absmax"
131+
weight_quantization_calibration_method: "absmax"
132+
act_quantization_calibration_method: "absmax"
133+
bwd_quantization_calibration_method: "absmax"
132134
# Shard the range finding operation for quantization. By default this is set to number of slices.
133135
quantization_local_shard_count: -1
134136

@@ -177,10 +179,26 @@ load_balance_loss_weight: 0.01 # weight for the load balance loss
177179
use_random_routing: False # whether to use random routing for debug/test purpose
178180
use_custom_sort_vjp: True # whether to use a custom sort vjp for sparse matmul ops
179181
use_ring_of_experts: False # whether to use ring of experts for sparse matmul expert parallelism
180-
# Tunable tiling dimensions used for Megablox
181-
tile_batch_seq: 512
182-
tile_embed_dim: 1024
183-
tile_mlp_dim: 1024
182+
# Tunable tiling dimensions used for MLP GMM, includes Tokamax ragged_dot and Megablox
183+
wi_tile_fwd_batch_seq: 512
184+
wi_tile_fwd_embed_dim: 1024
185+
wi_tile_fwd_mlp_dim: 1024
186+
wi_tile_dlhs_batch_seq: 512
187+
wi_tile_dlhs_embed_dim: 1024
188+
wi_tile_dlhs_mlp_dim: 1024
189+
wi_tile_drhs_batch_seq: 512
190+
wi_tile_drhs_embed_dim: 1024
191+
wi_tile_drhs_mlp_dim: 1024
192+
193+
wo_tile_fwd_batch_seq: 512
194+
wo_tile_fwd_embed_dim: 1024
195+
wo_tile_fwd_mlp_dim: 1024
196+
wo_tile_dlhs_batch_seq: 512
197+
wo_tile_dlhs_embed_dim: 1024
198+
wo_tile_dlhs_mlp_dim: 1024
199+
wo_tile_drhs_batch_seq: 512
200+
wo_tile_drhs_embed_dim: 1024
201+
wo_tile_drhs_mlp_dim: 1024
184202
norm_topk_prob: False # Boolean to enable the top-k probability normalization. Qwen3-specific normalization of router weights.
185203

186204
# How the expert axis is used to shard attention weights and activations

src/MaxText/configs/types.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,17 @@ class Quantization(BaseModel):
341341
kv_quant_dtype: Literal["int8", "int4"] = Field("int8", description="Data type for KV cache quantization.")
342342
quantization_local_shard_count: int = Field(-1, description="Shards the range finding operation for quantization.")
343343
use_qwix_quantization: bool = Field(False, description="Whether to use qwix for quantization.")
344-
quantization_calibration_method: str = Field(
344+
weight_quantization_calibration_method: str = Field(
345345
"absmax",
346-
description="Quantization calibration method used for weights and activations.",
346+
description="Quantization calibration method used for weights.",
347+
)
348+
act_quantization_calibration_method: str = Field(
349+
"absmax",
350+
description="Quantization calibration method used for activations.",
351+
)
352+
bwd_quantization_calibration_method: str = Field(
353+
"absmax",
354+
description="Quantization calibration method used for gradients.",
347355
)
348356

349357

@@ -547,9 +555,24 @@ class MoEKernels(BaseModel):
547555

548556
megablox: bool = Field(True, description="Whether to use Megablox kernels for MoE.")
549557
sparse_matmul: bool = Field(True, description="Whether to use sparse matmul kernels for MoE.")
550-
tile_batch_seq: int = Field(512, description="Tunable tiling dimension for batch/sequence in Megablox.")
551-
tile_embed_dim: int = Field(1024, description="Tunable tiling dimension for embedding in Megablox.")
552-
tile_mlp_dim: int = Field(1024, description="Tunable tiling dimension for MLP in Megablox.")
558+
wi_tile_fwd_batch_seq: int = Field(512, description="forward pass tiling dimension for batch/sequence in GMM for wi.")
559+
wi_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wi.")
560+
wi_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wi.")
561+
wi_tile_dlhs_batch_seq: int = Field(512, description="bwd pass dlhs tiling dimension for batch/sequence in GMM for wi.")
562+
wi_tile_dlhs_embed_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for embedding in GMM for wi.")
563+
wi_tile_dlhs_mlp_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for MLP in GMM for wi.")
564+
wi_tile_drhs_batch_seq: int = Field(512, description="bwd pass drhs tiling dimension for batch/sequence in GMM for wi.")
565+
wi_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wi.")
566+
wi_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wi.")
567+
wo_tile_fwd_batch_seq: int = Field(512, description="forward pass tiling dimension for batch/sequence in GMM for wo.")
568+
wo_tile_fwd_embed_dim: int = Field(1024, description="forward pass tiling dimension for embedding in GMM for wo.")
569+
wo_tile_fwd_mlp_dim: int = Field(1024, description="forward pass tiling dimension for MLP in GMM for wo.")
570+
wo_tile_dlhs_batch_seq: int = Field(512, description="bwd pass dlhs tiling dimension for batch/sequence in GMM for wo.")
571+
wo_tile_dlhs_embed_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for embedding in GMM for wo.")
572+
wo_tile_dlhs_mlp_dim: int = Field(1024, description="bwd pass dlhs tiling dimension for MLP in GMM for wo.")
573+
wo_tile_drhs_batch_seq: int = Field(512, description="bwd pass drhs tiling dimension for batch/sequence in GMM for wo.")
574+
wo_tile_drhs_embed_dim: int = Field(1024, description="bwd pass drhs tiling dimension for embedding in GMM for wo.")
575+
wo_tile_drhs_mlp_dim: int = Field(1024, description="bwd pass drhs tiling dimension for MLP in GMM for wo.")
553576

554577

555578
class DeepSeekMoE(BaseModel):
@@ -1400,7 +1423,6 @@ class DerivedValues(BaseModel):
14001423
None, description="Increment for global batch size during rampup."
14011424
)
14021425
rampup_samples_per_increment_to_load: None | float = Field(None, description="Samples per increment for rampup.")
1403-
tile_fwd_batch_seq: None | int = Field(None, description="Legacy alias for tile_batch_seq.")
14041426

14051427

14061428
# ----------------------------------------------------------------------------
@@ -1721,7 +1743,6 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
17211743
if self.expert_shard_attention_option == "context":
17221744
cp_size *= self.ici_expert_parallelism * self.dcn_expert_parallelism
17231745
self.context_parallel_size = cp_size
1724-
self.tile_fwd_batch_seq = self.tile_batch_seq # Legacy alias.
17251746
if self.pipeline_parallel_layers == -1:
17261747
if self.decoder_block == DecoderBlockType.DEEPSEEK:
17271748
moe_layers = self.num_decoder_layers - self.first_num_dense_layers

src/MaxText/kernels/megablox/backend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import dataclasses
2121
import functools
2222
from typing import Any, Optional
23+
import json
2324

2425
import jax
2526
from jax import lax
@@ -514,6 +515,11 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
514515
bytes_accessed = (lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) + out_bytes
515516
flops = 2 * m * k * n
516517
cost_estimate = pl.CostEstimate(flops=flops, bytes_accessed=bytes_accessed, transcendentals=0)
518+
metadata = {
519+
"preferred_element_type": jnp.dtype(preferred_element_type).name,
520+
"tiling": {"tile_m": tm, "tile_k": tk, "tile_n": tn},
521+
"transpose_rhs": transpose_rhs,
522+
}
517523
call_gmm = qpl.pallas_call(
518524
kernel,
519525
out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type),
@@ -532,6 +538,7 @@ def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset):
532538
compiler_params=pltpu.CompilerParams(dimension_semantics=("parallel", "arbitrary", "arbitrary")),
533539
interpret=interpret,
534540
cost_estimate=cost_estimate,
541+
metadata={"xprof_metadata": json.dumps(metadata)},
535542
)
536543

537544
out = call_gmm(
@@ -761,6 +768,11 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
761768
flops = 2 * m * k * n
762769
cost_estimate = pl.CostEstimate(flops=flops, bytes_accessed=bytes_accessed, transcendentals=0)
763770
lhs = lhs.swapaxes(0, 1)
771+
metadata = {
772+
"tiling": {"tile_m": tm, "tile_k": tk, "tile_n": tn},
773+
"prefer_element_type": jnp.dtype(preferred_element_type).name,
774+
"num_actual_groups": num_actual_groups,
775+
}
764776
call_gmm = qpl.pallas_call(
765777
kernel,
766778
out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type),
@@ -779,6 +791,7 @@ def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset):
779791
compiler_params=pltpu.CompilerParams(dimension_semantics=("parallel", "arbitrary", "arbitrary")),
780792
interpret=interpret,
781793
cost_estimate=cost_estimate,
794+
metadata={"xprof_metadata": json.dumps(metadata)},
782795
)
783796

784797
out = call_gmm(

src/MaxText/kernels/megablox/ops.py

Lines changed: 93 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
# pylint: disable=too-many-positional-arguments
1818

1919
import functools
20+
import dataclasses
2021
from typing import Literal
2122
import jax
2223
import jax.numpy as jnp
2324
from MaxText.kernels.megablox import backend
25+
from tokamax._src.ops.ragged_dot import pallas_mosaic_tpu_kernel as tokamax_backend
2426
import qwix
2527
import qwix.pallas as qpl
2628

@@ -30,14 +32,16 @@ def gmm(
3032
rhs: jnp.ndarray,
3133
group_sizes: jnp.ndarray,
3234
preferred_element_type: jnp.dtype = jnp.float32,
33-
tiling: tuple[int, int, int] = (128, 128, 128),
35+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
3436
group_offset: jnp.ndarray | None = None,
3537
existing_out: jnp.ndarray | None = None,
3638
transpose_rhs: bool = False,
3739
interpret: bool = False,
3840
lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
3941
rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None,
4042
use_qwix_quantization: bool = False,
43+
use_tokamax_backend: bool = False,
44+
is_fsdp_shard_on_exp: bool = False,
4145
):
4246
"""Grouped matrix multiplication operation."""
4347
quantization_rule = None
@@ -57,7 +61,7 @@ def gmm(
5761
)
5862

5963
gmm_fwd_bwd = lambda *args: _gmm_fwd(*args)[0] # pylint: disable=C3001
60-
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9))
64+
gmm_fwd_bwd = jax.custom_vjp(gmm_fwd_bwd, nondiff_argnums=(3, 4, 7, 8, 9, 10, 11))
6165
gmm_fwd_bwd.defvjp(_gmm_fwd, functools.partial(_gmm_bwd, lhs.dtype, rhs.dtype))
6266
return gmm_fwd_bwd(
6367
lhs,
@@ -70,6 +74,8 @@ def gmm(
7074
transpose_rhs,
7175
interpret,
7276
quantization_rule,
77+
use_tokamax_backend,
78+
is_fsdp_shard_on_exp,
7379
)
7480

7581

@@ -78,12 +84,14 @@ def _gmm_fwd(
7884
rhs: jnp.ndarray,
7985
group_sizes: jnp.ndarray,
8086
preferred_element_type: jnp.dtype = jnp.float32,
81-
tiling: tuple[int, int, int] = (128, 128, 128),
87+
tiling: tuple[int, int, int, int, int, int, int, int, int] = (128, 128, 128, 128, 128, 128, 128, 128, 128),
8288
group_offset: jnp.ndarray | None = None,
8389
existing_out: jnp.ndarray | None = None,
8490
transpose_rhs: bool = False,
8591
interpret: bool = False,
8692
quantization_rule: qwix.QtRule | None = None,
93+
use_tokamax_backend: bool = False,
94+
is_fsdp_shard_on_exp: bool = False,
8795
) -> tuple[
8896
jnp.ndarray,
8997
tuple[
@@ -114,29 +122,52 @@ def _gmm_fwd(
114122
calibration_method=quantization_rule.weight_calibration_method,
115123
scale_dtype=jnp.float32,
116124
)
117-
118-
out = backend.gmm(
119-
lhs,
120-
rhs,
121-
group_sizes,
122-
preferred_element_type,
123-
tiling,
124-
group_offset,
125-
existing_out,
126-
transpose_rhs=transpose_rhs,
127-
interpret=interpret,
128-
)
125+
# QAG is only supported for following conditions
126+
if use_tokamax_backend:
127+
if quantization_rule and quantization_rule.bwd_qtype:
128+
if (
129+
quantization_rule.weight_calibration_method.startswith("fixed")
130+
and isinstance(rhs, qpl.QArray)
131+
and is_fsdp_shard_on_exp
132+
):
133+
rhs_qvalue = jax.lax.all_gather(rhs.qvalue, "fsdp", axis=0, tiled=True)
134+
rhs = dataclasses.replace(rhs, qvalue=rhs_qvalue)
135+
out = tokamax_backend.gmm(
136+
lhs=lhs,
137+
rhs=rhs,
138+
group_sizes=group_sizes,
139+
precision=jax.lax.Precision.DEFAULT,
140+
out_dtype=preferred_element_type,
141+
tiling=tiling[:3],
142+
group_offset=group_offset,
143+
transpose_rhs=transpose_rhs,
144+
interpret=interpret,
145+
)
146+
else:
147+
out = backend.gmm(
148+
lhs,
149+
rhs,
150+
group_sizes,
151+
preferred_element_type,
152+
tiling[:3],
153+
group_offset,
154+
existing_out,
155+
transpose_rhs=transpose_rhs,
156+
interpret=interpret,
157+
)
129158
return out, (lhs, rhs, group_sizes, group_offset)
130159

131160

132161
def _gmm_bwd(
133162
lhs_dtype: jax.typing.DTypeLike,
134163
rhs_dtype: jax.typing.DTypeLike,
135164
preferred_element_type: jnp.dtype,
136-
tiling: tuple[int, int, int],
165+
tiling: tuple[int, int, int, int, int, int, int, int, int],
137166
transpose_rhs: bool,
138167
interpret: bool,
139168
quantization_rule: qwix.QtRule | None,
169+
use_tokamax_backend: bool,
170+
is_fsdp_shard_on_exp: bool,
140171
residual: tuple[
141172
jnp.ndarray | qpl.QArray,
142173
jnp.ndarray | qpl.QArray,
@@ -187,27 +218,52 @@ def _gmm_bwd(
187218
calibration_method=quantization_rule.bwd_calibration_method,
188219
scale_dtype=jnp.float32,
189220
)
190-
191-
dlhs = backend.gmm(
192-
dlhs_dout,
193-
rhs,
194-
group_sizes,
195-
lhs_dtype,
196-
tiling,
197-
group_offset,
198-
transpose_rhs=not transpose_rhs,
199-
interpret=interpret,
200-
)
201-
drhs = backend.tgmm(
202-
lhs.swapaxes(0, 1),
203-
drhs_dout,
204-
group_sizes,
205-
rhs_dtype,
206-
tiling,
207-
group_offset,
208-
num_actual_groups,
209-
interpret=interpret,
210-
)
221+
if use_tokamax_backend:
222+
dlhs = tokamax_backend.gmm(
223+
lhs=dlhs_dout,
224+
rhs=rhs,
225+
group_sizes=group_sizes,
226+
precision=jax.lax.Precision.DEFAULT,
227+
out_dtype=lhs_dtype,
228+
tiling=tiling[3:6],
229+
group_offset=group_offset,
230+
transpose_rhs=not transpose_rhs,
231+
interpret=interpret,
232+
)
233+
drhs = tokamax_backend.tgmm(
234+
lhs=lhs.swapaxes(0, 1),
235+
rhs=drhs_dout,
236+
group_sizes=group_sizes,
237+
precision=jax.lax.Precision.DEFAULT,
238+
out_dtype=rhs_dtype,
239+
tiling=tiling[-3:],
240+
group_offset=group_offset,
241+
num_actual_groups=num_actual_groups,
242+
interpret=interpret,
243+
)
244+
if quantization_rule and quantization_rule.bwd_qtype and is_fsdp_shard_on_exp:
245+
drhs = jax.lax.psum_scatter(drhs, "fsdp", scatter_dimension=0, tiled=True)
246+
else:
247+
dlhs = backend.gmm(
248+
dlhs_dout,
249+
rhs,
250+
group_sizes,
251+
lhs_dtype,
252+
tiling[3:6],
253+
group_offset,
254+
transpose_rhs=not transpose_rhs,
255+
interpret=interpret,
256+
)
257+
drhs = backend.tgmm(
258+
lhs.swapaxes(0, 1),
259+
drhs_dout,
260+
group_sizes,
261+
rhs_dtype,
262+
tiling[-3:],
263+
group_offset,
264+
num_actual_groups,
265+
interpret=interpret,
266+
)
211267

212268
# NOTE: If the rhs transposition is fused into the forward pass we need to
213269
# return the transpose of the rhs gradient that we calculated above.

0 commit comments

Comments
 (0)