@@ -852,8 +852,9 @@ def _triton_calculate_scale(x, axis):
852852 scale_e8m0_unbiased = extracted_pow2 .to (tl .bfloat16 )
853853
854854 # Clamp to exponents that can be represented in e8m0
855+ # Add 1 to capture NaNs
855856 scale_e8m0_unbiased = tl .clamp (
856- scale_e8m0_unbiased , - 1 * e8m0_exponent_bias , e8m0_exponent_bias
857+ scale_e8m0_unbiased , - 1 * e8m0_exponent_bias , e8m0_exponent_bias + 1
857858 )
858859
859860 # Create the biased e8m0 representation and cast it to 8 bits
@@ -863,15 +864,18 @@ def _triton_calculate_scale(x, axis):
863864 # TODO(future PR): add NaN handling here,
864865 # https://github.com/pytorch/pytorch/pull/100572 will likely be useful to
865866 # get proper NaN propagation working
866-
867867 # Calculate the scale in floating point.
868868 scale_fp = (scale_e8m0_biased .to (tl .int32 ) << fp32_mbits ).to (
869869 tl .float32 , bitcast = True
870870 )
871871
872+ fp32_exp_bias = 127.0
873+ fp32_min_normal = tl .exp2 (- fp32_exp_bias + 1 )
874+ scale_fp = tl .clamp (scale_fp , min = fp32_min_normal , max = float ("inf" ))
875+
872876 return scale_fp , scale_e8m0_biased
873877
874- def _get_mxfp8_dim1_kernel_autotune_configs ():
878+ def _get_mxfp8_quant_autotune_configs ():
875879 # Values to sweep over here were determined by a manual
876880 # sweep over a small set of shapes, it's likely that this
877881 # can be improved in the future.
@@ -890,7 +894,7 @@ def _get_mxfp8_dim1_kernel_autotune_configs():
890894 return results
891895
892896 @triton .autotune (
893- configs = _get_mxfp8_dim1_kernel_autotune_configs (),
897+ configs = _get_mxfp8_quant_autotune_configs (),
894898 key = ["n_cols" , "INNER_BLOCK_SIZE" ],
895899 )
896900 @triton .jit
@@ -1039,110 +1043,87 @@ def to_mxfp8_dim1_kernel(
10391043 tl .store (col_scale_start_ptr + col_scale_indices , col_scale_e8m0 )
10401044
10411045 @triton .autotune (
1042- configs = _get_mxfp8_dim1_kernel_autotune_configs (),
1043- key = ["n_cols" , "INNER_BLOCK_SIZE " ],
1046+ configs = _get_mxfp8_quant_autotune_configs (),
1047+ key = ["n_cols" , "SCALE_BLOCK_SIZE " ],
10441048 )
10451049 @triton .jit
10461050 def to_mxfp8_dim0_kernel (
1047- x_ptr , # pointer to input tensor
1048- output_ptr , # pointer to output tensor (row-normalized)
1049- row_scale_ptr , # pointer to store row-wise maximum absolute values
1050- n_rows , # number of rows in the tensor
1051- n_cols , # number of columns in the tensor
1051+ x_ptr ,
1052+ output_ptr ,
1053+ scale_ptr ,
1054+ n_rows ,
1055+ n_cols ,
10521056 ROW_TILE_SIZE : tl .constexpr ,
10531057 COL_TILE_SIZE : tl .constexpr ,
1054- INNER_BLOCK_SIZE : tl .constexpr , # should be 32 for MX
1058+ SCALE_BLOCK_SIZE : tl .constexpr , # should be 32 for MX
10551059 ):
10561060 """
10571061 Quantizes a high precision tensor to mxfp8 rowwise (1x32 scaling granularity).
1058-
1059- This is the counterpart to to_mxfp8_dim1_kernel which does columnwise quantization.
1060- Instead of transposing and scaling across columns, this kernel scales across rows.
10611062 """
10621063
1063- BLOCKS_PER_COL_TILE : tl .constexpr = COL_TILE_SIZE // INNER_BLOCK_SIZE
1064+ SCALE_BLOCKS_PER_COL_TILE : tl .constexpr = COL_TILE_SIZE // SCALE_BLOCK_SIZE
10641065
10651066 # Get program ID
10661067 pid_row = tl .program_id (0 )
10671068 pid_col = tl .program_id (1 )
10681069
1069- # Calculate starting row and column for this tile
10701070 start_row = pid_row * ROW_TILE_SIZE
10711071 start_col = pid_col * COL_TILE_SIZE
1072-
1073- # Create offsets for the block
1074- row_offsets = tl .arange (0 , ROW_TILE_SIZE )
1075- col_offsets = tl .arange (0 , COL_TILE_SIZE )
1076-
1077- # Compute global row/col positions
1078- rows = start_row + row_offsets [:, None ]
1079- cols = start_col + col_offsets [None , :]
1080-
1081- # Create masks for out-of-bounds accesses
1082- row_mask = rows < n_rows
1083- col_mask = cols < n_cols
1084- mask = row_mask & col_mask
1072+ row_offs = start_row + tl .arange (0 , ROW_TILE_SIZE )[:, None ]
1073+ col_offs = start_col + tl .arange (0 , COL_TILE_SIZE )[None , :]
10851074
10861075 # Compute memory offsets for row-major layout (rows, cols)
1087- row_major_offsets = (rows * n_cols + cols ).to (tl .int32 )
1076+ row_major_offsets = (row_offs * n_cols + col_offs ).to (tl .int32 )
10881077
10891078 # Load the entire block in a single operation
10901079 # shape: (ROW_TILE_SIZE, COL_TILE_SIZE)
1080+ mask = (row_offs < n_rows ) & (col_offs < n_cols )
10911081 x_block = tl .load (x_ptr + row_major_offsets , mask = mask )
10921082
10931083 # Reshape to inner tile size for rowwise scaling
1094- # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE )
1084+ # shape: (ROW_TILE_SIZE, COL_TILE_SIZE) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, SCALE_BLOCK_SIZE )
10951085 x_block_r = x_block .reshape (
1096- ROW_TILE_SIZE * BLOCKS_PER_COL_TILE , INNER_BLOCK_SIZE
1086+ ROW_TILE_SIZE * SCALE_BLOCKS_PER_COL_TILE , SCALE_BLOCK_SIZE
10971087 )
10981088
10991089 # Calculate the absolute values of elements in the block
11001090 x_block_abs_r = tl .abs (x_block_r )
11011091
11021092 # Find the maximum absolute value for each row (across columns)
11031093 # shape: (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,)
1104- row_scale_r , row_scale_e8m0_r = _triton_calculate_scale (x_block_abs_r , axis = 1 )
1094+ scale_fp32_r , scale_e8m0_r = _triton_calculate_scale (x_block_abs_r , axis = 1 )
11051095
11061096 # Divide each row by scale
1107- # Broadcasting row_scale to match x_block's shape
1108- # x_block_r shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, INNER_BLOCK_SIZE)
1109- # row_scale shape (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE,) -> (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1110- row_normalized_r = x_block_r / row_scale_r [:, None ]
1097+ # Broadcasting scale to match x_block's shape
1098+ # x_block_r shape:
1099+ # (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, SCALE_BLOCK_SIZE)
1100+ # scale[:, None] shape:
1101+ # (ROW_TILE_SIZE * BLOCKS_PER_COL_TILE, 1)
1102+ scaled_data_r = x_block_r / scale_fp32_r [:, None ]
11111103
11121104 # Reshape back to original tile size
1113- row_normalized = tl .reshape (row_normalized_r , ROW_TILE_SIZE , COL_TILE_SIZE )
1114-
1115- # Quantize to float8
1116- row_normalized = row_normalized .to (tl .float8e4nv )
1105+ e4m3_data_2d = tl .reshape (scaled_data_r , ROW_TILE_SIZE , COL_TILE_SIZE ).to (
1106+ tl .float8e4nv
1107+ )
11171108
11181109 # Store the row-normalized result in row-major format
1119- tl .store (output_ptr + row_major_offsets , row_normalized , mask = mask )
1120-
1121- # For rowwise quantization, scale tensor has shape (n_rows, n_cols // INNER_BLOCK_SIZE)
1122- # Calculate base offset for this tile's scales
1123- scales_per_row = n_cols // INNER_BLOCK_SIZE
1110+ tl .store (output_ptr + row_major_offsets , e4m3_data_2d , mask = mask )
11241111
1125- # Create row and column indices for scale storage
1126- scale_row_indices = tl .arange (0 , ROW_TILE_SIZE )[:, None ] + (
1127- pid_row * ROW_TILE_SIZE
1112+ # Calculate scale offsets to write to
1113+ scales_per_row = n_cols // SCALE_BLOCK_SIZE
1114+ scale_row_indices = (
1115+ pid_row * ROW_TILE_SIZE + tl .arange (0 , ROW_TILE_SIZE )[:, None ]
11281116 )
1129- scale_col_indices = tl .arange (0 , BLOCKS_PER_COL_TILE )[None , :] + (
1130- pid_col * BLOCKS_PER_COL_TILE
1117+ scale_col_indices = (
1118+ pid_col * SCALE_BLOCKS_PER_COL_TILE
1119+ + tl .arange (0 , SCALE_BLOCKS_PER_COL_TILE )[None , :]
11311120 )
1132-
1133- # Calculate linear indices into scale tensor
11341121 scale_offsets = scale_row_indices * scales_per_row + scale_col_indices
11351122
1136- # Create masks for valid scale indices
1137- scale_row_mask = scale_row_indices < n_rows
1138- scale_col_mask = scale_col_indices < scales_per_row
1139- scale_mask = scale_row_mask & scale_col_mask
1140-
1141- # Reshape scale values and masks to match the flattened layout
1142- row_scale_e8m0_2d = row_scale_e8m0_r .reshape (ROW_TILE_SIZE , BLOCKS_PER_COL_TILE )
1143-
1144- # Store the scales with proper masking
1145- tl .store (row_scale_ptr + scale_offsets , row_scale_e8m0_2d , mask = scale_mask )
1123+ # Store e8m0 scales
1124+ scale_mask = (scale_row_indices < n_rows ) & (scale_col_indices < scales_per_row )
1125+ scale_e8m0_2d = scale_e8m0_r .reshape (ROW_TILE_SIZE , SCALE_BLOCKS_PER_COL_TILE )
1126+ tl .store (scale_ptr + scale_offsets , scale_e8m0_2d , mask = scale_mask )
11461127
11471128 @triton_op ("torchao::triton_to_mxfp8_dim0" , mutates_args = {})
11481129 def triton_to_mxfp8_dim0 (
@@ -1155,7 +1136,7 @@ def triton_to_mxfp8_dim0(
11551136
11561137 Output:
11571138 * `output`: the `float8_e4m3fn` values of `x` cast to mxfp8 across dim0 (rowwise)
1158- * `row_scale `: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
1139+ * `scale `: the `e8m0` values of `x_scale` used to cast `x` to mxfp8 across dim0
11591140 """
11601141 assert x .is_contiguous (), "`x` must be contiguous"
11611142 assert inner_block_size <= 32
@@ -1175,7 +1156,7 @@ def triton_to_mxfp8_dim0(
11751156 )
11761157
11771158 # Create scale tensors for rowwise scaling
1178- row_scale = torch .empty (
1159+ scale = torch .empty (
11791160 (n_rows , n_cols // inner_block_size ),
11801161 dtype = torch .uint8 ,
11811162 device = x .device ,
@@ -1191,19 +1172,19 @@ def triton_to_mxfp8_dim0(
11911172 wrap_triton (to_mxfp8_dim0_kernel )[grid ](
11921173 x_ptr = x ,
11931174 output_ptr = output ,
1194- row_scale_ptr = row_scale ,
1175+ scale_ptr = scale ,
11951176 n_rows = n_rows ,
11961177 n_cols = n_cols ,
1197- INNER_BLOCK_SIZE = inner_block_size ,
1178+ SCALE_BLOCK_SIZE = inner_block_size ,
11981179 )
11991180
12001181 # Reshape output back to original shape
12011182 output = output .reshape (x_orig_shape )
1202- row_scale = row_scale .reshape (* x_orig_shape [:- 1 ], row_scale .shape [- 1 ])
1183+ scale = scale .reshape (* x_orig_shape [:- 1 ], scale .shape [- 1 ])
12031184
12041185 return (
12051186 output ,
1206- row_scale .view (torch .float8_e8m0fnu ),
1187+ scale .view (torch .float8_e8m0fnu ),
12071188 )
12081189
12091190 @triton_op ("torchao::triton_to_mxfp8_dim1" , mutates_args = {})
0 commit comments