Skip to content

Commit b644211

Browse files
[mxfp8 moe training] integrate mxfp8 dim0 cast triton kernel (#3186)
* [mxfp8 moe training] integrate mxfp8 dim0 cast triton kernel * handle nans
1 parent e8471c6 commit b644211

File tree

4 files changed

+86
-84
lines changed

4 files changed

+86
-84
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import torch
99
from torch.nn import functional as F
1010

11-
from torchao.utils import torch_version_at_least
11+
from torchao.utils import is_sm_version, torch_version_at_least
1212

1313
# We need to skip before doing any imports which would use triton, since
1414
# triton won't be available on CPU builds and torch < 2.5
@@ -28,6 +28,7 @@
2828
from torchao.float8.float8_linear import matmul_with_hp_or_float8_args
2929
from torchao.float8.float8_training_tensor import LinearMMConfig
3030
from torchao.float8.float8_utils import compute_error, tensor_to_scale, to_fp8_saturated
31+
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
3132
from torchao.prototype.moe_training.scaled_grouped_mm import (
3233
_emulated_mxfp8_scaled_grouped_mm_2d_2d,
3334
_emulated_mxfp8_scaled_grouped_mm_2d_3d,
@@ -43,10 +44,15 @@
4344

4445

4546
@skip_if_rocm("ROCm not supported")
46-
def test_valid_scaled_grouped_mm_2d_3d():
47+
@pytest.mark.parametrize("m", [131072])
48+
@pytest.mark.parametrize("n", [8192])
49+
@pytest.mark.parametrize("k", [5120])
50+
@pytest.mark.parametrize("n_groups", [1, 2, 4, 8])
51+
def test_valid_scaled_grouped_mm_2d_3d(m, n, k, n_groups):
52+
if not is_sm_version(9, 0):
53+
pytest.skip("Skipping FP8 rowwise test, requires sm90")
4754
out_dtype = torch.bfloat16
4855
device = "cuda"
49-
m, n, k, n_groups = 16, 32, 16, 4
5056
a = torch.randn(
5157
m * n_groups,
5258
k,
@@ -72,6 +78,7 @@ def test_valid_scaled_grouped_mm_2d_3d():
7278
b_t,
7379
offs=offs,
7480
out_dtype=out_dtype,
81+
scaling_type=MoEScalingType.FP8_ROWWISE,
7582
)
7683

7784
# Validate result.
@@ -307,7 +314,7 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
307314

308315
@skip_if_rocm("ROCm not supported")
309316
@pytest.mark.parametrize(
310-
"M,K,N", [(1024, 5120, 8192), (2048, 5120, 8192), (16640, 5120, 8192)]
317+
"M,K,N", [(16640, 5120, 8192), (131072, 5120, 8192), (131072, 8192, 5120)]
311318
)
312319
@pytest.mark.parametrize("num_experts", (2, 4, 8, 16))
313320
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):

test/prototype/mx_formats/test_kernels.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,8 @@ def test_triton_mxfp8_dim1_randn(M, K):
486486

487487
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
488488
@pytest.mark.skipif(
489-
not is_sm_at_least_89(),
490-
reason="float8 in triton requires CUDA capability 8.9 or greater",
489+
not is_sm_at_least_100(),
490+
reason="mxfp8 requires CUDA capability 10.0 or greater",
491491
)
492492
@pytest.mark.parametrize("M", (256, 2048, 131072))
493493
@pytest.mark.parametrize("K", (256, 5120, 7168))
@@ -499,6 +499,20 @@ def test_triton_mxfp8_dim0_randn(M, K):
499499
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
500500

501501

502+
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
503+
@pytest.mark.skipif(
504+
not is_sm_at_least_100(),
505+
reason="mxfp8 requires CUDA capability 10.0 or greater",
506+
)
507+
def test_triton_mxfp8_dim0_zeros():
508+
x = torch.zeros(8192, 5120, dtype=torch.bfloat16, device="cuda")
509+
x_mx_ref, x_s_ref = triton_to_mxfp8_dim0_reference(x, block_size=32)
510+
x_mx_t, x_s_t = triton_to_mxfp8_dim0(x, inner_block_size=32)
511+
assert not x_mx_t.isnan().any(), "quantized tensor should not contain NaNs"
512+
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
513+
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
514+
515+
502516
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
503517
@pytest.mark.parametrize(
504518
"shape",

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
MXGemmKernelChoice,
3333
ScaleCalculationMode,
3434
)
35-
from torchao.prototype.mx_formats.mx_tensor import to_mx
35+
from torchao.prototype.mx_formats.kernels import triton_to_mxfp8_dim0
3636
from torchao.prototype.mx_formats.utils import _to_mxfp8_dim1_kernel_wrapper
3737

3838
logger: logging.Logger = logging.getLogger(__name__)
@@ -303,16 +303,16 @@ def forward(
303303

304304
# A_data shape: (M, K)
305305
# A_scale shape: (M, K//block_size)
306-
A_scale, A_data = to_mx(
307-
A, elem_dtype=torch.float8_e4m3fn, block_size=block_size
306+
A_data, A_scale = triton_to_mxfp8_dim0(
307+
A,
308+
inner_block_size=block_size,
308309
)
309310

310311
# B_data shape: (E, N, K)
311312
# B_scale shape: (E, N, K//block_size)
312-
B_scales, B_data = to_mx(
313+
B_data, B_scales = triton_to_mxfp8_dim0(
313314
B_t.transpose(-2, -1),
314-
elem_dtype=torch.float8_e4m3fn,
315-
block_size=block_size,
315+
inner_block_size=block_size,
316316
)
317317

318318
# Convert scales to blocked format for 2d-3d grouped mm
@@ -351,8 +351,8 @@ def backward(ctx, grad_out: torch.Tensor):
351351

352352
# grad_out_data shape: (M, N)
353353
# grad_out_scale shape: (M, N//block_size)
354-
grad_out_scale, grad_out_data = to_mx(
355-
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
354+
grad_out_data, grad_out_scale = triton_to_mxfp8_dim0(
355+
grad_out, inner_block_size=block_size
356356
)
357357

358358
# Quantize 3d expert weights along N (contraction dimension for next grouped gemm)

torchao/prototype/mx_formats/kernels.py

Lines changed: 51 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -852,8 +852,9 @@ def _triton_calculate_scale(x, axis):
852852
scale_e8m0_unbiased = extracted_pow2.to(tl.bfloat16)
853853

854854
# Clamp to exponents that can be represented in e8m0
855+
# Add 1 to capture NaNs
855856
scale_e8m0_unbiased = tl.clamp(
856-
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias
857+
scale_e8m0_unbiased, -1 * e8m0_exponent_bias, e8m0_exponent_bias + 1
857858
)
858859

859860
# Create the biased e8m0 representation and cast it to 8 bits
@@ -863,15 +864,18 @@ def _triton_calculate_scale(x, axis):
863864
# TODO(future PR): add NaN handling here,
864865
# https://github.com/pytorch/pytorch/pull/100572 will likely be useful to
865866
# get proper NaN propagation working
866-
867867
# Calculate the scale in floating point.
868868
scale_fp = (scale_e8m0_biased.to(tl.int32) << fp32_mbits).to(
869869
tl.float32, bitcast=True
870870
)
871871

872+
fp32_exp_bias = 127.0
873+
fp32_min_normal = tl.exp2(-fp32_exp_bias + 1)
874+
scale_fp = tl.clamp(scale_fp, min=fp32_min_normal, max=float("inf"))
875+
872876
return scale_fp, scale_e8m0_biased
873877

874-
def _get_mxfp8_dim1_kernel_autotune_configs():
878+
def _get_mxfp8_quant_autotune_configs():
875879
# Values to sweep over here were determined by a manual
876880
# sweep over a small set of shapes, it's likely that this
877881
# can be improved in the future.
@@ -890,7 +894,7 @@ def _get_mxfp8_dim1_kernel_autotune_configs():
890894
return results
891895

892896
@triton.autotune(
893-
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
897+
configs=_get_mxfp8_quant_autotune_configs(),
894898
key=["n_cols", "INNER_BLOCK_SIZE"],
895899
)
896900
@triton.jit
@@ -1039,110 +1043,87 @@ def to_mxfp8_dim1_kernel(
10391043
tl.store(col_scale_start_ptr + col_scale_indices, col_scale_e8m0)
10401044

10411045
@triton.autotune(
1042-
configs=_get_mxfp8_dim1_kernel_autotune_configs(),
1043-
key=["n_cols", "INNER_BLOCK_SIZE"],
1046+
configs=_get_mxfp8_quant_autotune_configs(),
1047+
key=["n_cols", "SCALE_BLOCK_SIZE"],
10441048
)
10451049
@triton.jit
10461050
def to_mxfp8_dim0_kernel(
1047-
x_ptr, # pointer to input tensor
1048-
output_ptr, # pointer to output tensor (row-normalized)
1049-
row_scale_ptr, # pointer to store row-wise maximum absolute values
1050-
n_rows, # number of rows in the tensor
1051-
n_cols, # number of columns in the tensor
1051+
x_ptr,
1052+
output_ptr,
1053+
scale_ptr,
1054+
n_rows,
1055+
n_cols,
10521056
ROW_TILE_SIZE: tl.constexpr,
10531057
COL_TILE_SIZE: tl.constexpr,
1054-
INNER_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
1058+
SCALE_BLOCK_SIZE: tl.constexpr, # should be 32 for MX
10551059
):
10561060
"""
10571061
Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1058-
1059-
This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1060-
Instead of transposing and scaling across columns, this kernel scales across rows.
10611062
"""
10621063

1063-
BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1064+
SCALE_BLOCKS_PER_COL_TILE: tl.constexpr = COL_TILE_SIZE // SCALE_BLOCK_SIZE
10641065

10651066
# Get program ID
10661067
pid_row = tl.program_id(0)
10671068
pid_col = tl.program_id(1)
10681069

1069-
# Calculate starting row and column for this tile
10701070
start_row = pid_row * ROW_TILE_SIZE
10711071
start_col = pid_col * COL_TILE_SIZE
1072-
1073-
# Create offsets for the block
1074-
row_offsets = tl.arange(0, ROW_TILE_SIZE)
1075-
col_offsets = tl.arange(0, COL_TILE_SIZE)
1076-
1077-
# Compute global row/col positions
1078-
rows = start_row + row_offsets[:, None]
1079-
cols = start_col + col_offsets[None, :]
1080-
1081-
# Create masks for out-of-bounds accesses
1082-
row_mask = rows < n_rows
1083-
col_mask = cols < n_cols
1084-
mask = row_mask & col_mask
1072+
row_offs = start_row + tl.arange(0, ROW_TILE_SIZE)[:, None]
1073+
col_offs = start_col + tl.arange(0, COL_TILE_SIZE)[None, :]
10851074

10861075
# Compute memory offsets for row-major layout (rows, cols)
1087-
row_major_offsets = (rows * n_cols + cols).to(tl.int32)
1076+
row_major_offsets = (row_offs * n_cols + col_offs).to(tl.int32)
10881077

10891078
# Load the entire block in a single operation
10901079
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1080+
mask = (row_offs < n_rows) & (col_offs < n_cols)
10911081
x_block = tl.load(x_ptr + row_major_offsets, mask=mask)
10921082

10931083
# Reshape to inner tile size for rowwise scaling
1094-
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1084+
# shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, SCALE_BLOCK_SIZE)
10951085
x_block_r = x_block.reshape(
1096-
ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE
1086+
ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE, SCALE_BLOCK_SIZE
10971087
)
10981088

10991089
# Calculate the absolute values of elements in the block
11001090
x_block_abs_r = tl.abs(x_block_r)
11011091

11021092
# Find the maximum absolute value for each row (across columns)
11031093
# shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1104-
row_scale_r, row_scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)
1094+
scale_fp32_r, scale_e8m0_r = _triton_calculate_scale(x_block_abs_r, axis=1)
11051095

11061096
# Divide each row by scale
1107-
# Broadcasting row_scale to match x_block's shape
1108-
# x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1109-
# row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1110-
row_normalized_r = x_block_r / row_scale_r[:, None]
1097+
# Broadcasting scale to match x_block's shape
1098+
# x_block_r shape:
1099+
# (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, SCALE_BLOCK_SIZE)
1100+
# scale[:, None] shape:
1101+
# (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1102+
scaled_data_r = x_block_r / scale_fp32_r[:, None]
11111103

11121104
# Reshape back to original tile size
1113-
row_normalized = tl.reshape(row_normalized_r, ROW_TILE_SIZE, COL_TILE_SIZE)
1114-
1115-
# Quantize to float8
1116-
row_normalized = row_normalized.to(tl.float8e4nv)
1105+
e4m3_data_2d = tl.reshape(scaled_data_r, ROW_TILE_SIZE, COL_TILE_SIZE).to(
1106+
tl.float8e4nv
1107+
)
11171108

11181109
# Store the row-normalized result in row-major format
1119-
tl.store(output_ptr + row_major_offsets, row_normalized, mask=mask)
1120-
1121-
# For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
1122-
# Calculate base offset for this tile's scales
1123-
scales_per_row = n_cols // INNER_BLOCK_SIZE
1110+
tl.store(output_ptr + row_major_offsets, e4m3_data_2d, mask=mask)
11241111

1125-
# Create row and column indices for scale storage
1126-
scale_row_indices = tl.arange(0, ROW_TILE_SIZE)[:, None] + (
1127-
pid_row * ROW_TILE_SIZE
1112+
# Calculate scale offsets to write to
1113+
scales_per_row = n_cols // SCALE_BLOCK_SIZE
1114+
scale_row_indices = (
1115+
pid_row * ROW_TILE_SIZE + tl.arange(0, ROW_TILE_SIZE)[:, None]
11281116
)
1129-
scale_col_indices = tl.arange(0, BLOCKS_PER_COL_TILE)[None, :] + (
1130-
pid_col * BLOCKS_PER_COL_TILE
1117+
scale_col_indices = (
1118+
pid_col * SCALE_BLOCKS_PER_COL_TILE
1119+
+ tl.arange(0, SCALE_BLOCKS_PER_COL_TILE)[None, :]
11311120
)
1132-
1133-
# Calculate linear indices into scale tensor
11341121
scale_offsets = scale_row_indices * scales_per_row + scale_col_indices
11351122

1136-
# Create masks for valid scale indices
1137-
scale_row_mask = scale_row_indices < n_rows
1138-
scale_col_mask = scale_col_indices < scales_per_row
1139-
scale_mask = scale_row_mask & scale_col_mask
1140-
1141-
# Reshape scale values and masks to match the flattened layout
1142-
row_scale_e8m0_2d = row_scale_e8m0_r.reshape(ROW_TILE_SIZE, BLOCKS_PER_COL_TILE)
1143-
1144-
# Store the scales with proper masking
1145-
tl.store(row_scale_ptr + scale_offsets, row_scale_e8m0_2d, mask=scale_mask)
1123+
# Store e8m0 scales
1124+
scale_mask = (scale_row_indices < n_rows) & (scale_col_indices < scales_per_row)
1125+
scale_e8m0_2d = scale_e8m0_r.reshape(ROW_TILE_SIZE, SCALE_BLOCKS_PER_COL_TILE)
1126+
tl.store(scale_ptr + scale_offsets, scale_e8m0_2d, mask=scale_mask)
11461127

11471128
@triton_op("torchao::triton_to_mxfp8_dim0", mutates_args={})
11481129
def triton_to_mxfp8_dim0(
@@ -1155,7 +1136,7 @@ def triton_to_mxfp8_dim0(
11551136
11561137
Output:
11571138
* `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1158-
* `row_scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1139+
* `scale`: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
11591140
"""
11601141
assert x.is_contiguous(), "`x` must be contiguous"
11611142
assert inner_block_size <= 32
@@ -1175,7 +1156,7 @@ def triton_to_mxfp8_dim0(
11751156
)
11761157

11771158
# Create scale tensors for rowwise scaling
1178-
row_scale = torch.empty(
1159+
scale = torch.empty(
11791160
(n_rows, n_cols // inner_block_size),
11801161
dtype=torch.uint8,
11811162
device=x.device,
@@ -1191,19 +1172,19 @@ def triton_to_mxfp8_dim0(
11911172
wrap_triton(to_mxfp8_dim0_kernel)[grid](
11921173
x_ptr=x,
11931174
output_ptr=output,
1194-
row_scale_ptr=row_scale,
1175+
scale_ptr=scale,
11951176
n_rows=n_rows,
11961177
n_cols=n_cols,
1197-
INNER_BLOCK_SIZE=inner_block_size,
1178+
SCALE_BLOCK_SIZE=inner_block_size,
11981179
)
11991180

12001181
# Reshape output back to original shape
12011182
output = output.reshape(x_orig_shape)
1202-
row_scale = row_scale.reshape(*x_orig_shape[:-1], row_scale.shape[-1])
1183+
scale = scale.reshape(*x_orig_shape[:-1], scale.shape[-1])
12031184

12041185
return (
12051186
output,
1206-
row_scale.view(torch.float8_e8m0fnu),
1187+
scale.view(torch.float8_e8m0fnu),
12071188
)
12081189

12091190
@triton_op("torchao::triton_to_mxfp8_dim1", mutates_args={})

0 commit comments

Comments
 (0)