Skip to content

Commit d2fae7a

Browse files
[mxfp8 moe training] update 3d quant colwise scaling kernel to use single input/output TMA descriptors (#3034)
1 parent ae12e42 commit d2fae7a

File tree

3 files changed

+148
-127
lines changed

3 files changed

+148
-127
lines changed

test/prototype/moe_training/test_kernels.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ def test_triton_mx_block_rearrange_2d_K_groups(
325325
reason="MXFP8 requires CUDA capability 10.0 or greater",
326326
)
327327
@pytest.mark.parametrize("E", (1, 2, 4, 8))
328-
@pytest.mark.parametrize("N", (32, 64, 8192))
329-
@pytest.mark.parametrize("K", (32, 64, 8192))
328+
@pytest.mark.parametrize("N", (32, 1536, 5120, 7168, 8192))
329+
@pytest.mark.parametrize("K", (32, 1536, 5120, 7168, 8192))
330330
@pytest.mark.parametrize("input_dtype", (torch.bfloat16,))
331331
@pytest.mark.parametrize("scaling_mode", (ScaleCalculationMode.FLOOR,))
332332
def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
@@ -361,7 +361,6 @@ def test_cuda_mx_dim1_3d_numerics(E, N, K, input_dtype, scaling_mode):
361361
y_d1, s_d1 = mxfp8_cuda.quantize_3d(
362362
x, scale_dim_n=block_size, scaling_mode=scaling_mode_str
363363
)
364-
365364
# Check scales
366365
torch.testing.assert_close(s_d1, s_d1_ref, rtol=0, atol=0)
367366

torchao/csrc/cuda/mx_kernels/mxfp8_extension.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ mxfp8_quantize(torch::Tensor input, bool rowwise, bool colwise,
5454

5555
// Validate inputs
5656
TORCH_CHECK(!rowwise, "rowwise scaling is not supported yet");
57-
check_cuda_tensor(input, "input");
57+
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
58+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
5859
TORCH_CHECK(input.dim() == 2, "input must be 2D");
5960
TORCH_CHECK(input.scalar_type() == torch::kFloat32 ||
6061
input.scalar_type() == torch::kFloat16 ||
@@ -130,6 +131,7 @@ mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n,
130131

131132
// Validate inputs
132133
TORCH_CHECK(input.is_cuda(), "input must be a CUDA tensor");
134+
TORCH_CHECK(input.is_contiguous(), "input must be contiguous");
133135
// Note: We don't check contiguous for 3D as it may have column major strides
134136
TORCH_CHECK(input.dim() == 3, "input must be 3D");
135137
TORCH_CHECK(input.scalar_type() == torch::kFloat32 ||
@@ -148,7 +150,6 @@ mxfp8_quantize_3d(torch::Tensor input, int64_t scale_dim_n,
148150
TORCH_CHECK((N >= 32) && (N % 32 == 0), "N must be a multiple of 32");
149151
TORCH_CHECK((K >= 32) && (K % 32 == 0), "K must be a multiple of 32");
150152

151-
// The kernel should work with any stride pattern - no layout requirements
152153

153154
c10::cuda::CUDAGuard device_guard(input.device());
154155

0 commit comments

Comments
 (0)