From 1b5b465959e67ffb34e914b36af291e0414f620c Mon Sep 17 00:00:00 2001 From: aoli Date: Wed, 10 Jun 2026 14:22:59 +0000 Subject: [PATCH 01/16] gemm tdm exp --- kernels/gemm_fp8fp4_gfx1250.py | 712 +++++++++++++++++++--- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 166 ++++- 2 files changed, 791 insertions(+), 87 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 35ffb352..cb534fc4 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -63,12 +63,61 @@ def _make_tdm_desc(*, early_timeout=False, **kwargs): SCALE_BLOCK = 32 SCALES_PER_WMMA = WMMA_K // SCALE_BLOCK # 4 + +def _vec_chunks(n: int): + """Compile-time split of n contiguous i32 into buffer_load widths (4/2/1). + + Module scope so the ``while`` stays plain Python (the DSL rewrites ``while`` + inside kernel bodies into scf.while). Returns [(start, width), ...]. + """ + chunks = [] + done = 0 + while done < n: + w = 4 if (n - done) >= 4 else (2 if (n - done) >= 2 else 1) + chunks.append((done, w)) + done += w + return chunks + LDS_PAD_A_BYTES = 16 LDS_PAD_D_BYTES = 16 LDS_SEGMENT_BYTES = 64 * 1024 LDS_GFX1250_MAX_BYTES = 5 * LDS_SEGMENT_BYTES +def is_ref_segmented_lds_layout( + *, + data_format, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + num_buffers, + split_k, + wave_specialized_tdm, + use_scale_opsel, +): + """Whether this config uses the reference segmented LDS layout. + + Single source of truth shared by the kernel and the test host-side scale + preshuffle: the buffer_load->VGPR scale path uses the legacy lane-major + coalesced layout here (deep-pipeline) and the general coalesced layout + otherwise. + """ + return ( + data_format == "fp8" + and tile_m == 256 + and tile_n == 256 + and tile_k == 128 + and m_warp == 2 + and n_warp == 2 + and num_buffers == 4 + and split_k == 1 + and wave_specialized_tdm + and not use_scale_opsel + ) + + @functools.lru_cache(maxsize=256) def compile_fp8fp4_gemm( *, @@ -97,6 +146,7 @@ def compile_fp8fp4_gemm( b_streaming: bool = False, scale_load_path: str = "tdm", fp8_schedule: str = "auto", + a_load_path: str = "tdm", ): """Compile an FP4/FP8/A8W4 GEMM kernel with TDM async copy. @@ -142,10 +192,23 @@ def compile_fp8fp4_gemm( raise ValueError(f"fp8_schedule={fp8_schedule!r} is only valid for data_format='fp8'") if fp8_schedule != "auto" and b_streaming: raise ValueError("fp8_schedule cannot be combined with b_streaming=True") + a_load_path_modes = ("tdm", "vgpr", "vgpr_ascale") + if a_load_path not in a_load_path_modes: + raise ValueError(f"a_load_path must be one of {a_load_path_modes}, got {a_load_path!r}") + use_a_vgpr = a_load_path != "tdm" + use_ascale_vgpr = a_load_path == "vgpr_ascale" + if use_a_vgpr and scale_load_path != "tdm": + raise ValueError("a_load_path and scale_load_path cannot both bypass TDM") + if use_a_vgpr and not wave_specialized_tdm: + raise ValueError("a_load_path != 'tdm' requires wave_specialized_tdm=True") + if use_a_vgpr and data_format not in ("fp8", "a8w4"): + raise ValueError("a_load_path != 'tdm' requires data_format='fp8' or 'a8w4'") + if use_a_vgpr and is_ptpc: + raise ValueError("a_load_path != 'tdm' requires scale_mode='mxscale'") effective_expert_sched_mode = bool(expert_sched_mode) - if num_buffers not in (2, 3, 4): - raise ValueError(f"num_buffers must be 2, 3, or 4, got {num_buffers}") + if num_buffers not in (2, 3, 4, 5, 6): + raise ValueError(f"num_buffers must be 2, 3, 4 or 5, got {num_buffers}") if split_k < 1: raise ValueError(f"split_k must be >= 1, got {split_k}") @@ -160,7 +223,22 @@ def compile_fp8fp4_gemm( if block_threads > 1024: raise ValueError(f"block_threads must be <= 1024, got {block_threads}") - _min_wave_spec_warps = 2 if is_ptpc else 4 + # Wave-specialized TDM dedicates one loader wave per TDM tensor. + # Determine which tensors bypass TDM to calculate minimum wave count. + # A data: bypasses TDM when use_a_vgpr + # A_scale: bypasses TDM when use_ascale_vgpr, is_ptpc, or scale_load_path=="vgpr" + # B_scale: bypasses TDM when is_ptpc or scale_load_path=="vgpr" + # Remaining TDM tensors determine wave assignment and min warp count. + _drop_scale_loader_waves = is_ptpc or scale_load_path == "vgpr" or use_ascale_vgpr + _drop_a_loader_wave = use_a_vgpr + if _drop_a_loader_wave and _drop_scale_loader_waves: + _min_wave_spec_warps = 2 # only B + B_scale (or just B for ptpc) + elif _drop_scale_loader_waves: + _min_wave_spec_warps = 2 # only A + B + elif _drop_a_loader_wave: + _min_wave_spec_warps = 4 # B + A_scale + B_scale (wave0 idle) + else: + _min_wave_spec_warps = 4 # A + B + A_scale + B_scale if wave_specialized_tdm and num_warps < _min_wave_spec_warps: raise ValueError(f"wave_specialized_tdm requires at least {_min_wave_spec_warps} waves, got {num_warps}") @@ -234,9 +312,16 @@ def compile_fp8fp4_gemm( _b_frag_loads_per_wn = 2 if is_a8w4 else 4 _a_frag_loads_per_wm = 2 if is_fp4 else 4 - _scale_ds_loads = (wmma_m_rep + 3) // 4 + (b_scale_load_rep + 3) // 4 + # _scale_ds_loads counts scale ds_loads issued alongside A/B fragment loads in + # the streaming schedule (used for the partial-drain s_wait_dscnt bookkeeping). + # The general VGPR scale path holds scales in registers (no ds_load), so it + # contributes zero. Finalized below once use_general_vgpr_scale is known. + _a_scale_ds = 0 if use_ascale_vgpr else (wmma_m_rep + 3) // 4 + _b_scale_ds = (b_scale_load_rep + 3) // 4 + _scale_ds_loads = _a_scale_ds + _b_scale_ds + _a_frag_ds = 0 if use_a_vgpr else wmma_m_rep * _a_frag_loads_per_wm _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _scale_ds_loads - _as_ds_loads = wmma_m_rep * _a_frag_loads_per_wm + _scale_ds_loads + _as_ds_loads = _a_frag_ds + _scale_ds_loads lds_a_stride_bytes = packed_tile_k_a + LDS_PAD_A_BYTES if scale_load_path == "vgpr_ab_split": @@ -245,12 +330,12 @@ def compile_fp8fp4_gemm( if tile_n % 32 != 0: raise ValueError(f"scale_load_path='vgpr_ab_split' requires tile_n divisible by 32, got {tile_n}") - lds_a_data_bytes = tile_m * lds_a_stride_bytes + lds_a_data_bytes = 0 if use_a_vgpr else tile_m * lds_a_stride_bytes lds_b_data_bytes = tile_n * packed_tile_k_b ab_split_a_rows = tile_m // 2 ab_split_b_groups = tile_n // 32 _scale_guard_bytes = 16 - lds_a_scale_bytes = 0 if is_ptpc else tile_m * scale_k_per_tile + _scale_guard_bytes + lds_a_scale_bytes = 0 if (is_ptpc or use_ascale_vgpr) else tile_m * scale_k_per_tile + _scale_guard_bytes lds_b_scale_bytes = 0 if is_ptpc else tile_n * scale_k_per_tile + _scale_guard_bytes interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile interleaved_scale_cols_b = b_scale_load_rep * scale_k_per_tile @@ -297,30 +382,51 @@ def _align_up(value: int, align: int) -> int: ), ) - use_ref_segmented_lds_layout = ( - data_format == "fp8" - and tile_m == 256 - and tile_n == 256 - and tile_k == 128 - and m_warp == 2 - and n_warp == 2 - and num_buffers == 4 - and split_k == 1 - and wave_specialized_tdm - and not use_scale_opsel + use_ref_segmented_lds_layout = is_ref_segmented_lds_layout( + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=num_buffers, + split_k=split_k, + wave_specialized_tdm=wave_specialized_tdm, + use_scale_opsel=use_scale_opsel, ) # "vgpr"/"vgpr_ab_split": load scale global->VGPR via buffer_load, bypassing - # TDM+LDS entirely. Requires the reference segmented LDS layout. + # TDM+LDS entirely. Two layouts coexist: the reference segmented deep-pipeline + # path (use_ref_segmented_lds_layout, fp8 256x256x128) and the general + # coalesced path (use_general_vgpr_scale) used by the row-major streaming + # schedule (a8w4/fp8, arbitrary warp_tile / tile_k). The full schedule+format + # eligibility check runs once the compute schedule is known (below). use_buffer_vgpr_scale = scale_load_path in ("vgpr", "vgpr_ab_split") - if use_buffer_vgpr_scale and not use_ref_segmented_lds_layout: - raise ValueError( - f"scale_load_path={scale_load_path!r} requires the reference segmented " - "LDS layout (not active for this tile/format configuration)" - ) - # Scale prefetch depth (K-tiles ahead) for the buffer->VGPR path. D=1 is the - # sweet spot; D=2 doubles scale VGPRs -> spill + ~18% regression. - _bvs_D = max(1, int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", "1"))) + use_general_vgpr_scale = use_buffer_vgpr_scale and not use_ref_segmented_lds_layout + if use_general_vgpr_scale and scale_load_path == "vgpr_ab_split": + raise ValueError("scale_load_path='vgpr_ab_split' requires the reference segmented LDS layout") + if use_general_vgpr_scale: + # General VGPR scales live in registers: no scale ds_loads to wait on. + _scale_ds_loads = 0 + _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _as_ds_loads = _a_frag_ds + # Scale prefetch depth (K-tiles ahead) for the buffer->VGPR path. The + # ref-segmented deep-pipeline path is VGPR-bound (D=2 doubles scale VGPRs -> + # spill + ~18% regression), so it stays at D=1. The general coalesced path + # (thin row-major streaming tiles) has spare VGPRs, so it prefetches deeper + # to overlap each scale buffer_load with an earlier tile's TDM wait. + _bvs_D_default = 3 if use_general_vgpr_scale else 1 + _bvs_D = max(1, int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", str(_bvs_D_default)))) + # FLYDSL_BUFFER_VGPR_SCALE_PRELOAD=1 (experiment, small-M only): switch the + # general vgpr scale path to the b128 layout (ks innermost) so scales load + # with wide buffer_load_b128 instead of many b32, and -- when the whole K + # runs in the tail (loop_iters == 0) -- load them ALL up front into VGPRs + # (preload) instead of a per-tile ring. The b128 layout must match the host + # (test preshuffle_scale_for_load_path). _bvs_b128 is independent of + # loop_iters so host<->kernel layout always agree; _bvs_preload is the + # all-up-front variant. Full-K scale must fit in VGPRs -- NOT general. + _bvs_b128 = use_general_vgpr_scale and bool(int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_PRELOAD", "0"))) + _bvs_preload = _bvs_b128 and loop_iters == 0 # ab_half_split: repurpose the (under "vgpr") idle scale waves 2,3 as the # second halves of A/B, so all 4 waves share the A/B TDM (wave0=A0, wave1=B0, # wave2=A1, wave3=B1). Measured wall-neutral. @@ -328,6 +434,34 @@ def _align_up(value: int, align: int) -> int: # The buffer_load->VGPR scale ring is built only when scale is actually loaded. _bvs_active = use_buffer_vgpr_scale + # A VGPR prefetch: buffer_load A data directly into VGPRs, bypassing TDM/LDS. + # Per-tile A frag count: k_wmma_steps * wmma_m_rep vec<16xi32> (or vec<8xi32> for fp4). + _avr_active = use_a_vgpr + _avr_D = max(1, int(os.environ.get("FLYDSL_A_VGPR_DEPTH", str(num_buffers)))) + _avr_frag_width = 8 if is_fp4 else 16 # vec elements per A fragment + _avr_frags_per_tile = k_wmma_steps * wmma_m_rep + # When vgpr_ascale, A_scale is bundled with the A ring. + _avr_ascale_per_tile = k_wmma_steps * wmma_m_rep if use_ascale_vgpr else 0 + + # B N-split (env FLYDSL_B_KSPLIT, default off): on the VGPR A path wave0 is + # freed from loading A, so wave0 and wave1 co-load B — wave0 the first half of + # the tile's N-groups, wave1 the second — issued in parallel under the normal + # single unified barrier. N is the outer LDS dim, so the two halves write the + # same contiguous tile the full-B descriptor would; the LDS layout, the compute + # loop, and the fence are all unchanged. The ONLY delta vs the plain vgpr path + # is: wave0 is activated and the B load is split across waves 0/1 by N-group. + # Valid only when A is VGPR with TDM scales (wave0 otherwise idle, waves + # 1/2/3 = B/A_scale/B_scale) and tile_n splits into two equal N-group halves. + _b_nsplit = ( + bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) + and use_a_vgpr + and not use_ascale_vgpr + and not is_ptpc + and scale_load_path == "tdm" + and wave_specialized_tdm + and tile_n % 32 == 0 + ) + if use_ref_segmented_lds_layout: # The A/B data pools are no longer packed into the same per-stage # 64KiB segment window. Scale pools keep the reference 0x800 stride so @@ -470,8 +604,28 @@ def _pick_compute_schedule_kind(): use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING - if use_buffer_vgpr_scale and not use_fp8_deep_pipeline_schedule: - raise ValueError(f"scale_load_path={scale_load_path!r} is only supported with the FP8 deep-pipeline schedule") + if use_buffer_vgpr_scale: + # General coalesced VGPR scale is supported on the row-major streaming + # schedule (mxscale fp8/a8w4, no scale_opsel, wave-specialized TDM); the + # ref-segmented path keeps the FP8 deep-pipeline schedule. + _vgpr_streaming_ok = ( + use_general_vgpr_scale + and compute_schedule_kind == COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING + and data_format in ("fp8", "a8w4") + and not is_ptpc + and not use_scale_opsel + and wave_specialized_tdm + ) + if not (use_fp8_deep_pipeline_schedule or _vgpr_streaming_ok): + raise ValueError( + f"scale_load_path={scale_load_path!r} requires the FP8 deep-pipeline schedule, or " + "the row-major streaming schedule with mxscale fp8/a8w4, no scale_opsel, and " + "wave_specialized_tdm" + ) + if use_a_vgpr and compute_schedule_kind != COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING: + raise ValueError( + f"a_load_path={a_load_path!r} requires the row-major streaming schedule" + ) use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) @@ -558,7 +712,7 @@ def kernel_mxscale_gemm( layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (WAVE_SIZE, m_warp * WAVE_SIZE, 16, 1)) else: layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) - thr_coord = idx2crd(fx.Int32(tx), layout_thr) + thr_coord = idx2crd(tx, layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), @@ -570,18 +724,34 @@ def kernel_mxscale_gemm( warp_n_base = wave_n_idx * arith.index(warp_tile_n) if const_expr(use_buffer_vgpr_scale): - # Direct global->VGPR scale load (no TDM/LDS). Coalesced lane-major - # host layout [M_block(128), K_tile, group(2), lane16(16), 4 i32], so - # each buffer_load_b128's 16 lanes read 256 contiguous bytes: - # i32_off(group) = (mb*Kt + kt)*128 + group*64 + lane16*4 + # Direct global->VGPR scale load (no TDM/LDS). One K-tile's scales for + # all reps land in VGPRs; the loop carries a small prefetch ring. + # + # Two host layouts share the same prefetch/consume plumbing: + # - deep-pipeline (use_ref_segmented_lds_layout): lane-major + # [M_block(128), K_tile, group(2), lane16(16), 4 i32], single + # k_wmma_step, no lane_kgrp shift. i32_off = (mb*Kt+kt)*128 + + # group*64 + lane16*4. + # - general (use_general_vgpr_scale): coalesced + # [mb, K_tile, k_wmma_step, rep_group, lane32(32), j(4), spw] with the + # a8w4 lane_kgrp shift baked per physical lane (lane32 = kgrp*16+L), so + # the address is kgrp-agnostic. i32_off = + # (((mb*Kt+kt)*KS+ks)*NG + grp)*32*4 + lane32*4 + j. Matches the TDM + # path value-for-value (see flydsl_fp8_perf/verify_vgpr_scale_layout.py). _bvs_a_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) _bvs_b_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, max_size=False) _bvs_Kt = K // tile_k # total K-tiles - _bvs_mb_a = blk_m // arith.index(128) + wave_m_idx - _bvs_mb_b = blk_n // arith.index(128) + wave_n_idx + _bvs_mb_a = blk_m // arith.index(warp_tile_m) + wave_m_idx + _bvs_mb_b = blk_n // arith.index(warp_tile_n) + wave_n_idx _bvs_lane4 = lane16 * arith.index(4) + _gvs_lane32 = lane_kgrp * arith.index(16) + lane16 + # Per-tile VGPR scale count (flat, ordered [k_wmma_step][rep]); reduces to + # `rep` for the deep-pipeline path (k_wmma_steps == 1). + _vs_tile_a = k_wmma_steps * wmma_m_rep + _vs_tile_b = k_wmma_steps * b_scale_load_rep def _bvs_load_scales(rsrc, mb, rep, k_base): + # Deep-pipeline lane-major layout (k_wmma_steps == 1). kt = k_base // arith.index(tile_k) tile_i32 = (mb * arith.index(_bvs_Kt) + kt) * arith.index(128) vals = [] @@ -592,10 +762,89 @@ def _bvs_load_scales(rsrc, mb, rep, k_base): vals.append(v[j]) return vals - def _bvs_prefetch(k_base): - # Issue scale buffer_load for one K-tile; returns (a[8], b[8]) VGPR. - a = _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base) - b = _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base) + def _gvs_load_scales(rsrc, mb, rep, k_base): + # General coalesced layout: k_wmma_steps * rep i32, flat [ks][rep]. + # The per-tile K term (kt) goes in the scalar soffset, NOT the + # per-lane voffset VGPR: that keeps the voffset identical across + # prefetched K-tiles, so the backend CSEs it to one address + # register instead of recomputing it per tile (which forced + # s_wait_xcnt address-drain serialization and left the scale + # buffer_loads only partially hidden). The within-tile ks/grp + # delta stays in voffset, where it folds into the buffer + # instruction's immediate offset. + kt = k_base // arith.index(tile_k) + _NG = (rep + 3) // 4 + _S = k_wmma_steps * _NG * 32 * 4 + base_i32 = mb * arith.index(_bvs_Kt) * arith.index(_S) + kt_soff = arith.index_cast(T.i32, kt * arith.index(_S) * arith.index(4)) + vals = [] + for ks in range_constexpr(k_wmma_steps): + for grp in range_constexpr(_NG): + grp_i32 = base_i32 + arith.index((ks * _NG + grp) * 32 * 4) + _gvs_lane32 * arith.index(4) + off = arith.index_cast(T.i32, grp_i32) + v = fx.Vector( + buffer_ops.buffer_load(rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff) + ) + for j in range_constexpr(4): + if const_expr(grp * 4 + j < rep): + vals.append(v[j]) + return vals + + def _load_contig_i32(rsrc, base_idx, n, soff): + # Load n contiguous i32 from base_idx (element units) via the widest + # buffer_load chunks (b128/b64/b32). Returns a list of n values. + out = [None] * n + _chunks = _vec_chunks(n) + for _ci in range_constexpr(len(_chunks)): + start, w = _chunks[_ci] + off = arith.index_cast(T.i32, base_idx + arith.index(start)) + r = buffer_ops.buffer_load(rsrc, off, vec_width=w, dtype=T.i32, soffset_bytes=soff) + if const_expr(w == 1): + out[start] = r + else: + rv = fx.Vector(r) + for c in range_constexpr(w): + out[start + c] = rv[c] + return out + + def _gvs_load_scales_b128(rsrc, mb, rep, k_base, in_voffset): + # b128 layout [kt, grp, j, lane32, ks, spw]: a lane's KS scale-words + # are contiguous, so each (rep,lane) is one wide load (b128 for + # KS==4). Returns vals flat [ks][rep] to match _scales_for_emit. + # in_voffset=True bakes kt into the voffset VGPR (preload: all tiles' + # voffsets live at once -> distinct regs, no s0 reuse). in_voffset= + # False keeps the per-lane voffset constant across tiles and puts kt + # in the scalar soffset (per-tile ring). + kt = k_base // arith.index(tile_k) + KS = k_wmma_steps + _NG = (rep + 3) // 4 + _S128 = _NG * 4 * 32 * KS + if const_expr(in_voffset): + base = (mb * arith.index(_bvs_Kt) + kt) * arith.index(_S128) + soff = None + else: + base = mb * arith.index(_bvs_Kt) * arith.index(_S128) + soff = arith.index_cast(T.i32, kt * arith.index(_S128) * arith.index(4)) + vals = [None] * (KS * rep) + for _rep in range_constexpr(rep): + rep_off = base + arith.index(_rep * 32 * KS) + _gvs_lane32 * arith.index(KS) + ks_vals = _load_contig_i32(rsrc, rep_off, KS, soff) + for ks in range_constexpr(KS): + vals[ks * rep + _rep] = ks_vals[ks] + return vals + + def _bvs_prefetch(k_base, preload=False): + # Issue scale buffer_load for one K-tile; returns (a, b) VGPR lists, + # each flat [k_wmma_step][rep] (length _vs_tile_a / _vs_tile_b). + if const_expr(_bvs_b128): + a = _gvs_load_scales_b128(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base, preload) + b = _gvs_load_scales_b128(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base, preload) + elif const_expr(use_general_vgpr_scale): + a = _gvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base) + b = _gvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base) + else: + a = _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base) + b = _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base) return a, b m_idx = fx.Index(i32_m) @@ -605,6 +854,89 @@ def _bvs_prefetch(k_base): lda_packed = fx.Index(i32_lda) else: lda_packed = fx.Index(i32_lda) / arith.index(PACK_FACTOR_A) + + if const_expr(_avr_active): + # arg_a is dynamically shaped (runtime M), so max_size=False would fall + # back to a max-sized descriptor and disable hardware OOB. Clip + # num_records to M*lda bytes (fp8 A: 1 byte/elem) so rows >= M read 0 + # for non-tile-aligned M, matching the TDM path. + _avr_a_rsrc = buffer_ops.create_buffer_resource(arg_a, num_records_bytes=m_idx * lda_packed) + # buffer_load voffset is in i32 (4-byte) elements (it multiplies by 4 + # internally), so the byte-domain row/K/lane offsets must be divided by + # 4. k_base goes in the byte-domain soffset and stays as-is. + _avr_lda_i32 = lda_packed // arith.index(4) + _avr_lane_kgrp_off = lane_kgrp * arith.index(4) # 16 bytes / 4 + if const_expr(use_ascale_vgpr): + _avr_as_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) + _avr_as_Kt = K // tile_k + _avr_as_mb = blk_m // arith.index(warp_tile_m) + wave_m_idx + _avr_as_lane32 = lane_kgrp * arith.index(16) + lane16 + + def _avr_load_a_tile(k_base): + """Issue buffer_load_b128 for one K-tile of A data. + + Returns list of vec<_avr_frag_width xi32> IR values, length + k_wmma_steps * wmma_m_rep (indexed [ks * wmma_m_rep + wm]). + The K-tile offset goes in soffset (scalar, same for all lanes); + per-lane row/kgrp address stays in voffset (reused across tiles). + """ + kt_soff = arith.index_cast(T.i32, k_base) + frags = [] + for ks in range_constexpr(k_wmma_steps): + for wm in range_constexpr(wmma_m_rep): + row = warp_m_base + arith.index(wm * WMMA_M) + lane16 + row_off = row * _avr_lda_i32 + _avr_lane_kgrp_off + arith.index(ks * WMMA_K // PACK_FACTOR_A // 4) + loads = [] + for i in range_constexpr(DS_LOADS_PER_A_FRAG): + off = arith.index_cast(T.i32, row_off + arith.index(i * 8)) + v = fx.Vector(buffer_ops.buffer_load( + _avr_a_rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff + )) + loads.append(v) + if const_expr(DS_LOADS_PER_A_FRAG == 2): + frag = loads[0].shuffle(loads[1], list(range(8))) + else: + v01 = loads[0].shuffle(loads[1], list(range(8))) + v23 = loads[2].shuffle(loads[3], list(range(8))) + frag = v01.shuffle(v23, list(range(16))) + frags.append(frag.ir_value()) + return frags + + def _avr_load_ascale(k_base): + """Load A_scale for one K-tile via buffer_load (coalesced layout).""" + kt = k_base // arith.index(tile_k) + _NG = (wmma_m_rep + 3) // 4 + _S = k_wmma_steps * _NG * 32 * 4 + base_i32 = _avr_as_mb * arith.index(_avr_as_Kt) * arith.index(_S) + kt_soff = arith.index_cast(T.i32, kt * arith.index(_S) * arith.index(4)) + vals = [] + for ks in range_constexpr(k_wmma_steps): + for grp in range_constexpr(_NG): + grp_i32 = base_i32 + arith.index((ks * _NG + grp) * 32 * 4) + _avr_as_lane32 * arith.index(4) + off = arith.index_cast(T.i32, grp_i32) + v = fx.Vector(buffer_ops.buffer_load( + _avr_as_rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff + )) + for j in range_constexpr(4): + if const_expr(grp * 4 + j < wmma_m_rep): + vals.append(v[j]) + return vals + + def _avr_prefetch(k_base): + """Issue A data (and optionally A_scale) prefetch for one K-tile. + + Returns (a_frags, a_scales) where a_frags is a list of + vec IR values and a_scales is a list of i32 (or empty). + """ + a_frags = _avr_load_a_tile(k_base) + if const_expr(use_ascale_vgpr): + a_scales = _avr_load_ascale(k_base) + else: + a_scales = [] + return a_frags, a_scales + + _a_vgpr_box = [None] + _a_vgpr_ascale_box = [None] n_stride = fx.Index(i32_ldc) c_nrec = m_idx * n_stride * arith.index(elem_bytes_d) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) @@ -756,8 +1088,8 @@ def _precompute_a_lane_bases(lds_ptr): bases.append(base) return lds_ptr, bases - def load_a_frag(lds_buffer, a_lane_base, ks): - """Load one A-fragment from LDS. + def load_a_frag(lds_buffer, a_lane_base, ks, wm=0): + """Load one A-fragment from LDS (or VGPR box when use_a_vgpr). FP4: vec<8xi32> via 2 × ds_load_b128 (32 bytes per lane). FP8/A8W4: vec<16xi32> via 4 × ds_load_b128 (64 bytes per lane). @@ -765,6 +1097,8 @@ def load_a_frag(lds_buffer, a_lane_base, ks): kgrp0 reads bytes [0:15],[32:47],[64:79],[96:111] (stride=32) kgrp1 reads bytes [16:31],[48:63],[80:95],[112:127] (stride=32) """ + if const_expr(use_a_vgpr): + return _a_vgpr_box[0][ks * wmma_m_rep + wm] k_byte_off = arith.index(ks * WMMA_K // PACK_FACTOR_A) byte_off = a_lane_base + k_byte_off v0 = fx.Vector(lds_load_b128_raw(lds_buffer, byte_off)) @@ -898,6 +1232,12 @@ def load_scale_slice_b128(lds_buffer, scale_base, full_reps, rep_start, rep_coun results.append(vecs[i // 4][i % 4]) return results + # Holds the current tile's prefetched VGPR scales (a_flat, b_flat), each + # ordered [k_wmma_step][rep]. compute_tile sets it before emitting; the + # general-vgpr branch of _scales_for_emit slices it per K-subtile. Set-then- + # consume is sequential at emit time (same pattern as epi_addrs_box). + _vgpr_scale_box = [None] + def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): """Load both scale tensors and apply op_sel downsampling per format. @@ -906,6 +1246,18 @@ def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): """ if const_expr(is_ptpc): return None, None + if const_expr(use_general_vgpr_scale): + # VGPR scales (no op_sel in this path); slice the prefetched ring. + pf_a, pf_b = _vgpr_scale_box[0] + a = pf_a[ks * wmma_m_rep : (ks + 1) * wmma_m_rep] + b = pf_b[ks * b_scale_load_rep : (ks + 1) * b_scale_load_rep] + return a, b + if const_expr(use_ascale_vgpr): + # A_scale from VGPR (bundled with A prefetch ring), B_scale from LDS. + a = _a_vgpr_ascale_box[0][ks * wmma_m_rep : (ks + 1) * wmma_m_rep] + b_all = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) + b = b_all + return a, b a_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) b_all = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) if const_expr(use_scale_opsel): @@ -921,7 +1273,7 @@ def _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, ks): return b_frags, b_scales, a_scales def _load_a_and_scales(a_buf, a_bases, as_buf, as_bases, bs_buf, bs_bases, ks): - a_frags = [load_a_frag(a_buf, a_bases[wm], ks) for wm in range_constexpr(wmma_m_rep)] + a_frags = [load_a_frag(a_buf, a_bases[wm], ks, wm=wm) for wm in range_constexpr(wmma_m_rep)] a_scales, b_scales = _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks) return a_frags, a_scales, b_scales @@ -1030,7 +1382,7 @@ def _emit_rows(start_wm, a_frags): wn = (wmma_n_rep - 1 - wn_raw) if (wm % 2 == 1) else wn_raw _emit_wmma(accs, wm, wn, a_frags[frag_i], b_frags[wn], a_scales, b_scales) - a_frags_front = [load_a_frag(a_buf, a_bases[wm], ks) for wm in range_constexpr(_front_wm)] + a_frags_front = [load_a_frag(a_buf, a_bases[wm], ks, wm=wm) for wm in range_constexpr(_front_wm)] _use_partial_drain = next_bs_info is not None and _front_wm * wmma_n_rep >= 4 @@ -1048,7 +1400,7 @@ def _emit_rows(start_wm, a_frags): mid_compute_callback() if const_expr(_back_wm > 0): - a_frags_back = [load_a_frag(a_buf, a_bases[_front_wm + h], ks) for h in range_constexpr(_back_wm)] + a_frags_back = [load_a_frag(a_buf, a_bases[_front_wm + h], ks, wm=_front_wm + h) for h in range_constexpr(_back_wm)] _back_drain = _bs_ds_loads if _use_partial_drain else 0 rocdl.s_wait_dscnt(_back_drain) _emit_rows(_front_wm, a_frags_back) @@ -1115,8 +1467,35 @@ def _emit_cols(start_wn, b_frags_chunk): return accs # ── Compute on one LDS buffer ── - def compute_tile(accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None, mid_compute_callback=None): + def compute_tile( + accs_in, + lds_a, + lds_b, + lds_as, + lds_bs, + emit_filler=None, + mid_compute_callback=None, + scale_k_base=None, + pf_a_scales=None, + pf_b_scales=None, + pf_a_data=None, + pf_a_data_scales=None, + ): current_accs = list(accs_in) + if const_expr(use_a_vgpr): + _a_vgpr_box[0] = pf_a_data + if const_expr(use_ascale_vgpr): + _a_vgpr_ascale_box[0] = pf_a_data_scales + if const_expr(use_general_vgpr_scale): + # Scales come from VGPR: use the loop-prefetched ring when provided, + # else issue the buffer_loads inline (tail path) for scale_k_base. + if const_expr(pf_a_scales is not None): + _vgpr_scale_box[0] = (pf_a_scales, pf_b_scales) + else: + # Inline tail load: barrier so the buffer_loads can't be hoisted + # above the caller's pipeline fence (mirrors the main-loop path). + rocdl.sched_barrier(0) + _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) @@ -1724,15 +2103,17 @@ def hot_loop_scheduler(): _half_wm = wmma_m_rep // 2 _half_wmma = _half_wm * wmma_n_rep _b_loads_per_frag = 2 if is_a8w4 else 4 - _scale_dsrd = 0 if is_ptpc else 2 + # No scale ds_loads when scales are in registers (PTPC epilogue / VGPR). + _scale_dsrd = 0 if (is_ptpc or use_general_vgpr_scale) else 2 + _a_half_dsrd = 0 if use_a_vgpr else _half_wm * DS_LOADS_PER_A_FRAG for _ks in range_constexpr(k_wmma_steps): if const_expr(_ks == 0): - rocdl.sched_dsrd(wmma_n_rep * _b_loads_per_frag + _scale_dsrd + _half_wm * DS_LOADS_PER_A_FRAG) + rocdl.sched_dsrd(wmma_n_rep * _b_loads_per_frag + _scale_dsrd + _a_half_dsrd) else: - rocdl.sched_dsrd(_half_wm * DS_LOADS_PER_A_FRAG) + rocdl.sched_dsrd(_a_half_dsrd) rocdl.sched_mfma(_half_wmma) - rocdl.sched_dsrd(_half_wm * DS_LOADS_PER_A_FRAG) + rocdl.sched_dsrd(_a_half_dsrd) rocdl.sched_mfma(_half_wmma) if const_expr(_ks < k_wmma_steps - 1): rocdl.sched_dsrd(wmma_n_rep * _b_loads_per_frag + _scale_dsrd) @@ -1840,6 +2221,8 @@ def compute_tile_scheduled( scale_k_base=None, pf_a_scales=None, pf_b_scales=None, + pf_a_data=None, + pf_a_data_scales=None, ): if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): return compute_tile_b_streaming( @@ -1895,6 +2278,11 @@ def compute_tile_scheduled( lds_bs, emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, + scale_k_base=scale_k_base, + pf_a_scales=pf_a_scales, + pf_b_scales=pf_b_scales, + pf_a_data=pf_a_data, + pf_a_data_scales=pf_a_data_scales, ) def hot_loop_scheduler_b_streaming(): @@ -2206,21 +2594,39 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): stages_as_lds_addr = [] stages_bs_lds_addr = [] for i in range_constexpr(num_buffers): - stages_a_lds_addr.append(_dg0_lane(make_desc_a(stages_a_mem[i], arith.index(0)), 1)) + if const_expr(not use_a_vgpr): + stages_a_lds_addr.append(_dg0_lane(make_desc_a(stages_a_mem[i], arith.index(0)), 1)) stages_b_lds_addr.append(_dg0_lane(make_desc_b(stages_b_mem[i], arith.index(0)), 1)) - if const_expr(not is_ptpc): + if const_expr(not is_ptpc and not use_ascale_vgpr): stages_as_lds_addr.append(_dg0_lane(make_desc_as(stages_as_mem[i], arith.index(0)), 1)) + if const_expr(not is_ptpc): stages_bs_lds_addr.append(_dg0_lane(make_desc_bs(stages_bs_mem[i], arith.index(0)), 1)) - desc_a_init = make_desc_a(stages_a_mem[0], split_k_base) + if const_expr(not use_a_vgpr): + desc_a_init = make_desc_a(stages_a_mem[0], split_k_base) desc_b_init = make_desc_b(stages_b_mem[0], split_k_base) - if const_expr(is_ptpc): - # No scale TDM for PTPC: alias the scale descriptors/addresses to A/B. - # Scale waves are predicated off, so these selections are never issued. - stages_as_lds_addr = stages_a_lds_addr - stages_bs_lds_addr = stages_b_lds_addr - desc_as_init = desc_a_init - desc_bs_init = desc_b_init + if const_expr(is_ptpc or use_ascale_vgpr): + # Alias unused A/A_scale slots to B (predicated off waves). + if const_expr(not stages_a_lds_addr): + stages_a_lds_addr = stages_b_lds_addr + if const_expr(not stages_as_lds_addr): + stages_as_lds_addr = stages_b_lds_addr + if const_expr(not stages_bs_lds_addr): + stages_bs_lds_addr = stages_b_lds_addr + if const_expr(use_a_vgpr): + desc_a_init = desc_b_init + if const_expr(is_ptpc): + desc_as_init = desc_b_init + desc_bs_init = desc_b_init + else: + desc_as_init = desc_b_init + desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) + elif const_expr(use_a_vgpr): + # A via VGPR, scales via TDM: alias A slot to B (wave0 predicated off). + stages_a_lds_addr = stages_b_lds_addr + desc_a_init = desc_b_init + desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) + desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) else: desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) @@ -2240,6 +2646,20 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): desc_a1_init = make_desc_a_half(stages_a_mem[0], split_k_base, 1) desc_b1_init = make_desc_b_half(stages_b_mem[0], split_k_base, 1) + if const_expr(_b_nsplit): + # N-direction B halves: wave0 -> N-groups [0:tile_n//32], wave1 -> + # [tile_n//32:tile_n//16]. N is the outer LDS dim, so the two halves + # write contiguous blocks that together equal the full-B tile layout + # (make_desc_b_half bakes the N offset into both the global offset and + # lds_byte_offset). load_b_frag reads are unchanged. + nstages_b0_lds_addr = [] + nstages_b1_lds_addr = [] + for i in range_constexpr(num_buffers): + nstages_b0_lds_addr.append(_dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 0), 1)) + nstages_b1_lds_addr.append(_dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 1), 1)) + desc_bn0_init = make_desc_b_half(stages_b_mem[0], split_k_base, 0) + desc_bn1_init = make_desc_b_half(stages_b_mem[0], split_k_base, 1) + adv_a_i32 = fx.Int32(tile_k // PACK_FACTOR_A) adv_b_i32 = fx.Int32(packed_tile_k_b * 16) adv_as_i32 = fx.Int32(tile_k // SCALE_BLOCK * wmma_m_rep) @@ -2247,9 +2667,17 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): - _drop_scale_waves = is_ptpc or (use_buffer_vgpr_scale and not use_ab_half_split) - _active_wave_limit = 2 if _drop_scale_waves else 4 - active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) + _drop_scale_waves = is_ptpc or (use_buffer_vgpr_scale and not use_ab_half_split) or use_ascale_vgpr + if const_expr(_b_nsplit): + # B N-split: all 4 waves load (wave0=B N-half0, wave1=B N-half1, + # wave2=A_scale, wave3=B_scale). + active_pred_const = pred_const + elif const_expr(_drop_a_loader_wave and not _drop_scale_waves): + # A via VGPR, scales via TDM: wave0 (A) idle, waves 1,2,3 active + active_pred_const = arith.select(tdm_wave_id >= fx.Int32(1), fx.Int32(1), fx.Int32(0)) + else: + _active_wave_limit = 2 if _drop_scale_waves else 4 + active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) def _select4(values): return _select_wave_tdm_value(values[0], values[1], values[2], values[3]) @@ -2286,6 +2714,22 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): (desc_a0_init, desc_b0_init, desc_a1_init, desc_b1_init), (adv_a_i32, adv_b_i32, adv_a_i32, adv_b_i32), ) + elif const_expr(_b_nsplit): + # B N-split: wave0=B N-half0, wave1=B N-half1 (both adv by a full + # tile_k; the N offset is constant), wave2=A_scale, wave3=B_scale. + active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( + (nstages_b0_lds_addr, nstages_b1_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), + (desc_bn0_init, desc_bn1_init, desc_as_init, desc_bs_init), + (adv_b_i32, adv_b_i32, adv_as_i32, adv_bs_i32), + ) + elif const_expr(wave_specialized_tdm and use_ascale_vgpr): + # A + A_scale via VGPR: only B (wave0) and B_scale (wave1) need TDM. + # Remap: slot0=B, slot1=B_scale, slots 2,3 aliased (predicated off). + active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( + (stages_b_lds_addr, stages_bs_lds_addr, stages_b_lds_addr, stages_bs_lds_addr), + (desc_b_init, desc_bs_init, desc_b_init, desc_bs_init), + (adv_b_i32, adv_bs_i32, adv_b_i32, adv_bs_i32), + ) elif const_expr(wave_specialized_tdm): active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), @@ -2347,13 +2791,19 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): addr_lo_as = addr_lo_as + adv_as_i32 addr_lo_bs = addr_lo_bs + adv_bs_i32 - if const_expr(_bvs_active): + if const_expr(_bvs_active and loop_iters > 0): # Prologue: prefetch the first _bvs_D K-tiles (global->VGPR). Carried as - # FLAT lists of i32 (list-of-tuples can't be loop-carried). + # FLAT lists of i32 (list-of-tuples can't be loop-carried). Only when the + # main loop runs; a tail-only problem (loop_iters == 0) loads inline. _bvs_pf = [_bvs_prefetch(split_k_base + arith.index(_d * tile_k)) for _d in range(_bvs_D)] _bvs_ra = [_v for (_a, _b) in _bvs_pf for _v in _a] _bvs_rb = [_v for (_a, _b) in _bvs_pf for _v in _b] + if const_expr(_avr_active and loop_iters > 0): + _avr_pf = [_avr_prefetch(split_k_base + arith.index(_d * tile_k)) for _d in range(_avr_D)] + _avr_rf = [_v for (_f, _s) in _avr_pf for _v in _f] + _avr_rs = [_v for (_f, _s) in _avr_pf for _v in _s] + _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) # Main loop — acc_mixed style: fence at top, TDM_load mid-compute. @@ -2368,15 +2818,24 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): init_args = list(accs) + [active_addr_lo] if const_expr(_bvs_active): init_args = init_args + _bvs_ra + _bvs_rb + if const_expr(_avr_active): + init_args = init_args + _avr_rf + _avr_rs for loop_iter, state in range(0, loop_iters, 1, init=init_args): accs_in = list(state[:n_accs]) cur_addr_lo = state[n_accs] + _state_off = n_accs + 1 if const_expr(_bvs_active): - _ra0 = n_accs + 1 - _ring_a = list(state[_ra0 : _ra0 + _bvs_D * wmma_m_rep]) - _rb0 = _ra0 + _bvs_D * wmma_m_rep - _ring_b = list(state[_rb0 : _rb0 + _bvs_D * b_scale_load_rep]) + _ra0 = _state_off + _ring_a = list(state[_ra0 : _ra0 + _bvs_D * _vs_tile_a]) + _rb0 = _ra0 + _bvs_D * _vs_tile_a + _ring_b = list(state[_rb0 : _rb0 + _bvs_D * _vs_tile_b]) + _state_off = _rb0 + _bvs_D * _vs_tile_b + if const_expr(_avr_active): + _af0 = _state_off + _avr_ring_f = list(state[_af0 : _af0 + _avr_D * _avr_frags_per_tile]) + _as0 = _af0 + _avr_D * _avr_frags_per_tile + _avr_ring_s = list(state[_as0 : _as0 + _avr_D * _avr_ascale_per_tile]) for buf_idx in range_constexpr(num_buffers): load_stage = (buf_idx + num_buffers - 1) % num_buffers @@ -2413,20 +2872,35 @@ def _late_tdm_ws_split_signal(): # NOTE: must stay AFTER the fence; issuing the scale # buffer_loads before the cluster barrier hangs the vgpr path. if const_expr(_bvs_active): - _cur_a = _ring_a[:wmma_m_rep] - _cur_b = _ring_b[:b_scale_load_rep] + _cur_a = _ring_a[:_vs_tile_a] + _cur_b = _ring_b[:_vs_tile_b] _next_kb = ( split_k_base + loop_iter * arith.index(num_buffers * tile_k) + arith.index((buf_idx + _bvs_D) * tile_k) ) _na, _nb2 = _bvs_prefetch(_next_kb) - _ring_a = _ring_a[wmma_m_rep:] + list(_na) - _ring_b = _ring_b[b_scale_load_rep:] + list(_nb2) + _ring_a = _ring_a[_vs_tile_a:] + list(_na) + _ring_b = _ring_b[_vs_tile_b:] + list(_nb2) else: _cur_a = None _cur_b = None + if const_expr(_avr_active): + _cur_ad = _avr_ring_f[:_avr_frags_per_tile] + _cur_as = _avr_ring_s[:_avr_ascale_per_tile] if _avr_ascale_per_tile else None + _next_akb = ( + split_k_base + + loop_iter * arith.index(num_buffers * tile_k) + + arith.index((buf_idx + _avr_D) * tile_k) + ) + _nf, _ns = _avr_prefetch(_next_akb) + _avr_ring_f = _avr_ring_f[_avr_frags_per_tile:] + list(_nf) + _avr_ring_s = _avr_ring_s[_avr_ascale_per_tile:] + list(_ns) + else: + _cur_ad = None + _cur_as = None + accs_in = compute_tile_scheduled( accs_in, stages_a_idx[buf_idx], @@ -2438,6 +2912,8 @@ def _late_tdm_ws_split_signal(): a0_prefetch=a0_prefetch, pf_a_scales=_cur_a, pf_b_scales=_cur_b, + pf_a_data=_cur_ad, + pf_a_data_scales=_cur_as, ) cur_addr_lo = addr_box[0] hot_loop_scheduler_scheduled() @@ -2446,7 +2922,11 @@ def _late_tdm_ws_split_signal(): _bvs_yield = _ring_a + _ring_b else: _bvs_yield = [] - results = yield list(accs_in) + [cur_addr_lo] + _bvs_yield + if const_expr(_avr_active): + _avr_yield = _avr_ring_f + _avr_ring_s + else: + _avr_yield = [] + results = yield list(accs_in) + [cur_addr_lo] + _bvs_yield + _avr_yield accs = list(results[:n_accs]) active_addr_lo = results[n_accs] @@ -2544,8 +3024,74 @@ def _bvs_tail_kb(): _bvs_tail_kt[0] += 1 return kb + # General VGPR scale: prefetch the tail's scales _bvs_D K-tiles ahead so + # each scale buffer_load overlaps an earlier tile's TDM wait instead of + # stalling the WMMA inline. The ref-segmented deep-pipeline path keeps its + # inline per-tile load to stay within its tight VGPR budget. + _bvs_tail_pf = use_general_vgpr_scale + _bvs_tail_ring = [] + _bvs_tail_issue_kt = [loop_iters * num_buffers] + # Preload (opt-in): all K-tiles' scales loaded up front, indexed by step. + _bvs_preload_ring = [] + _bvs_tail_step = [0] + + def _bvs_tail_issue_one(): + if const_expr(_bvs_tail_pf and not _bvs_preload and _bvs_tail_issue_kt[0] < num_k_tiles): + _kb = split_k_base + arith.index(_bvs_tail_issue_kt[0] * tile_k) + _bvs_tail_ring.append(_bvs_prefetch(_kb)) + _bvs_tail_issue_kt[0] += 1 + + def _bvs_tail_scales(): + # Per-tile (scale_k_base, pf_a_scales, pf_b_scales): consume the preload + # set or the prefetch ring on the general path, else fall back to the + # inline-load k_base. + if const_expr(_bvs_preload): + _i = _bvs_tail_step[0] + _bvs_tail_step[0] += 1 + _cur_a, _cur_b = _bvs_preload_ring[_i] + return None, _cur_a, _cur_b + if const_expr(_bvs_tail_pf): + _cur_a, _cur_b = _bvs_tail_ring.pop(0) + return None, _cur_a, _cur_b + return _bvs_tail_kb(), None, None + + if const_expr(_bvs_preload): + # One-shot: issue ALL K-tiles' scales up front (distinct voffset VGPRs, + # no shared-soffset reuse). All loads overlap the prologue/first B TDM. + rocdl.sched_barrier(0) + for _t in range_constexpr(num_k_tiles): + _kb = split_k_base + arith.index(_t * tile_k) + _bvs_preload_ring.append(_bvs_prefetch(_kb, preload=True)) + elif const_expr(_bvs_tail_pf): + # Prime the ring before the first tail fence so even tile 0's scale + # load overlaps its TDM wait rather than stalling the WMMA. + rocdl.sched_barrier(0) + for _ in range_constexpr(_bvs_D): + _bvs_tail_issue_one() + + _avr_tail_ring = [] + _avr_tail_issue_kt = [loop_iters * num_buffers] + + def _avr_tail_issue_one(): + if const_expr(_avr_active and _avr_tail_issue_kt[0] < num_k_tiles): + _kb = split_k_base + arith.index(_avr_tail_issue_kt[0] * tile_k) + _avr_tail_ring.append(_avr_prefetch(_kb)) + _avr_tail_issue_kt[0] += 1 + + def _avr_tail_consume(): + if const_expr(_avr_active): + _f, _s = _avr_tail_ring.pop(0) + return _f, _s if _s else None + return None, None + + if const_expr(_avr_active): + rocdl.sched_barrier(0) + for _ in range_constexpr(_avr_D): + _avr_tail_issue_one() + for _load_stage, _compute_stage, _outstanding in tail_plan: - _entry_kb = _bvs_tail_kb() + _entry_kb, _pf_a_scales, _pf_b_scales = _bvs_tail_scales() + _tail_ad, _tail_as = _avr_tail_consume() if const_expr(_outstanding == -1): if const_expr(_tail_had_load): _pipeline_fence(outstanding=0) @@ -2560,6 +3106,10 @@ def _bvs_tail_kb(): emit_filler=(_load_ptpc_scales_once if is_ptpc else None), a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, + pf_a_scales=_pf_a_scales, + pf_b_scales=_pf_b_scales, + pf_a_data=_tail_ad, + pf_a_data_scales=_tail_as, ) else: @@ -2577,6 +3127,10 @@ def _emit_epi_addrs(): emit_filler=_emit_epi_addrs, a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, + pf_a_scales=_pf_a_scales, + pf_b_scales=_pf_b_scales, + pf_a_data=_tail_ad, + pf_a_data_scales=_tail_as, ) else: _pipeline_fence_signal(outstanding=_outstanding) @@ -2616,6 +3170,8 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) rocdl.sched_barrier(0) + _bvs_tail_issue_one() + _avr_tail_issue_one() accs = compute_tile_scheduled( accs, stages_a_idx[_compute_stage], @@ -2625,6 +3181,10 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): mid_compute_callback=_tail_mid_cb, a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, + pf_a_scales=_pf_a_scales, + pf_b_scales=_pf_b_scales, + pf_a_data=_tail_ad, + pf_a_data_scales=_tail_as, ) if const_expr(_load_stage is not None): @@ -2700,6 +3260,7 @@ def _emit_buffer_store(): atomic_barrier_enable, b_streaming, scale_load_path, + a_load_path, fp8_schedule, ) @@ -2838,6 +3399,7 @@ def compile_ptpc_gemm( __all__ = [ + "is_ref_segmented_lds_layout", "compile_fp8fp4_gemm", "compile_mxscale_gemm", "compile_mxfp4_gemm", diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index ad1daf3e..48e70c2e 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -24,7 +24,11 @@ import flydsl.compiler as flyc # noqa: E402,I001 from flydsl.runtime.device import get_rocm_arch # noqa: E402 -from kernels.gemm_fp8fp4_gfx1250 import compile_mxscale_gemm, compile_ptpc_gemm # noqa: E402 +from kernels.gemm_fp8fp4_gfx1250 import ( # noqa: E402 + compile_mxscale_gemm, + compile_ptpc_gemm, + is_ref_segmented_lds_layout, +) from tests.kernels.utils import fp4_utils # noqa: E402 if not torch.cuda.is_available(): @@ -49,6 +53,79 @@ def preshuffle_e8m0_scale_coalesced(scale: torch.Tensor, block: int = 128) -> to return g.view(M, Ks) +def preshuffle_e8m0_scale_coalesced_general( + scale: torch.Tensor, + warp_tile: int, + scale_k_per_tile: int, + kgrp_shift: int, + WMMA_DIM: int = 16, + row_align: int = None, + b128: bool = False, +) -> torch.Tensor: + """General lane-major scale layout for the buffer_load->VGPR path. + + Generalizes :func:`preshuffle_e8m0_scale_coalesced` (which is locked to + warp_tile=128, scale_k_per_tile=4, no lane_kgrp) to arbitrary ``warp_tile``, + ``scale_k_per_tile`` (=> ``k_wmma_steps``) and the a8w4/fp4 ``lane_kgrp`` + scale shift. The byte delivered to every *physical* 32-lane index + ``lane32 = lane_kgrp*16 + lane16`` is baked here so the kernel's load address + is lane_kgrp-agnostic and fully coalesced (32 lanes -> 512 contiguous bytes + per (k_wmma_step, rep_group)). + + Per (row_block of ``warp_tile`` rows, K-tile): layout is + ``[k_wmma_step, rep_group(ceil(rep/4)), lane32(32), j(4), spw(4)]`` and the + value at ``(lane32=(G*16+L), rep=grp*4+j)`` is the original e8m0 quadruple of + row ``block*warp_tile + (rep + G*kgrp_shift)*16 + L``; slots whose shifted rep + reaches ``rep_count`` are filled with E8M0 127 (=1.0) guards. + + Mirrors the TDM/LDS path value-for-value (see the offline parity check in + ``flydsl_fp8_perf/verify_vgpr_scale_layout.py``), so it is correct by + construction against the working ``scale_load_path='tdm'`` path. + """ + rows, K_scale = scale.shape + assert K_scale % scale_k_per_tile == 0, f"K_scale={K_scale} % spt={scale_k_per_tile}" + assert scale_k_per_tile % 4 == 0, f"scale_k_per_tile={scale_k_per_tile} must be a multiple of 4" + align = row_align if row_align is not None else warp_tile + if rows % align != 0: + pad = _align_up(rows, align) - rows + scale = torch.cat([scale, torch.full((pad, K_scale), 127, dtype=scale.dtype, device=scale.device)], dim=0) + rows = scale.shape[0] + R = warp_tile // WMMA_DIM # rep_count (wmma reps per warp tile) + KS = scale_k_per_tile // 4 # k_wmma_steps + KG = K_scale // scale_k_per_tile # K-tiles + NG = (R + 3) // 4 # rep groups (vec4 b128 each) + num_mb = rows // warp_tile + + # Index grids over [mb, kt, ks, grp, lane32, j]; spw is the trailing dim. + dev = scale.device + mb = torch.arange(num_mb, device=dev).view(num_mb, 1, 1, 1, 1, 1) + kt = torch.arange(KG, device=dev).view(1, KG, 1, 1, 1, 1) + ks = torch.arange(KS, device=dev).view(1, 1, KS, 1, 1, 1) + grp = torch.arange(NG, device=dev).view(1, 1, 1, NG, 1, 1) + l32 = torch.arange(32, device=dev).view(1, 1, 1, 1, 32, 1) + j = torch.arange(4, device=dev).view(1, 1, 1, 1, 1, 4) + G = l32 // 16 + L = l32 % 16 + rep = grp * 4 + j + orig_rep = rep + G * kgrp_shift + valid = orig_rep < R + orig_row = mb * warp_tile + orig_rep * WMMA_DIM + L + # Clamp out-of-range (guard) rows so the gather stays in bounds; masked below. + orig_row = torch.where(valid, orig_row, torch.zeros_like(orig_row)) + colg = (kt * KS + ks) # group-of-4 column index into the K dimension + # Gather the 4 spw bytes for each (row, colg): scale viewed as [rows, KG*KS, 4]. + scale_g = scale.view(rows, KG * KS, 4) + row_idx, colg_idx = torch.broadcast_tensors(orig_row, colg) + out = scale_g[row_idx, colg_idx] # [mb, KG, KS, NG, 32, 4, 4] + out = torch.where(valid.unsqueeze(-1), out, torch.full_like(out, 127)) + if b128: + # b128 variant: move ks to the innermost (next to spw) so a lane's + # KS scale-words are contiguous -> one buffer_load_b128 reads a whole + # tile's ks per (rep,lane). Output order [mb, kt, grp, j, lane32, ks, spw]. + out = out.permute(0, 1, 3, 5, 4, 2, 6).contiguous() + return out.reshape(num_mb, -1).contiguous() + + def preshuffle_e8m0_scale( scale: torch.Tensor, warp_tile: int, @@ -83,6 +160,30 @@ def preshuffle_e8m0_scale( return g.reshape(-1, k_groups * k_wmma_steps * wmma_rep * SCALES_PER_WMMA) +def preshuffle_scale_for_load_path(scale, warp_tile, skt, *, scale_load_path, data_format, ref_segmented, row_align=None): + """Host scale preshuffle matching the kernel's selected scale_load_path. + + - 'tdm': interleaved TDM/LDS layout. + - 'vgpr'/'vgpr_ab_split' on the ref-segmented deep-pipeline config: legacy + lane-major coalesced layout. + - 'vgpr' on any other (general) config: general coalesced layout, with the + a8w4/fp4 lane_kgrp scale shift. + """ + if scale_load_path in ("vgpr", "vgpr_ab_split"): + if ref_segmented: + return preshuffle_e8m0_scale(scale, warp_tile, scale_k_per_tile=skt, coalesced=True) + kgrp_shift = 1 if data_format in ("a8w4", "fp4") else 0 + # FLYDSL_BUFFER_VGPR_SCALE_PRELOAD=1 switches the general vgpr path to the + # b128 layout (ks innermost) so the kernel reads scales with wide + # buffer_load_b128 instead of many b32. Must match the kernel flag of the + # same name (kernels/gemm_fp8fp4_gfx1250.py::_bvs_b128). + b128 = bool(int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_PRELOAD", "0"))) + return preshuffle_e8m0_scale_coalesced_general( + scale, warp_tile, skt, kgrp_shift, row_align=row_align, b128=b128 + ) + return preshuffle_e8m0_scale(scale, warp_tile, scale_k_per_tile=skt, row_align=row_align) + + def random_fp8_data(rows: int, cols: int, *, device="cpu") -> torch.Tensor: """Generate random FP8/E4M3 data as uint8. Avoids NaN (0x7F/0xFF).""" return torch.randint(0, 126, (rows, cols), dtype=torch.uint8, device=device) @@ -367,6 +468,7 @@ def _run_mxscale_gemm_test( split_k=1, b_streaming=False, scale_load_path="tdm", + a_load_path="tdm", return_launch_fn=False, ): """Unified test body for FP4 and FP8.""" @@ -448,9 +550,19 @@ def _run_mxscale_gemm_test( skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // m_warp warp_tile_n = tile_n // n_warp - _coalesced_scale = scale_load_path in ("vgpr", "vgpr_ab_split") - a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) - b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) + _ref_seg = is_ref_segmented_lds_layout( + data_format=data_format, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, m_warp=m_warp, n_warp=n_warp, + num_buffers=num_buffers, split_k=split_k, wave_specialized_tdm=wave_specialized_tdm, + use_scale_opsel=use_scale_opsel, + ) + a_scale = preshuffle_scale_for_load_path( + a_scale, warp_tile_m, skt, scale_load_path=scale_load_path, data_format=data_format, + ref_segmented=_ref_seg, row_align=tile_m, + ) + b_scale = preshuffle_scale_for_load_path( + b_scale, warp_tile_n, skt, scale_load_path=scale_load_path, data_format=data_format, + ref_segmented=_ref_seg, row_align=tile_n, + ) # Preshuffle B data K_packed = padded_k // padded_shape["pack_b"] @@ -486,6 +598,7 @@ def _run_mxscale_gemm_test( expert_sched_mode=expert_sched_mode, b_streaming=b_streaming, scale_load_path=scale_load_path, + a_load_path=a_load_path, ) # Keep 2D — dynamic_layout=True packs shape as i32; flattening overflows for M*K >= 2^31. @@ -1742,7 +1855,7 @@ def _run_benchmark(args): print( f" Buffers={args.num_buffers}, out={args.out_dtype}, " f"opsel={args.use_scale_opsel}, inst_prefetch={args.inst_prefetch}, " - f"scale_load={args.scale_load_path}" + f"scale_load={args.scale_load_path}, a_load={args.a_load_path}" ) if args.split_k > 1: print(f" Split-K={args.split_k} (atomic accumulate, buffer-store epilogue)") @@ -1811,9 +1924,19 @@ def _run_benchmark(args): a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) skt = tile_k // SCALE_BLOCK - _coalesced_scale = args.scale_load_path in ("vgpr", "vgpr_ab_split") - a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) - b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) + _ref_seg = is_ref_segmented_lds_layout( + data_format=data_format, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, m_warp=args.m_warp, + n_warp=args.n_warp, num_buffers=args.num_buffers, split_k=args.split_k, + wave_specialized_tdm=args.wave_spec_tdm, use_scale_opsel=args.use_scale_opsel, + ) + a_scale = preshuffle_scale_for_load_path( + a_scale, warp_tile_m, skt, scale_load_path=args.scale_load_path, data_format=data_format, + ref_segmented=_ref_seg, row_align=tile_m, + ) + b_scale = preshuffle_scale_for_load_path( + b_scale, warp_tile_n, skt, scale_load_path=args.scale_load_path, data_format=data_format, + ref_segmented=_ref_seg, row_align=tile_n, + ) K_packed = padded_k // PACK_B b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -1877,6 +2000,7 @@ def _run_benchmark(args): atomic_barrier_enable=args.atomic_barrier_enable, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, + a_load_path=args.a_load_path, ) compiled_exe = flyc.compile( @@ -2025,9 +2149,19 @@ def _run_graph_verify(args): skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // args.m_warp warp_tile_n = tile_n // args.n_warp - _coalesced_scale = args.scale_load_path in ("vgpr", "vgpr_ab_split") - a_scale = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt, coalesced=_coalesced_scale) - b_scale = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt, coalesced=_coalesced_scale) + _ref_seg = is_ref_segmented_lds_layout( + data_format=data_format, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, m_warp=args.m_warp, + n_warp=args.n_warp, num_buffers=args.num_buffers, split_k=args.split_k, + wave_specialized_tdm=args.wave_spec_tdm, use_scale_opsel=args.use_scale_opsel, + ) + a_scale = preshuffle_scale_for_load_path( + a_scale, warp_tile_m, skt, scale_load_path=args.scale_load_path, data_format=data_format, + ref_segmented=_ref_seg, row_align=tile_m, + ) + b_scale = preshuffle_scale_for_load_path( + b_scale, warp_tile_n, skt, scale_load_path=args.scale_load_path, data_format=data_format, + ref_segmented=_ref_seg, row_align=tile_n, + ) K_packed = padded_k // padded_shape["pack_b"] b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -2065,6 +2199,7 @@ def _run_graph_verify(args): atomic_barrier_enable=args.atomic_barrier_enable, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, + a_load_path=args.a_load_path, ) c_flat = c_gpu.contiguous() @@ -2165,7 +2300,7 @@ def launch(): parser.add_argument("--tile-k", type=int, default=128) parser.add_argument("--m-warp", type=int, default=2) parser.add_argument("--n-warp", type=int, default=2) - parser.add_argument("--num-buffers", type=int, default=4, choices=[2, 3, 4]) + parser.add_argument("--num-buffers", type=int, default=4, choices=[2, 3, 4, 5, 6]) parser.add_argument("--split-k", type=int, default=1) parser.add_argument("--l2-prefetch-distance", type=int, default=2) parser.add_argument("--cluster-m", type=int, default=1) @@ -2182,6 +2317,12 @@ def launch(): default="tdm", choices=["tdm", "vgpr", "vgpr_ab_split"], ) + parser.add_argument( + "--a-load-path", + type=str, + default="tdm", + choices=["tdm", "vgpr", "vgpr_ascale"], + ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument("--b-streaming", action="store_true", default=False) parser.add_argument( @@ -2275,4 +2416,5 @@ def launch(): expert_sched_mode=args.expert_sched_mode, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, + a_load_path=args.a_load_path, ) From 62154860da90e6ad7eb55f8ce46688efb80fe725 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Thu, 11 Jun 2026 04:53:40 +0000 Subject: [PATCH 02/16] fix(test): fix test cache issue --- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 151 ++++++++++++---------- 1 file changed, 83 insertions(+), 68 deletions(-) diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 48e70c2e..a88fd786 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -1225,33 +1225,45 @@ def launch(): ) -def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None, n_per_graph=20): - """Per-launch timer via hipGraph: capture n_per_graph launches, replay iters times, single event pair around the whole replay loop.""" +def _l2_cache_bytes() -> int: + """Reported L2 size (gfx1250 under-reports the effective LLC, so callers floor this).""" + return getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "L2_cache_size", 4 * 1024 * 1024) + + +def _rotate_slot_count(working_set_bytes: int, flush_l2: bool, cap: int = 256) -> int: + """Number of rotate-buffer copies so the pool exceeds the last-level cache.""" + if not flush_l2: + return 1 + POOL_TARGET = 1024 * 1024 * 1024 + target = max(_l2_cache_bytes() * 2, POOL_TARGET) + needed = -(-target // max(working_set_bytes, 1)) # ceil-div + return max(2, min(needed, cap)) + + +def _bench_kernel_us_cudagraph(run_slot, num_slots, warmup=10, iters=100): + """Per-launch timer via hipGraph: captures n_per_graph launches, replays them.""" + n_per_graph = max(num_slots, 20) capture_stream = torch.cuda.Stream() capture_stream.wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(capture_stream): - for _ in range(warmup): - if prep_fn is not None: - prep_fn() - run_fn() + for i in range(warmup): + run_slot(i) torch.cuda.current_stream().wait_stream(capture_stream) torch.cuda.synchronize() g = torch.cuda.CUDAGraph() - if prep_fn is not None: - prep_fn() with torch.cuda.graph(g, stream=capture_stream): - for _ in range(n_per_graph): - run_fn() + for j in range(n_per_graph): + run_slot(j) torch.cuda.synchronize() # Sanity guard against empty graph capture. ref_start = torch.cuda.Event(enable_timing=True) ref_end = torch.cuda.Event(enable_timing=True) ref_start.record() - for _ in range(n_per_graph): - run_fn() + for j in range(n_per_graph): + run_slot(j) ref_end.record() torch.cuda.synchronize() ref_per_launch_us = ref_start.elapsed_time(ref_end) * 1e3 / n_per_graph @@ -1288,45 +1300,20 @@ def _bench_kernel_us_cudagraph(run_fn, warmup=10, iters=100, prep_fn=None, n_per return start_ev.elapsed_time(end_ev) * 1e3 / (iters * n_per_graph) -def _bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): - """Per-iter CUDA events with L2 flush + IQR-trimmed median; fast path uses a single event pair when no flush/prep is requested (preserves back-to-back launch pipelining).""" - flush_buf = None - if flush_l2: - l2_bytes = getattr( - torch.cuda.get_device_properties(torch.cuda.current_device()), "L2_cache_size", 4 * 1024 * 1024 - ) - alloc_bytes = max(l2_bytes * 2, 8 * 1024 * 1024) - flush_buf = torch.empty(alloc_bytes, dtype=torch.uint8, device="cuda") - - for _ in range(warmup): - if flush_buf is not None: - flush_buf.zero_() - if prep_fn is not None: - prep_fn() - run_fn() - torch.cuda.synchronize() +def _bench_kernel_us(run_slot, num_slots, warmup=10, iters=50): + """Per-iter CUDA-event timer with rotating buffers (cold L2) + IQR-trimmed median.""" + del num_slots # rotation is handled inside run_slot; kept for call-site symmetry - if flush_buf is None and prep_fn is None: - # Single event pair preserves back-to-back launch pipelining (returns mean latency). - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(iters): - run_fn() - end.record() - torch.cuda.synchronize() - return start.elapsed_time(end) * 1e3 / iters + for i in range(warmup): + run_slot(i) + torch.cuda.synchronize() start_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] for i in range(iters): - if flush_buf is not None: - flush_buf.zero_() - if prep_fn is not None: - prep_fn() start_ev[i].record() - run_fn() + run_slot(i) end_ev[i].record() torch.cuda.synchronize() @@ -1342,7 +1329,6 @@ def _bench_kernel_us(run_fn, warmup=10, iters=50, flush_l2=True, prep_fn=None): if filtered: latencies = filtered - del flush_buf return latencies[len(latencies) // 2] @@ -1859,9 +1845,13 @@ def _run_benchmark(args): ) if args.split_k > 1: print(f" Split-K={args.split_k} (atomic accumulate, buffer-store epilogue)") - l2_flush_label = "OFF (graph)" if getattr(args, "use_graph", False) else ("OFF" if args.no_flush_l2 else "ON") - print(f" Warmup={args.warmup}, Iters={args.iters}, L2 flush={l2_flush_label}") - print(" Output init: zero before warmup") + if args.no_flush_l2: + l2_flush_label = "OFF (hot L2, --no-flush-l2)" + elif getattr(args, "use_graph", False): + l2_flush_label = "OFF (graph replay is warm/L2-resident; use eager for cold HBM)" + else: + l2_flush_label = "ON (rotate buffers, cold HBM)" + print(f" Warmup={args.warmup}, Iters={args.iters}, L2 defeat={l2_flush_label}") if is_ptpc: # compile_ptpc_gemm forces these internally; flag the ones the user set off-default. _ptpc_ignored = [] @@ -2017,16 +2007,13 @@ def _run_benchmark(args): torch.cuda.current_stream(), ) - def prep_kernel(): - c_gpu.zero_() - - def run_kernel(): + def run_one(c_, a_, b_, as_, bs_): compiled_exe( - c_gpu, - a_gpu, - b_gpu, - as_gpu, - bs_gpu, + c_, + a_, + b_, + as_, + bs_, padded_m, padded_n, padded_k, @@ -2034,21 +2021,43 @@ def run_kernel(): torch.cuda.current_stream(), ) - prep_kernel() - run_kernel() + c_gpu.zero_() + run_one(c_gpu, a_gpu, b_gpu, as_gpu, bs_gpu) torch.cuda.synchronize() compile_ms = (time.perf_counter() - t0) * 1e3 print(f" Compile + first launch: {compile_ms:.0f} ms") + flush_l2 = not args.no_flush_l2 + working_set = sum(t.numel() * t.element_size() for t in (a_gpu, b_gpu, as_gpu, bs_gpu, c_gpu)) + num_slots = _rotate_slot_count(working_set, flush_l2) + a_pool = [a_gpu] + [a_gpu.clone() for _ in range(num_slots - 1)] + b_pool = [b_gpu] + [b_gpu.clone() for _ in range(num_slots - 1)] + as_pool = [as_gpu] + [as_gpu.clone() for _ in range(num_slots - 1)] + bs_pool = [bs_gpu] + [bs_gpu.clone() for _ in range(num_slots - 1)] + c_pool = [c_gpu] + [torch.zeros_like(c_gpu) for _ in range(num_slots - 1)] + print( + f" Rotate buffers: {num_slots} slot(s), pool={working_set * num_slots / 1e6:.1f} MB " + f"(working set {working_set / 1e6:.1f} MB)" + (" [HOT L2: --no-flush-l2]" if num_slots == 1 else "") + ) + + def run_slot(i): + s = i % num_slots + run_one(c_pool[s], a_pool[s], b_pool[s], as_pool[s], bs_pool[s]) + use_graph = getattr(args, "use_graph", False) if use_graph: + if not args.no_flush_l2: + print( + " WARNING: hipGraph capture aliases the kernel-param buffer across " + "replayed launches, so the rotate buffers above do NOT take effect under " + "replay -- this number is WARM (L2-resident). Use eager mode (drop " + "--use-graph) for the cold-HBM number." + ) print(f"[2/3] Warming up ({args.warmup} iters) + bench via hipGraph " f"({args.iters} replays)...") - us = _bench_kernel_us_cudagraph(run_kernel, warmup=args.warmup, iters=args.iters) + us = _bench_kernel_us_cudagraph(run_slot, num_slots, warmup=args.warmup, iters=args.iters) else: print(f"[2/3] Warming up ({args.warmup} iters) + benchmarking ({args.iters} iters)...") - us = _bench_kernel_us( - run_kernel, warmup=args.warmup, iters=args.iters, flush_l2=not args.no_flush_l2, prep_fn=prep_kernel - ) + us = _bench_kernel_us(run_slot, num_slots, warmup=args.warmup, iters=args.iters) logical_flops = 2.0 * M * N * K kernel_flops = 2.0 * padded_m * padded_n * padded_k @@ -2337,15 +2346,21 @@ def launch(): ) parser.add_argument("--warmup", type=int, default=5) parser.add_argument("--iters", type=int, default=20) - parser.add_argument("--no-flush-l2", action="store_true", default=False) + parser.add_argument( + "--no-flush-l2", + action="store_true", + default=False, + help="Disable the rotate-buffer L2 defeat (use a single hot buffer) for a " + "warm-cache measurement. Applies to both eager and --use-graph modes.", + ) parser.add_argument( "--use-graph", action="store_true", default=False, - help="Time via hipGraph capture+replay to strip " - "host launch overhead from per-launch latency. " - "Implicitly disables L2 flush (graph replays " - "are back-to-back, hot-cache).", + help="Time via hipGraph capture+replay to strip host launch overhead from " + "per-launch latency. NOTE: graph replay measures the WARM (L2-resident) " + "regime -- rotate buffers do not survive hipGraph capture, so use the eager " + "path (drop --use-graph) for the cold-HBM number.", ) parser.add_argument( "--verify-graph", From 203841e00857b5cdc8b1a1471d5c3860a56a50b1 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Thu, 11 Jun 2026 09:04:25 +0000 Subject: [PATCH 03/16] add bsplit expr --- kernels/gemm_fp8fp4_gfx1250.py | 154 +++++++-- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 398 ++++++++++++++++------ 2 files changed, 429 insertions(+), 123 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index cb534fc4..c113d3e7 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -147,6 +147,7 @@ def compile_fp8fp4_gemm( scale_load_path: str = "tdm", fp8_schedule: str = "auto", a_load_path: str = "tdm", + b_split_load: bool = False, ): """Compile an FP4/FP8/A8W4 GEMM kernel with TDM async copy. @@ -205,10 +206,11 @@ def compile_fp8fp4_gemm( raise ValueError("a_load_path != 'tdm' requires data_format='fp8' or 'a8w4'") if use_a_vgpr and is_ptpc: raise ValueError("a_load_path != 'tdm' requires scale_mode='mxscale'") + b_split_load = bool(b_split_load) effective_expert_sched_mode = bool(expert_sched_mode) if num_buffers not in (2, 3, 4, 5, 6): - raise ValueError(f"num_buffers must be 2, 3, 4 or 5, got {num_buffers}") + raise ValueError(f"num_buffers must be 2, 3, 4, 5 or 6, got {num_buffers}") if split_k < 1: raise ValueError(f"split_k must be >= 1, got {split_k}") @@ -324,11 +326,11 @@ def compile_fp8fp4_gemm( _as_ds_loads = _a_frag_ds + _scale_ds_loads lds_a_stride_bytes = packed_tile_k_a + LDS_PAD_A_BYTES - if scale_load_path == "vgpr_ab_split": + if scale_load_path == "vgpr_ab_split" or b_split_load: if tile_m % 2 != 0: - raise ValueError(f"scale_load_path='vgpr_ab_split' requires even tile_m, got {tile_m}") + raise ValueError("B/A split load variants require even tile_m, got " f"{tile_m}") if tile_n % 32 != 0: - raise ValueError(f"scale_load_path='vgpr_ab_split' requires tile_n divisible by 32, got {tile_n}") + raise ValueError("B/A split load variants require tile_n divisible by 32, got " f"{tile_n}") lds_a_data_bytes = 0 if use_a_vgpr else tile_m * lds_a_stride_bytes lds_b_data_bytes = tile_n * packed_tile_k_b @@ -453,7 +455,7 @@ def _align_up(value: int, align: int) -> int: # Valid only when A is VGPR with TDM scales (wave0 otherwise idle, waves # 1/2/3 = B/A_scale/B_scale) and tile_n splits into two equal N-group halves. _b_nsplit = ( - bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) + b_split_load and use_a_vgpr and not use_ascale_vgpr and not is_ptpc @@ -461,6 +463,27 @@ def _align_up(value: int, align: int) -> int: and wave_specialized_tdm and tile_n % 32 == 0 ) + # TDM B N-split while A still uses TDM. Wave assignment: + # wave0=A, wave1=B first N-half, wave2=B second N-half, + # wave3=A_scale then B_scale. + # Wave3 issues two tensor ops per K-tile, so the pipeline wait uses a + # wave-specific outstanding count in the fence helpers below. + _tdm_b_nsplit_scale_combo = ( + b_split_load + and not use_a_vgpr + and data_format == "a8w4" + and not is_ptpc + and scale_load_path == "tdm" + and wave_specialized_tdm + and num_buffers == 4 + and tile_n % 32 == 0 + ) + if b_split_load and not (_b_nsplit or _tdm_b_nsplit_scale_combo): + raise ValueError( + "b_split_load currently supports either a_load_path='vgpr' with TDM scales, " + "or A8W4 a_load_path='tdm' scale_load_path='tdm' wave_specialized_tdm=True " + "num_buffers=4" + ) if use_ref_segmented_lds_layout: # The A/B data pools are no longer packed into the same per-stage @@ -884,7 +907,7 @@ def _avr_load_a_tile(k_base): frags = [] for ks in range_constexpr(k_wmma_steps): for wm in range_constexpr(wmma_m_rep): - row = warp_m_base + arith.index(wm * WMMA_M) + lane16 + row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 row_off = row * _avr_lda_i32 + _avr_lane_kgrp_off + arith.index(ks * WMMA_K // PACK_FACTOR_A // 4) loads = [] for i in range_constexpr(DS_LOADS_PER_A_FRAG): @@ -2646,7 +2669,7 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): desc_a1_init = make_desc_a_half(stages_a_mem[0], split_k_base, 1) desc_b1_init = make_desc_b_half(stages_b_mem[0], split_k_base, 1) - if const_expr(_b_nsplit): + if const_expr(_b_nsplit or _tdm_b_nsplit_scale_combo): # N-direction B halves: wave0 -> N-groups [0:tile_n//32], wave1 -> # [tile_n//32:tile_n//16]. N is the outer LDS dim, so the two halves # write contiguous blocks that together equal the full-B tile layout @@ -2668,13 +2691,17 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): _drop_scale_waves = is_ptpc or (use_buffer_vgpr_scale and not use_ab_half_split) or use_ascale_vgpr - if const_expr(_b_nsplit): - # B N-split: all 4 waves load (wave0=B N-half0, wave1=B N-half1, - # wave2=A_scale, wave3=B_scale). - active_pred_const = pred_const + if const_expr(_b_nsplit or _tdm_b_nsplit_scale_combo): + # Split variants use only the first four waves. Keep extra compute + # waves from falling through the 4-slot selector to the last slot. + active_pred_const = arith.select(tdm_wave_id < fx.Int32(4), fx.Int32(1), fx.Int32(0)) elif const_expr(_drop_a_loader_wave and not _drop_scale_waves): # A via VGPR, scales via TDM: wave0 (A) idle, waves 1,2,3 active - active_pred_const = arith.select(tdm_wave_id >= fx.Int32(1), fx.Int32(1), fx.Int32(0)) + active_pred_const = arith.select( + tdm_wave_id >= fx.Int32(1), + arith.select(tdm_wave_id < fx.Int32(4), fx.Int32(1), fx.Int32(0)), + fx.Int32(0), + ) else: _active_wave_limit = 2 if _drop_scale_waves else 4 active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) @@ -2722,6 +2749,21 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): (desc_bn0_init, desc_bn1_init, desc_as_init, desc_bs_init), (adv_b_i32, adv_b_i32, adv_as_i32, adv_bs_i32), ) + elif const_expr(_tdm_b_nsplit_scale_combo): + # A + B N-split + combined scale wave: + # wave0=A, wave1=B N-half0, wave2=B N-half1, wave3=A_scale; + # wave3 additionally issues B_scale below with an independent address. + active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( + (stages_a_lds_addr, nstages_b0_lds_addr, nstages_b1_lds_addr, stages_as_lds_addr), + (desc_a_init, desc_bn0_init, desc_bn1_init, desc_as_init), + (adv_a_i32, adv_b_i32, adv_b_i32, adv_as_i32), + ) + active_extra_pred_const = arith.select(tdm_wave_id == fx.Int32(3), fx.Int32(1), fx.Int32(0)) + active_extra_stage_lds_addr = stages_bs_lds_addr + active_extra_addr_lo = _dg0_lane(desc_bs_init, 2) + active_extra_addr_hi = _dg0_lane(desc_bs_init, 3) + active_extra_dgroup1 = desc_bs_init.dgroup1 + active_extra_adv_i32 = adv_bs_i32 elif const_expr(wave_specialized_tdm and use_ascale_vgpr): # A + A_scale via VGPR: only B (wave0) and B_scale (wave1) need TDM. # Remap: slot0=B, slot1=B_scale, slots 2,3 aliased (predicated off). @@ -2751,18 +2793,52 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): dgroup1_as = desc_as_init.dgroup1 dgroup1_bs = desc_bs_init.dgroup1 + def _pipeline_tensor_wait(outstanding=0): + if const_expr(_tdm_b_nsplit_scale_combo and outstanding > 0): + if_op = scf.IfOp(tdm_wave_id == fx.Int32(3), [], has_else=True) + with ir.InsertionPoint(if_op.then_block): + tdm_ops.tensor_wait(outstanding * 2) + scf.YieldOp([]) + with ir.InsertionPoint(if_op.else_block): + tdm_ops.tensor_wait(outstanding) + scf.YieldOp([]) + else: + tdm_ops.tensor_wait(outstanding) + def _pipeline_fence(outstanding=0): - pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) + if const_expr(_tdm_b_nsplit_scale_combo): + _pipeline_tensor_wait(outstanding) + if const_expr(use_cluster): + cluster.cluster_barrier() + else: + gpu.barrier() + else: + pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) def _pipeline_fence_signal(outstanding=0): - pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) + if const_expr(_tdm_b_nsplit_scale_combo): + _pipeline_tensor_wait(outstanding) + rocdl.s_barrier_signal(-1) + if const_expr(use_cluster): + cluster.cluster_signal_once_per_wg() + else: + pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) if const_expr(wave_specialized_tdm): - def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): + def _issue_active_tdm(load_stage, addr_box, extra_addr_box=None, k_prefetch=None): dg0 = _pack_dg0(active_pred_const, active_stage_lds_addr[load_stage], addr_box[0], active_addr_hi) tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) addr_box[0] = addr_box[0] + active_adv_i32 + if const_expr(_tdm_b_nsplit_scale_combo): + dg0_extra = _pack_dg0( + active_extra_pred_const, + active_extra_stage_lds_addr[load_stage], + extra_addr_box[0], + active_extra_addr_hi, + ) + tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0_extra, active_extra_dgroup1)) + extra_addr_box[0] = extra_addr_box[0] + active_extra_adv_i32 if k_prefetch is not None: _l2_prefetch(k_prefetch) @@ -2770,7 +2846,12 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): if const_expr(wave_specialized_tdm): for i in range_constexpr(pre_loaded): addr_box = [active_addr_lo] - _issue_active_tdm(i, addr_box) + if const_expr(_tdm_b_nsplit_scale_combo): + extra_addr_box = [active_extra_addr_lo] + _issue_active_tdm(i, addr_box, extra_addr_box) + active_extra_addr_lo = extra_addr_box[0] + else: + _issue_active_tdm(i, addr_box) active_addr_lo = addr_box[0] else: for i in range_constexpr(pre_loaded): @@ -2816,6 +2897,8 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): if const_expr(loop_iters > 0): if const_expr(wave_specialized_tdm): init_args = list(accs) + [active_addr_lo] + if const_expr(_tdm_b_nsplit_scale_combo): + init_args = init_args + [active_extra_addr_lo] if const_expr(_bvs_active): init_args = init_args + _bvs_ra + _bvs_rb if const_expr(_avr_active): @@ -2825,6 +2908,9 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): accs_in = list(state[:n_accs]) cur_addr_lo = state[n_accs] _state_off = n_accs + 1 + if const_expr(_tdm_b_nsplit_scale_combo): + cur_extra_addr_lo = state[_state_off] + _state_off += 1 if const_expr(_bvs_active): _ra0 = _state_off _ring_a = list(state[_ra0 : _ra0 + _bvs_D * _vs_tile_a]) @@ -2841,17 +2927,22 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): load_stage = (buf_idx + num_buffers - 1) % num_buffers addr_box = [cur_addr_lo] + if const_expr(_tdm_b_nsplit_scale_combo): + extra_addr_box = [cur_extra_addr_lo] + else: + extra_addr_box = None def _mid_tdm_ws( _ls=load_stage, _ab=addr_box, + _eb=extra_addr_box, _k_off=( split_k_base + loop_iter * arith.index(num_buffers * tile_k) + arith.index(buf_idx * tile_k) ), ): - _issue_active_tdm(_ls, _ab, k_prefetch=_k_off) + _issue_active_tdm(_ls, _ab, _eb, k_prefetch=_k_off) if const_expr(not use_ws_tdm_split_signal_overlap): _pipeline_fence_signal(outstanding=_fence_outstanding) @@ -2916,6 +3007,8 @@ def _late_tdm_ws_split_signal(): pf_a_data_scales=_cur_as, ) cur_addr_lo = addr_box[0] + if const_expr(_tdm_b_nsplit_scale_combo): + cur_extra_addr_lo = extra_addr_box[0] hot_loop_scheduler_scheduled() if const_expr(_bvs_active): @@ -2926,10 +3019,16 @@ def _late_tdm_ws_split_signal(): _avr_yield = _avr_ring_f + _avr_ring_s else: _avr_yield = [] - results = yield list(accs_in) + [cur_addr_lo] + _bvs_yield + _avr_yield + if const_expr(_tdm_b_nsplit_scale_combo): + _extra_yield = [cur_extra_addr_lo] + else: + _extra_yield = [] + results = yield list(accs_in) + [cur_addr_lo] + _extra_yield + _bvs_yield + _avr_yield accs = list(results[:n_accs]) active_addr_lo = results[n_accs] + if const_expr(_tdm_b_nsplit_scale_combo): + active_extra_addr_lo = results[n_accs + 1] else: init_args = list(accs) + [addr_lo_a, addr_lo_b, addr_lo_as, addr_lo_bs] @@ -3141,9 +3240,13 @@ def _emit_epi_addrs(): _tail_had_load = True if const_expr(wave_specialized_tdm): _tail_addr_box = [active_addr_lo] + if const_expr(_tdm_b_nsplit_scale_combo): + _tail_extra_addr_box = [active_extra_addr_lo] + else: + _tail_extra_addr_box = None - def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box): - _issue_active_tdm(_ls, _ab) + def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box, _eb=_tail_extra_addr_box): + _issue_active_tdm(_ls, _ab, _eb) _tail_mid_cb = _tail_mid_ws else: @@ -3190,6 +3293,8 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): if const_expr(_load_stage is not None): if const_expr(wave_specialized_tdm): active_addr_lo = _tail_addr_box[0] + if const_expr(_tdm_b_nsplit_scale_combo): + active_extra_addr_lo = _tail_extra_addr_box[0] else: addr_lo_a = _tail_ab[0][0] addr_lo_b = _tail_ab[1][0] @@ -3261,6 +3366,7 @@ def _emit_buffer_store(): b_streaming, scale_load_path, a_load_path, + b_split_load, fp8_schedule, ) @@ -3323,18 +3429,26 @@ def launch_mxscale_gemm( def compile_mxscale_gemm(**kw): """Backward-compatible wrapper: MX block-scale (E8M0) GEMM.""" + if "b_split_load" not in kw: + kw["b_split_load"] = bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) return compile_fp8fp4_gemm(scale_mode="mxscale", **kw) def compile_mxfp4_gemm(**kw): + if "b_split_load" not in kw: + kw["b_split_load"] = bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) return compile_fp8fp4_gemm(data_format="fp4", scale_mode="mxscale", **kw) def compile_mxfp8_gemm(**kw): + if "b_split_load" not in kw: + kw["b_split_load"] = bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) return compile_fp8fp4_gemm(data_format="fp8", scale_mode="mxscale", **kw) def compile_a8w4_gemm(**kw): + if "b_split_load" not in kw: + kw["b_split_load"] = bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) return compile_fp8fp4_gemm(data_format="a8w4", scale_mode="mxscale", **kw) diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index a88fd786..d84c67ec 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -469,6 +469,7 @@ def _run_mxscale_gemm_test( b_streaming=False, scale_load_path="tdm", a_load_path="tdm", + b_split_load=False, return_launch_fn=False, ): """Unified test body for FP4 and FP8.""" @@ -512,11 +513,13 @@ def _run_mxscale_gemm_test( mcast_str = f", cluster=({cluster_m},{cluster_n})" if cluster_m > 1 or cluster_n > 1 else "" tdm_str = ", tdm_store" if use_tdm_store else ", buffer_store" scale_load_str = "" if scale_load_path == "tdm" else f", scale_load={scale_load_path}" + a_load_str = "" if a_load_path == "tdm" else f", a_load={a_load_path}" + b_split_str = ", b_split_load" if b_split_load else "" pad_str = _format_kernel_pad(M, N, K, padded_shape) print( f"\nRunning {fmt_name} GEMM: M={M}, N={N}, K={K}{pad_str}, " f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}" - f"{mcast_str}{tdm_str}{scale_load_str}, preshuffle, out={out_dtype}" + f"{mcast_str}{tdm_str}{scale_load_str}{a_load_str}{b_split_str}, preshuffle, out={out_dtype}" ) # Generate data @@ -599,6 +602,7 @@ def _run_mxscale_gemm_test( b_streaming=b_streaming, scale_load_path=scale_load_path, a_load_path=a_load_path, + b_split_load=b_split_load, ) # Keep 2D — dynamic_layout=True packs shape as i32; flattening overflows for M*K >= 2^31. @@ -913,6 +917,38 @@ def test_a8w4_gemm_irregular_m_tile16(M, N, K, use_tdm_store): ) +@pytest.mark.parametrize( + "scale_load_path, a_load_path, b_split_load", + [ + ("tdm", "vgpr", False), + ("tdm", "vgpr", True), + ("tdm", "tdm", True), + ("vgpr", "tdm", False), + ], +) +def test_a8w4_small_m_load_path_variants(scale_load_path, a_load_path, b_split_load): + _run_mxscale_gemm_test( + "a8w4", + 1, + 256, + 2048, + 16, + 64, + 512, + 1, + 4, + num_buffers=4, + use_tdm_store=True, + out_dtype="bf16", + wave_specialized_tdm=True, + l2_prefetch_distance=0, + use_scale_opsel=False, + scale_load_path=scale_load_path, + a_load_path=a_load_path, + b_split_load=b_split_load, + ) + + @pytest.mark.parametrize( "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", [ @@ -1230,56 +1266,123 @@ def _l2_cache_bytes() -> int: return getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "L2_cache_size", 4 * 1024 * 1024) -def _rotate_slot_count(working_set_bytes: int, flush_l2: bool, cap: int = 256) -> int: - """Number of rotate-buffer copies so the pool exceeds the last-level cache.""" - if not flush_l2: - return 1 - POOL_TARGET = 1024 * 1024 * 1024 - target = max(_l2_cache_bytes() * 2, POOL_TARGET) - needed = -(-target // max(working_set_bytes, 1)) # ceil-div +def _make_l2_flush_buffer(flush_l2: bool, flush_mb: int) -> torch.Tensor | None: + """Allocate a scratch buffer used only to evict data from L2.""" + if not flush_l2 or flush_mb <= 0: + return None + nbytes = int(flush_mb) * 1024 * 1024 + if nbytes <= 0: + return None + nelem = max(1, nbytes // torch.empty((), dtype=torch.int32).element_size()) + cache = torch.empty(nelem, dtype=torch.int32, device="cuda") + cache.zero_() + torch.cuda.synchronize() + return cache + + +def _graph_rotate_slot_count(working_set_bytes: int, target_bytes: int = 0, cap: int = 512) -> int: + """Number of graph-captured buffer slots for cold-L2 graph replay.""" + target = max(_l2_cache_bytes() * 5, int(target_bytes), 1) + needed = 1 + math.ceil(target / max(working_set_bytes, 1)) return max(2, min(needed, cap)) -def _bench_kernel_us_cudagraph(run_slot, num_slots, warmup=10, iters=100): - """Per-launch timer via hipGraph: captures n_per_graph launches, replays them.""" - n_per_graph = max(num_slots, 20) +def _flush_l2_cache(cache: torch.Tensor | None): + if cache is not None: + cache.zero_() + + +def _iqr_trimmed_median_us(latencies_us: list[float]) -> float: + latencies = sorted(latencies_us) + n = len(latencies) + if n >= 8: + q1, q3 = latencies[n // 4], latencies[3 * n // 4] + iqr = q3 - q1 + lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr + filtered = [x for x in latencies if lo <= x <= hi] + if filtered: + latencies = filtered + return latencies[len(latencies) // 2] + + +def _bench_kernel_us_cudagraph( + run_slot, + num_slots=1, + warmup=10, + iters=100, + n_per_graph=20, + post_run_slot=None, +): + """Per-launch timer via hipGraph.""" + cold_rotate = num_slots > 1 + n_per_graph = num_slots if cold_rotate else (1 if post_run_slot is not None else max(1, n_per_graph)) capture_stream = torch.cuda.Stream() capture_stream.wait_stream(torch.cuda.current_stream()) + def post_run_all_slots(): + if post_run_slot is not None: + for slot in range(num_slots): + post_run_slot(slot) + + def run_direct_graph_body(): + if cold_rotate: + for slot in range(num_slots): + run_slot(slot) + else: + for _ in range(n_per_graph): + run_slot(0) + + pre_capture_warmup = max(warmup, num_slots if cold_rotate else warmup) with torch.cuda.stream(capture_stream): - for i in range(warmup): - run_slot(i) + post_run_all_slots() + for i in range(pre_capture_warmup): + slot = i % num_slots + run_slot(slot) + if post_run_slot is not None: + post_run_slot(slot) torch.cuda.current_stream().wait_stream(capture_stream) torch.cuda.synchronize() + graphs = [] g = torch.cuda.CUDAGraph() - with torch.cuda.graph(g, stream=capture_stream): - for j in range(n_per_graph): - run_slot(j) + with torch.cuda.stream(capture_stream): + with torch.cuda.graph(g, stream=capture_stream): + run_direct_graph_body() + graphs.append(g) torch.cuda.synchronize() - # Sanity guard against empty graph capture. + def replay_graph_body(): + graphs[0].replay() + ref_start = torch.cuda.Event(enable_timing=True) ref_end = torch.cuda.Event(enable_timing=True) - ref_start.record() - for j in range(n_per_graph): - run_slot(j) - ref_end.record() + with torch.cuda.stream(capture_stream): + run_direct_graph_body() + post_run_all_slots() + ref_start.record() + run_direct_graph_body() + ref_end.record() + post_run_all_slots() torch.cuda.synchronize() ref_per_launch_us = ref_start.elapsed_time(ref_end) * 1e3 / n_per_graph rep_start = torch.cuda.Event(enable_timing=True) rep_end = torch.cuda.Event(enable_timing=True) - rep_start.record() - g.replay() - rep_end.record() + with torch.cuda.stream(capture_stream): + replay_graph_body() + post_run_all_slots() + rep_start.record() + replay_graph_body() + rep_end.record() + post_run_all_slots() torch.cuda.synchronize() first_replay_per_launch_us = rep_start.elapsed_time(rep_end) * 1e3 / n_per_graph print( f"SANITY_GRAPH,n_per_graph={n_per_graph}," f"ref_per_launch_us={ref_per_launch_us:.3f}," - f"first_replay_per_launch_us={first_replay_per_launch_us:.3f}", + f"first_replay_per_launch_us={first_replay_per_launch_us:.3f}," + f"cold_rotate_slots={num_slots if cold_rotate else 0}", file=sys.stderr, flush=True, ) @@ -1288,48 +1391,54 @@ def _bench_kernel_us_cudagraph(run_slot, num_slots, warmup=10, iters=100): f"hipGraph replay per-launch={first_replay_per_launch_us:.3f}us " f"<< ref direct-launch={ref_per_launch_us:.3f}us. " f"Graph capture likely empty (stream mismatch?)." - ) + ) + + # Stabilize graph replay before collecting samples. + with torch.cuda.stream(capture_stream): + replay_graph_body() + post_run_all_slots() + torch.cuda.synchronize() - start_ev = torch.cuda.Event(enable_timing=True) - end_ev = torch.cuda.Event(enable_timing=True) - start_ev.record() - for _ in range(iters): - g.replay() - end_ev.record() + start_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + with torch.cuda.stream(capture_stream): + for i in range(iters): + start_ev[i].record() + replay_graph_body() + end_ev[i].record() + post_run_all_slots() torch.cuda.synchronize() - return start_ev.elapsed_time(end_ev) * 1e3 / (iters * n_per_graph) + latencies_us = [start_ev[i].elapsed_time(end_ev[i]) * 1e3 / n_per_graph for i in range(iters)] + return _iqr_trimmed_median_us(latencies_us) -def _bench_kernel_us(run_slot, num_slots, warmup=10, iters=50): - """Per-iter CUDA-event timer with rotating buffers (cold L2) + IQR-trimmed median.""" - del num_slots # rotation is handled inside run_slot; kept for call-site symmetry - for i in range(warmup): - run_slot(i) +def _bench_kernel_us(run_once, flush_cache=None, warmup=10, iters=50, post_run=None): + """Per-iter CUDA-event timer with optional pre-launch L2 flush + IQR-trimmed median.""" + if post_run is not None: + post_run() + for _ in range(warmup): + _flush_l2_cache(flush_cache) + run_once() + if post_run is not None: + post_run() torch.cuda.synchronize() start_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] end_ev = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] for i in range(iters): + _flush_l2_cache(flush_cache) start_ev[i].record() - run_slot(i) + run_once() end_ev[i].record() + if post_run is not None: + post_run() torch.cuda.synchronize() - latencies = sorted(start_ev[i].elapsed_time(end_ev[i]) * 1e3 for i in range(iters)) - - n = len(latencies) - if n >= 8: - q1, q3 = latencies[n // 4], latencies[3 * n // 4] - iqr = q3 - q1 - lo, hi = q1 - 1.5 * iqr, q3 + 1.5 * iqr - filtered = [x for x in latencies if lo <= x <= hi] - if filtered: - latencies = filtered - - return latencies[len(latencies) // 2] + latencies_us = [start_ev[i].elapsed_time(end_ev[i]) * 1e3 for i in range(iters)] + return _iqr_trimmed_median_us(latencies_us) def reference_ptpc_gemm(data_format, a, b, sa, sb, M, N, K): @@ -1841,16 +1950,26 @@ def _run_benchmark(args): print( f" Buffers={args.num_buffers}, out={args.out_dtype}, " f"opsel={args.use_scale_opsel}, inst_prefetch={args.inst_prefetch}, " - f"scale_load={args.scale_load_path}, a_load={args.a_load_path}" + f"scale_load={args.scale_load_path}, a_load={args.a_load_path}, " + f"b_split_load={args.b_split_load}" ) + if args.warmup < 0: + raise ValueError(f"--warmup must be >= 0, got {args.warmup}") + if args.iters <= 0: + raise ValueError(f"--iters must be > 0, got {args.iters}") + if args.l2_flush_mb < 0: + raise ValueError(f"--l2-flush-mb must be >= 0, got {args.l2_flush_mb}") if args.split_k > 1: print(f" Split-K={args.split_k} (atomic accumulate, buffer-store epilogue)") + print(" Split-K timing excludes the required C reset from the reported kernel time") if args.no_flush_l2: l2_flush_label = "OFF (hot L2, --no-flush-l2)" + elif args.l2_flush_mb == 0: + l2_flush_label = "OFF (hot L2, --l2-flush-mb=0)" elif getattr(args, "use_graph", False): - l2_flush_label = "OFF (graph replay is warm/L2-resident; use eager for cold HBM)" + l2_flush_label = "ON (graph rotating buffers; compare against --no-flush-l2)" else: - l2_flush_label = "ON (rotate buffers, cold HBM)" + l2_flush_label = f"ON ({args.l2_flush_mb} MiB scratch clear before timed launches)" print(f" Warmup={args.warmup}, Iters={args.iters}, L2 defeat={l2_flush_label}") if is_ptpc: # compile_ptpc_gemm forces these internally; flag the ones the user set off-default. @@ -1865,6 +1984,8 @@ def _run_benchmark(args): _ptpc_ignored.append(f"--scale-load-path {args.scale_load_path}") if args.b_streaming: _ptpc_ignored.append("--b-streaming") + if args.b_split_load: + _ptpc_ignored.append("--b-split-load") if _ptpc_ignored: print(f" Note: PTPC ignores (forced internally): {', '.join(_ptpc_ignored)}") print("=" * 72) @@ -1991,6 +2112,7 @@ def _run_benchmark(args): b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, a_load_path=args.a_load_path, + b_split_load=args.b_split_load, ) compiled_exe = flyc.compile( @@ -2027,43 +2149,111 @@ def run_one(c_, a_, b_, as_, bs_): compile_ms = (time.perf_counter() - t0) * 1e3 print(f" Compile + first launch: {compile_ms:.0f} ms") - flush_l2 = not args.no_flush_l2 + use_graph = getattr(args, "use_graph", False) + flush_l2 = not args.no_flush_l2 and args.l2_flush_mb > 0 working_set = sum(t.numel() * t.element_size() for t in (a_gpu, b_gpu, as_gpu, bs_gpu, c_gpu)) - num_slots = _rotate_slot_count(working_set, flush_l2) - a_pool = [a_gpu] + [a_gpu.clone() for _ in range(num_slots - 1)] - b_pool = [b_gpu] + [b_gpu.clone() for _ in range(num_slots - 1)] - as_pool = [as_gpu] + [as_gpu.clone() for _ in range(num_slots - 1)] - bs_pool = [bs_gpu] + [bs_gpu.clone() for _ in range(num_slots - 1)] - c_pool = [c_gpu] + [torch.zeros_like(c_gpu) for _ in range(num_slots - 1)] - print( - f" Rotate buffers: {num_slots} slot(s), pool={working_set * num_slots / 1e6:.1f} MB " - f"(working set {working_set / 1e6:.1f} MB)" + (" [HOT L2: --no-flush-l2]" if num_slots == 1 else "") - ) + flush_cache = None if use_graph else _make_l2_flush_buffer(flush_l2, args.l2_flush_mb) + graph_num_slots = 1 + if use_graph and flush_l2: + graph_rotate_target = max(_l2_cache_bytes() * 5, int(args.l2_flush_mb) * 1024 * 1024) + graph_num_slots = _graph_rotate_slot_count(working_set, graph_rotate_target) + graph_eviction_bytes = max(0, graph_num_slots - 1) * working_set + cap_note = ( + " [WARNING: capped below target]" + if graph_eviction_bytes < graph_rotate_target + else "" + ) + print( + f" L2 defeat: graph rotating buffers, slots={graph_num_slots}, " + f"pool={working_set * graph_num_slots / 1e6:.1f} MB " + f"(evict distance={graph_eviction_bytes / 1e6:.1f} MB, " + f"target={graph_rotate_target / 1e6:.1f} MB, " + f"reported L2={_l2_cache_bytes() / 1e6:.1f} MB, " + f"working set {working_set / 1e6:.1f} MB){cap_note}" + ) + elif flush_cache is None: + print(f" L2 defeat: OFF (hot-cache timing), working set {working_set / 1e6:.1f} MB") + else: + print( + f" L2 defeat: ON, scratch={flush_cache.numel() * flush_cache.element_size() / 1e6:.1f} MB " + f"(reported L2={_l2_cache_bytes() / 1e6:.1f} MB, working set {working_set / 1e6:.1f} MB)" + ) - def run_slot(i): - s = i % num_slots - run_one(c_pool[s], a_pool[s], b_pool[s], as_pool[s], bs_pool[s]) + clear_output_each_run = args.split_k > 1 + + def run_bench_once(): + run_one(c_gpu, a_gpu, b_gpu, as_gpu, bs_gpu) + + def reset_bench_output(): + c_gpu.zero_() - use_graph = getattr(args, "use_graph", False) if use_graph: - if not args.no_flush_l2: + if graph_num_slots == 1: + print(f"[2/3] Warming up ({args.warmup} iters) + bench via hot-cache hipGraph ({args.iters} replays)...") + us = _bench_kernel_us_cudagraph( + lambda _slot: run_bench_once(), + num_slots=1, + warmup=args.warmup, + iters=args.iters, + post_run_slot=(lambda _slot: reset_bench_output()) if clear_output_each_run else None, + ) + else: + a_pool = [a_gpu] + [a_gpu.clone() for _ in range(graph_num_slots - 1)] + b_pool = [b_gpu] + [b_gpu.clone() for _ in range(graph_num_slots - 1)] + as_pool = [as_gpu] + [as_gpu.clone() for _ in range(graph_num_slots - 1)] + bs_pool = [bs_gpu] + [bs_gpu.clone() for _ in range(graph_num_slots - 1)] + c_pool = [c_gpu] + [torch.zeros_like(c_gpu) for _ in range(graph_num_slots - 1)] + + def run_graph_slot(slot): + s = slot % graph_num_slots + run_one(c_pool[s], a_pool[s], b_pool[s], as_pool[s], bs_pool[s]) + + def reset_graph_slot(slot): + c_pool[slot % graph_num_slots].zero_() + print( - " WARNING: hipGraph capture aliases the kernel-param buffer across " - "replayed launches, so the rotate buffers above do NOT take effect under " - "replay -- this number is WARM (L2-resident). Use eager mode (drop " - "--use-graph) for the cold-HBM number." + f"[2/3] Warming up ({args.warmup} iters) + bench via rotating-buffer hipGraph " + f"({args.iters} replays × {graph_num_slots} launches/replay, " + f"rotating graph-captured buffer slots)..." + ) + us = _bench_kernel_us_cudagraph( + run_graph_slot, + num_slots=graph_num_slots, + warmup=args.warmup, + iters=args.iters, + post_run_slot=reset_graph_slot if clear_output_each_run else None, ) - print(f"[2/3] Warming up ({args.warmup} iters) + bench via hipGraph " f"({args.iters} replays)...") - us = _bench_kernel_us_cudagraph(run_slot, num_slots, warmup=args.warmup, iters=args.iters) else: print(f"[2/3] Warming up ({args.warmup} iters) + benchmarking ({args.iters} iters)...") - us = _bench_kernel_us(run_slot, num_slots, warmup=args.warmup, iters=args.iters) + us = _bench_kernel_us( + run_bench_once, + flush_cache, + warmup=args.warmup, + iters=args.iters, + post_run=reset_bench_output if clear_output_each_run else None, + ) + + WMMA_K = 128 + WMMA_N_EFF = 32 if is_fp4 else 16 + wmma_m_rep = warp_tile_m // 16 + wmma_n_rep = warp_tile_n // WMMA_N_EFF + k_wmma_steps = tile_k // WMMA_K + wmma_per_tile = wmma_m_rep * wmma_n_rep * k_wmma_steps + m_tiles = (padded_m + tile_m - 1) // tile_m + n_tiles = (padded_n + tile_n - 1) // tile_n + k_tiles = padded_k // tile_k + k_tiles_local = (padded_k // args.split_k) // tile_k + # Sequential WMMAs per workgroup (all k_tiles execute sequentially) + seq_wmma = k_tiles_local * wmma_per_tile + us_per_wmma = us / seq_wmma if seq_wmma > 0 else 0 logical_flops = 2.0 * M * N * K - kernel_flops = 2.0 * padded_m * padded_n * padded_k + tile_m_covered = m_tiles * tile_m + tile_n_covered = n_tiles * tile_n + tile_flops = 2.0 * tile_m_covered * tile_n_covered * padded_k time_s = us / 1e6 logical_tflops = logical_flops / time_s / 1e12 if time_s > 0 else 0.0 - kernel_tflops = kernel_flops / time_s / 1e12 if time_s > 0 else 0.0 + tile_tflops = tile_flops / time_s / 1e12 if time_s > 0 else 0.0 bytes_a = padded_m * padded_k // PACK_A bytes_b = padded_n * padded_k // PACK_B @@ -2076,26 +2266,12 @@ def run_slot(i): read_bw_gbs = read_bytes / 1e9 / time_s if time_s > 0 else 0.0 write_bw_gbs = write_bytes / 1e9 / time_s if time_s > 0 else 0.0 - WMMA_K = 128 - WMMA_N_EFF = 32 if is_fp4 else 16 - wmma_m_rep = warp_tile_m // 16 - wmma_n_rep = warp_tile_n // WMMA_N_EFF - k_wmma_steps = tile_k // WMMA_K - wmma_per_tile = wmma_m_rep * wmma_n_rep * k_wmma_steps - m_tiles = padded_m // tile_m - n_tiles = padded_n // tile_n - k_tiles = padded_k // tile_k - k_tiles_local = (padded_k // args.split_k) // tile_k - # Sequential WMMAs per workgroup (all k_tiles execute sequentially) - seq_wmma = k_tiles_local * wmma_per_tile - us_per_wmma = us / seq_wmma if seq_wmma > 0 else 0 - print("\n[3/3] Results:") print(f" Kernel time: {us:.1f} us ({us / 1e3:.4f} ms)") - if not needs_pad: - print(f" TFLOPS: {kernel_tflops:.4f}") + if tile_flops == logical_flops: + print(f" TFLOPS: {logical_tflops:.4f}") else: - print(f" TFLOPS: {logical_tflops:.4f} (logical), {kernel_tflops:.4f} (kernel)") + print(f" TFLOPS: {logical_tflops:.4f} (logical), {tile_tflops:.4f} (tile-covered)") print(f" Bandwidth: {bw_gbs:.1f} GB/s " f"(read: {read_bw_gbs:.1f} + write: {write_bw_gbs:.1f})") print( f" Bytes moved: {bytes_moved / 1e6:.1f} MB " @@ -2117,8 +2293,7 @@ def run_slot(i): print(f" WARNING: {us_per_wmma/1000:.1f} ms/WMMA indicates " f"WMMA_SCALE trap-handler emulation") print("=" * 72) - reported_tflops = kernel_tflops if not needs_pad else logical_tflops - return us, reported_tflops, bw_gbs + return us, logical_tflops, bw_gbs def _run_graph_verify(args): @@ -2209,6 +2384,7 @@ def _run_graph_verify(args): b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, a_load_path=args.a_load_path, + b_split_load=args.b_split_load, ) c_flat = c_gpu.contiguous() @@ -2334,6 +2510,14 @@ def launch(): ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument("--b-streaming", action="store_true", default=False) + parser.add_argument( + "--b-split-load", + action="store_true", + default=False, + help="Split the B TDM load by N groups. With --a-load-path vgpr this reuses " + "wave0 for B half 0; with TDM A/scale it uses wave0=A, wave1/2=B halves, " + "wave3=A_scale+B_scale.", + ) parser.add_argument( "--atomic-barrier-enable", action="store_true", @@ -2350,17 +2534,24 @@ def launch(): "--no-flush-l2", action="store_true", default=False, - help="Disable the rotate-buffer L2 defeat (use a single hot buffer) for a " - "warm-cache measurement. Applies to both eager and --use-graph modes.", + help="Disable L2 defeat for a hot-cache measurement. Applies to both eager " + "and --use-graph modes.", + ) + parser.add_argument( + "--l2-flush-mb", + type=int, + default=256, + help="Scratch buffer size in MiB for eager cold-cache timing, and the " + "minimum address-rotation target for --use-graph rotating-buffer timing.", ) parser.add_argument( "--use-graph", action="store_true", default=False, help="Time via hipGraph capture+replay to strip host launch overhead from " - "per-launch latency. NOTE: graph replay measures the WARM (L2-resident) " - "regime -- rotate buffers do not survive hipGraph capture, so use the eager " - "path (drop --use-graph) for the cold-HBM number.", + "per-launch latency. By default this captures a rotating-buffer graph to " + "avoid replaying the same tensor addresses; compare with --no-flush-l2 to " + "separate address-reuse/cache effects from launch overhead.", ) parser.add_argument( "--verify-graph", @@ -2432,4 +2623,5 @@ def launch(): b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, a_load_path=args.a_load_path, + b_split_load=args.b_split_load, ) From 83f8273cf44bbffd434f53956d992926e75bfbde Mon Sep 17 00:00:00 2001 From: aoli26 Date: Fri, 12 Jun 2026 08:45:43 +0000 Subject: [PATCH 04/16] add decode shape cluster expr --- lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp | 73 +++++++++++-- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 114 +++++++++++--------- 2 files changed, 131 insertions(+), 56 deletions(-) diff --git a/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp b/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp index a8b2d3f3..21d302c9 100644 --- a/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp +++ b/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp @@ -17,8 +17,15 @@ #include #include "hip/hip_runtime.h" +#include "hip/hip_version.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" +// TODO(gfx1250-cluster): TEMPORARY. This version check does NOT actually guarantee +// cluster launch support (this version doesn't support it either). It's a debug-only +// hack for gfx1250 bring-up and will likely break other environments. Replace with a +// real capability check once the supported version/path is confirmed. +#define FLY_HIP_HAS_CLUSTER_LAUNCH (HIP_VERSION >= 70000000) + #define HIP_REPORT_IF_ERROR(expr) \ [](hipError_t result) { \ if (!result) \ @@ -68,7 +75,52 @@ extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, intptr_t cluster intptr_t blockY, intptr_t blockZ, int32_t smem, hipStream_t stream, void **params, void **extra, size_t /*paramsCount*/) { -#ifdef hipLaunchAttributeClusterDimension + const bool requestedRealCluster = (clusterX > 1) || (clusterY > 1) || (clusterZ > 1); + +#if FLY_HIP_HAS_CLUSTER_LAUNCH + hipStreamCaptureStatus capStatus = hipStreamCaptureStatusNone; + hipGraph_t capGraph = nullptr; + const hipGraphNode_t *capDeps = nullptr; + size_t numCapDeps = 0; + if (hipStreamGetCaptureInfo_v2(stream, &capStatus, /*id_out=*/nullptr, &capGraph, &capDeps, + &numCapDeps) == hipSuccess && + capStatus == hipStreamCaptureStatusActive) { + hipKernelNodeParams nodeParams{}; + nodeParams.func = reinterpret_cast(function); + nodeParams.gridDim = dim3(static_cast(gridX), static_cast(gridY), + static_cast(gridZ)); + nodeParams.blockDim = dim3(static_cast(blockX), static_cast(blockY), + static_cast(blockZ)); + nodeParams.sharedMemBytes = static_cast(smem); + nodeParams.kernelParams = params; + nodeParams.extra = extra; + + hipGraphNode_t node = nullptr; + hipError_t addErr = hipGraphAddKernelNode(&node, capGraph, capDeps, numCapDeps, &nodeParams); + if (addErr != hipSuccess) { + // Fail loudly: a silent empty graph would report bogus ~0us replay times. + fprintf(stderr, + "[mgpuLaunchClusterKernel] hipGraphAddKernelNode failed (err=%d) during stream " + "capture for cluster=(%ld,%ld,%ld); cluster kernels cannot be captured into a " + "hipGraph on this HIP build.\n", + static_cast(addErr), static_cast(clusterX), static_cast(clusterY), + static_cast(clusterZ)); + HIP_REPORT_IF_ERROR(addErr); + return; + } + if (requestedRealCluster) { + hipKernelNodeAttrValue attrVal{}; + attrVal.clusterDim.x = static_cast(clusterX); + attrVal.clusterDim.y = static_cast(clusterY); + attrVal.clusterDim.z = static_cast(clusterZ); + HIP_REPORT_IF_ERROR( + hipGraphKernelNodeSetAttribute(node, hipLaunchAttributeClusterDimension, &attrVal)); + } + HIP_REPORT_IF_ERROR( + hipStreamUpdateCaptureDependencies(stream, &node, 1, hipStreamSetCaptureDependencies)); + return; + } + hipLaunchAttribute attrs[1]; attrs[0].id = hipLaunchAttributeClusterDimension; attrs[0].value.clusterDim.x = static_cast(clusterX); @@ -91,7 +143,6 @@ extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, intptr_t cluster if (err == hipSuccess) return; - const bool requestedRealCluster = (clusterX > 1) || (clusterY > 1) || (clusterZ > 1); if (requestedRealCluster) { fprintf(stderr, "[mgpuLaunchClusterKernel] hipDrvLaunchKernelEx failed (err=%d) " @@ -103,6 +154,7 @@ extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, intptr_t cluster return; } + // cluster=(1,1,1) carries no cluster semantics — plain launch is equivalent. fprintf(stderr, "[mgpuLaunchClusterKernel] hipDrvLaunchKernelEx failed (err=%d) " "for cluster=(1,1,1); falling back to hipModuleLaunchKernel.\n", @@ -110,15 +162,20 @@ extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, intptr_t cluster HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); #else - // Cluster launch not supported by this HIP version; ignore cluster dims - // and fall back to regular kernel launch. - if ((clusterX > 1) || (clusterY > 1) || (clusterZ > 1)) { + // HIP < 7.0: no cluster API. Refuse to downgrade silently — kernel relies on + // cluster semantics (multicast, cluster_barrier) that a plain launch breaks. + if (requestedRealCluster) { fprintf(stderr, "[mgpuLaunchClusterKernel] cluster=(%ld,%ld,%ld) requested but " - "hipLaunchAttributeClusterDimension is not available in this HIP " - "version; falling back to hipModuleLaunchKernel.\n", - static_cast(clusterX), static_cast(clusterY), static_cast(clusterZ)); + "FlyDSL was built against HIP %d (need HIP >= 7.0 / ROCm >= 7.0 " + "for hipDrvLaunchKernelEx + hipLaunchAttributeClusterDimension). " + "Aborting.\n", + static_cast(clusterX), static_cast(clusterY), static_cast(clusterZ), + HIP_VERSION); + HIP_REPORT_IF_ERROR(hipErrorNotSupported); + return; } + // cluster=(1,1,1): plain launch is equivalent. HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); #endif diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index d84c67ec..6be1b0a8 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -1386,11 +1386,12 @@ def replay_graph_body(): file=sys.stderr, flush=True, ) - if first_replay_per_launch_us < 1.0 and ref_per_launch_us > 2.0: + if (ref_per_launch_us > 2.0 and first_replay_per_launch_us < 0.25 * ref_per_launch_us + and first_replay_per_launch_us < 1.0): raise RuntimeError( f"hipGraph replay per-launch={first_replay_per_launch_us:.3f}us " f"<< ref direct-launch={ref_per_launch_us:.3f}us. " - f"Graph capture likely empty (stream mismatch?)." + f"Graph capture likely empty (uncaptured cluster launch or stream mismatch?)." ) # Stabilize graph replay before collecting samples. @@ -2528,6 +2529,13 @@ def launch(): parser.add_argument( "--benchmark", action="store_true", default=False, help="Run benchmark mode (timing only, no correctness check)" ) + parser.add_argument( + "--verify", + action="store_true", + default=False, + help="With --benchmark, also run the correctness check before timing. " + "Without --benchmark, runs always verify and this flag is a no-op.", + ) parser.add_argument("--warmup", type=int, default=5) parser.add_argument("--iters", type=int, default=20) parser.add_argument( @@ -2572,56 +2580,66 @@ def launch(): if args.scale_mode == "ptpc" and args.verify_graph: raise SystemExit("--scale-mode ptpc does not support --verify-graph") + def _run_correctness_test(): + """Run the functional test (computes a reference and asserts correctness).""" + if args.scale_mode == "ptpc": + _run_ptpc_gemm_test( + args.M, + args.N, + args.K, + args.tile_m, + args.tile_n, + args.tile_k, + args.m_warp, + args.n_warp, + num_buffers=args.num_buffers, + out_dtype=args.out_dtype, + data_format=args.data_format, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + split_k=args.split_k, + ) + else: + use_tdm_store = not args.no_tdm_store and args.split_k == 1 + _run_mxscale_gemm_test( + args.data_format, + args.M, + args.N, + args.K, + args.tile_m, + args.tile_n, + args.tile_k, + args.m_warp, + args.n_warp, + num_buffers=args.num_buffers, + use_tdm_store=use_tdm_store, + out_dtype=args.out_dtype, + wave_specialized_tdm=args.wave_spec_tdm, + split_k=args.split_k, + use_scale_opsel=args.use_scale_opsel, + l2_prefetch_distance=args.l2_prefetch_distance, + cluster_m=args.cluster_m, + cluster_n=args.cluster_n, + inst_prefetch=args.inst_prefetch, + waves_per_eu=args.waves_per_eu, + expert_sched_mode=args.expert_sched_mode, + b_streaming=args.b_streaming, + scale_load_path=args.scale_load_path, + a_load_path=args.a_load_path, + b_split_load=args.b_split_load, + ) + if args.verify_graph: _run_graph_verify(args) if not args.benchmark: sys.exit(0) if args.benchmark: + # Benchmark defaults to timing-only; --verify opts into a correctness check first. + if args.verify: + print("Verifying correctness before benchmark (--verify)...") + _run_correctness_test() _run_benchmark(args) - elif args.scale_mode == "ptpc": - _run_ptpc_gemm_test( - args.M, - args.N, - args.K, - args.tile_m, - args.tile_n, - args.tile_k, - args.m_warp, - args.n_warp, - num_buffers=args.num_buffers, - out_dtype=args.out_dtype, - data_format=args.data_format, - l2_prefetch_distance=args.l2_prefetch_distance, - cluster_m=args.cluster_m, - cluster_n=args.cluster_n, - split_k=args.split_k, - ) else: - use_tdm_store = not args.no_tdm_store and args.split_k == 1 - _run_mxscale_gemm_test( - args.data_format, - args.M, - args.N, - args.K, - args.tile_m, - args.tile_n, - args.tile_k, - args.m_warp, - args.n_warp, - num_buffers=args.num_buffers, - use_tdm_store=use_tdm_store, - out_dtype=args.out_dtype, - wave_specialized_tdm=args.wave_spec_tdm, - split_k=args.split_k, - use_scale_opsel=args.use_scale_opsel, - l2_prefetch_distance=args.l2_prefetch_distance, - cluster_m=args.cluster_m, - cluster_n=args.cluster_n, - inst_prefetch=args.inst_prefetch, - waves_per_eu=args.waves_per_eu, - expert_sched_mode=args.expert_sched_mode, - b_streaming=args.b_streaming, - scale_load_path=args.scale_load_path, - a_load_path=args.a_load_path, - b_split_load=args.b_split_load, - ) + # Non-benchmark runs always verify. + _run_correctness_test() From e00bf343446aa803fdff9dc82a8bf25e8ce71b3f Mon Sep 17 00:00:00 2001 From: aoli26 Date: Fri, 12 Jun 2026 14:54:43 +0000 Subject: [PATCH 05/16] gemm_fp8fp4_gfx1250: support cluster launch through hipGraph - runtime: capture cluster kernels into hipGraph nodes via hipGraphAddKernelNode, relying on baked-in amdgpu-cluster-dims (graph-node cluster attribute is unsupported on this HIP build) - test: parametrize test_mxscale_gemm_cudagraph with cluster cases; run functional test through hipGraph (graph vs ref) when --use-graph is set, covering --use-graph and --benchmark --use-graph --verify --- lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp | 8 - tests/kernels/test_gemm_fp8fp4_gfx1250.py | 182 +++++++++++++++----- 2 files changed, 141 insertions(+), 49 deletions(-) diff --git a/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp b/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp index 21d302c9..66a3b68d 100644 --- a/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp +++ b/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp @@ -108,14 +108,6 @@ extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, intptr_t cluster HIP_REPORT_IF_ERROR(addErr); return; } - if (requestedRealCluster) { - hipKernelNodeAttrValue attrVal{}; - attrVal.clusterDim.x = static_cast(clusterX); - attrVal.clusterDim.y = static_cast(clusterY); - attrVal.clusterDim.z = static_cast(clusterZ); - HIP_REPORT_IF_ERROR( - hipGraphKernelNodeSetAttribute(node, hipLaunchAttributeClusterDimension, &attrVal)); - } HIP_REPORT_IF_ERROR( hipStreamUpdateCaptureDependencies(stream, &node, 1, hipStreamSetCaptureDependencies)); return; diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 6be1b0a8..029cfade 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -112,7 +112,7 @@ def preshuffle_e8m0_scale_coalesced_general( orig_row = mb * warp_tile + orig_rep * WMMA_DIM + L # Clamp out-of-range (guard) rows so the gather stays in bounds; masked below. orig_row = torch.where(valid, orig_row, torch.zeros_like(orig_row)) - colg = (kt * KS + ks) # group-of-4 column index into the K dimension + colg = kt * KS + ks # group-of-4 column index into the K dimension # Gather the 4 spw bytes for each (row, colg): scale viewed as [rows, KG*KS, 4]. scale_g = scale.view(rows, KG * KS, 4) row_idx, colg_idx = torch.broadcast_tensors(orig_row, colg) @@ -160,7 +160,9 @@ def preshuffle_e8m0_scale( return g.reshape(-1, k_groups * k_wmma_steps * wmma_rep * SCALES_PER_WMMA) -def preshuffle_scale_for_load_path(scale, warp_tile, skt, *, scale_load_path, data_format, ref_segmented, row_align=None): +def preshuffle_scale_for_load_path( + scale, warp_tile, skt, *, scale_load_path, data_format, ref_segmented, row_align=None +): """Host scale preshuffle matching the kernel's selected scale_load_path. - 'tdm': interleaved TDM/LDS layout. @@ -471,6 +473,7 @@ def _run_mxscale_gemm_test( a_load_path="tdm", b_split_load=False, return_launch_fn=False, + use_graph=False, ): """Unified test body for FP4 and FP8.""" is_fp4 = data_format == "fp4" @@ -515,11 +518,12 @@ def _run_mxscale_gemm_test( scale_load_str = "" if scale_load_path == "tdm" else f", scale_load={scale_load_path}" a_load_str = "" if a_load_path == "tdm" else f", a_load={a_load_path}" b_split_str = ", b_split_load" if b_split_load else "" + graph_str = ", graph" if use_graph else "" pad_str = _format_kernel_pad(M, N, K, padded_shape) print( f"\nRunning {fmt_name} GEMM: M={M}, N={N}, K={K}{pad_str}, " f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}" - f"{mcast_str}{tdm_str}{scale_load_str}{a_load_str}{b_split_str}, preshuffle, out={out_dtype}" + f"{mcast_str}{tdm_str}{scale_load_str}{a_load_str}{b_split_str}{graph_str}, preshuffle, out={out_dtype}" ) # Generate data @@ -554,17 +558,34 @@ def _run_mxscale_gemm_test( warp_tile_m = tile_m // m_warp warp_tile_n = tile_n // n_warp _ref_seg = is_ref_segmented_lds_layout( - data_format=data_format, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, m_warp=m_warp, n_warp=n_warp, - num_buffers=num_buffers, split_k=split_k, wave_specialized_tdm=wave_specialized_tdm, + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=num_buffers, + split_k=split_k, + wave_specialized_tdm=wave_specialized_tdm, use_scale_opsel=use_scale_opsel, ) a_scale = preshuffle_scale_for_load_path( - a_scale, warp_tile_m, skt, scale_load_path=scale_load_path, data_format=data_format, - ref_segmented=_ref_seg, row_align=tile_m, + a_scale, + warp_tile_m, + skt, + scale_load_path=scale_load_path, + data_format=data_format, + ref_segmented=_ref_seg, + row_align=tile_m, ) b_scale = preshuffle_scale_for_load_path( - b_scale, warp_tile_n, skt, scale_load_path=scale_load_path, data_format=data_format, - ref_segmented=_ref_seg, row_align=tile_n, + b_scale, + warp_tile_n, + skt, + scale_load_path=scale_load_path, + data_format=data_format, + ref_segmented=_ref_seg, + row_align=tile_n, ) # Preshuffle B data @@ -612,7 +633,7 @@ def _run_mxscale_gemm_test( as_flat = as_gpu.contiguous() bs_flat = bs_gpu.contiguous() - flyc.compile( + compiled_exe = flyc.compile( launch_fn, c_flat, a_flat, @@ -625,6 +646,35 @@ def _run_mxscale_gemm_test( padded_n, torch.cuda.current_stream(), ) + + if use_graph: + + def _launch(): + compiled_exe( + c_flat, + a_flat, + b_flat, + as_flat, + bs_flat, + padded_m, + padded_n, + padded_k, + padded_n, + torch.cuda.current_stream(), + ) + + g = torch.cuda.CUDAGraph() + s = torch.cuda.Stream() + s.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(s): + _launch() + torch.cuda.current_stream().wait_stream(s) + torch.cuda.synchronize() + c_gpu.zero_() + with torch.cuda.graph(g, stream=s): + _launch() + c_gpu.zero_() + g.replay() torch.cuda.synchronize() c_out = c_gpu[:M, :N].to(torch_out_dtype).cpu() @@ -1135,14 +1185,23 @@ def test_mxfp4_gemm_mcast( @pytest.mark.parametrize( - "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", + "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", [ - ("fp8", 128, 256, 256, 128, 256, 128, 2, 2), - ("fp4", 128, 256, 256, 128, 256, 128, 2, 2), + ("fp8", 128, 256, 256, 128, 256, 128, 2, 2, 1, 1), + ("fp4", 128, 256, 256, 128, 256, 128, 2, 2, 1, 1), + ("fp8", 256, 512, 256, 128, 256, 128, 2, 2, 2, 2), + ("fp4", 256, 512, 256, 128, 256, 128, 2, 2, 2, 2), + ("a8w4", 256, 512, 256, 128, 256, 128, 2, 4, 2, 2), + ], + ids=[ + "fp8-128x256x256", + "fp4-128x256x256", + "fp8-256x512x256-cluster2x2", + "fp4-256x512x256-cluster2x2", + "a8w4-256x512x256-cluster2x2", ], - ids=["fp8-128x256x256", "fp4-128x256x256"], ) -def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp): +def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n): """Verify that the gfx1250 MX-scale GEMM kernel works inside a hipGraph. Captures one launch, replays once, and checks the replay output is @@ -1157,11 +1216,15 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ pytest.skip("hipGraph capture/replay not supported on simulator") is_fp4 = data_format == "fp4" + is_a8w4 = data_format == "a8w4" # Build inputs (mirrors _run_mxscale_gemm_test, but no padding needed # because we pick a clean shape). torch.manual_seed(0) - if is_fp4: + if is_a8w4: + a = random_fp8_data(M, K) # FP8 activation + b = fp4_utils.random_fp4_packed(N, K) # FP4 weight + elif is_fp4: a = fp4_utils.random_fp4_packed(M, K) b = fp4_utils.random_fp4_packed(N, K) else: @@ -1175,7 +1238,7 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ warp_tile_n = tile_n // n_warp a_scale_ps = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) b_scale_ps = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) - pack_b = 2 if is_fp4 else 1 + pack_b = 2 if (is_fp4 or is_a8w4) else 1 b_ps = fp4_utils.preshuffle_b_16x16(b, N, K // pack_b) a_gpu = a.cuda() @@ -1198,6 +1261,8 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ out_dtype="bf16", wave_specialized_tdm=False, split_k=1, + cluster_m=cluster_m, + cluster_n=cluster_n, ) c_flat = c_gpu.contiguous() @@ -1386,13 +1451,16 @@ def replay_graph_body(): file=sys.stderr, flush=True, ) - if (ref_per_launch_us > 2.0 and first_replay_per_launch_us < 0.25 * ref_per_launch_us - and first_replay_per_launch_us < 1.0): + if ( + ref_per_launch_us > 2.0 + and first_replay_per_launch_us < 0.25 * ref_per_launch_us + and first_replay_per_launch_us < 1.0 + ): raise RuntimeError( f"hipGraph replay per-launch={first_replay_per_launch_us:.3f}us " f"<< ref direct-launch={ref_per_launch_us:.3f}us. " f"Graph capture likely empty (uncaptured cluster launch or stream mismatch?)." - ) + ) # Stabilize graph replay before collecting samples. with torch.cuda.stream(capture_stream): @@ -2037,17 +2105,34 @@ def _run_benchmark(args): skt = tile_k // SCALE_BLOCK _ref_seg = is_ref_segmented_lds_layout( - data_format=data_format, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, m_warp=args.m_warp, - n_warp=args.n_warp, num_buffers=args.num_buffers, split_k=args.split_k, - wave_specialized_tdm=args.wave_spec_tdm, use_scale_opsel=args.use_scale_opsel, + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=args.m_warp, + n_warp=args.n_warp, + num_buffers=args.num_buffers, + split_k=args.split_k, + wave_specialized_tdm=args.wave_spec_tdm, + use_scale_opsel=args.use_scale_opsel, ) a_scale = preshuffle_scale_for_load_path( - a_scale, warp_tile_m, skt, scale_load_path=args.scale_load_path, data_format=data_format, - ref_segmented=_ref_seg, row_align=tile_m, + a_scale, + warp_tile_m, + skt, + scale_load_path=args.scale_load_path, + data_format=data_format, + ref_segmented=_ref_seg, + row_align=tile_m, ) b_scale = preshuffle_scale_for_load_path( - b_scale, warp_tile_n, skt, scale_load_path=args.scale_load_path, data_format=data_format, - ref_segmented=_ref_seg, row_align=tile_n, + b_scale, + warp_tile_n, + skt, + scale_load_path=args.scale_load_path, + data_format=data_format, + ref_segmented=_ref_seg, + row_align=tile_n, ) K_packed = padded_k // PACK_B @@ -2159,11 +2244,7 @@ def run_one(c_, a_, b_, as_, bs_): graph_rotate_target = max(_l2_cache_bytes() * 5, int(args.l2_flush_mb) * 1024 * 1024) graph_num_slots = _graph_rotate_slot_count(working_set, graph_rotate_target) graph_eviction_bytes = max(0, graph_num_slots - 1) * working_set - cap_note = ( - " [WARNING: capped below target]" - if graph_eviction_bytes < graph_rotate_target - else "" - ) + cap_note = " [WARNING: capped below target]" if graph_eviction_bytes < graph_rotate_target else "" print( f" L2 defeat: graph rotating buffers, slots={graph_num_slots}, " f"pool={working_set * graph_num_slots / 1e6:.1f} MB " @@ -2335,17 +2416,34 @@ def _run_graph_verify(args): warp_tile_m = tile_m // args.m_warp warp_tile_n = tile_n // args.n_warp _ref_seg = is_ref_segmented_lds_layout( - data_format=data_format, tile_m=tile_m, tile_n=tile_n, tile_k=tile_k, m_warp=args.m_warp, - n_warp=args.n_warp, num_buffers=args.num_buffers, split_k=args.split_k, - wave_specialized_tdm=args.wave_spec_tdm, use_scale_opsel=args.use_scale_opsel, + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=args.m_warp, + n_warp=args.n_warp, + num_buffers=args.num_buffers, + split_k=args.split_k, + wave_specialized_tdm=args.wave_spec_tdm, + use_scale_opsel=args.use_scale_opsel, ) a_scale = preshuffle_scale_for_load_path( - a_scale, warp_tile_m, skt, scale_load_path=args.scale_load_path, data_format=data_format, - ref_segmented=_ref_seg, row_align=tile_m, + a_scale, + warp_tile_m, + skt, + scale_load_path=args.scale_load_path, + data_format=data_format, + ref_segmented=_ref_seg, + row_align=tile_m, ) b_scale = preshuffle_scale_for_load_path( - b_scale, warp_tile_n, skt, scale_load_path=args.scale_load_path, data_format=data_format, - ref_segmented=_ref_seg, row_align=tile_n, + b_scale, + warp_tile_n, + skt, + scale_load_path=args.scale_load_path, + data_format=data_format, + ref_segmented=_ref_seg, + row_align=tile_n, ) K_packed = padded_k // padded_shape["pack_b"] b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -2542,8 +2640,7 @@ def launch(): "--no-flush-l2", action="store_true", default=False, - help="Disable L2 defeat for a hot-cache measurement. Applies to both eager " - "and --use-graph modes.", + help="Disable L2 defeat for a hot-cache measurement. Applies to both eager " "and --use-graph modes.", ) parser.add_argument( "--l2-flush-mb", @@ -2579,6 +2676,8 @@ def launch(): if args.scale_mode == "ptpc" and args.verify_graph: raise SystemExit("--scale-mode ptpc does not support --verify-graph") + if args.scale_mode == "ptpc" and args.use_graph and not args.benchmark: + raise SystemExit("--scale-mode ptpc does not support --use-graph for functional tests (use --benchmark)") def _run_correctness_test(): """Run the functional test (computes a reference and asserts correctness).""" @@ -2628,6 +2727,7 @@ def _run_correctness_test(): scale_load_path=args.scale_load_path, a_load_path=args.a_load_path, b_split_load=args.b_split_load, + use_graph=args.use_graph, ) if args.verify_graph: From cfbdb931a05c6293bce7d768d3c6e4e6e06d6a9c Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 14 Jun 2026 15:04:02 +0000 Subject: [PATCH 06/16] b scale tile independent preshuffle --- kernels/gemm_common_gfx1250.py | 23 ++ kernels/gemm_fp8fp4_gfx1250.py | 189 ++++++++++++++-- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 257 ++++------------------ 3 files changed, 239 insertions(+), 230 deletions(-) diff --git a/kernels/gemm_common_gfx1250.py b/kernels/gemm_common_gfx1250.py index b269192d..863fea77 100644 --- a/kernels/gemm_common_gfx1250.py +++ b/kernels/gemm_common_gfx1250.py @@ -93,6 +93,29 @@ def lds_load_b128_raw(lds_base_idx, byte_offset): return llvm_dialect.load(ir.VectorType.get([4], ir.IntegerType.get_signless(32)), ptr_val) +def lds_load_b32_raw(lds_base_idx, byte_offset): + """Load 4 bytes (one i32) from LDS using a pre-extracted base index (raw LLVM). + + Unlike :func:`lds_load_b128_raw`, this only requires 4-byte alignment, so it + suits scale layouts where consumed words sit at 4-byte (not 16-byte) granular + offsets (e.g. the N4K4 B-scale layout's per-N-block reads). + """ + ptr_val = _raw_lds_ptr(lds_base_idx, byte_offset) + return llvm_dialect.load(ir.IntegerType.get_signless(32), ptr_val) + + +def lds_load_b64_raw(lds_base_idx, byte_offset): + """Load 8 bytes (``vector<2xi32>``) from LDS using a pre-extracted base index. + + Requires 8-byte alignment. Sits between :func:`lds_load_b32_raw` and + :func:`lds_load_b128_raw` for layouts whose contiguous read width is 2 words + (e.g. the N4K4 B-scale layout when ``wmma_n_rep`` is even but not a multiple + of 4, where each aligned batch covers exactly 2 N-blocks). + """ + ptr_val = _raw_lds_ptr(lds_base_idx, byte_offset) + return llvm_dialect.load(ir.VectorType.get([2], ir.IntegerType.get_signless(32)), ptr_val) + + def lds_transpose_load_raw(result_type, lds_base_idx, byte_offset): """Transpose-load 16 bytes from LDS using a pre-extracted base index.""" from flydsl._mlir.dialects import rocdl as _rocdl diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index c113d3e7..24458607 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -24,6 +24,8 @@ extract_lds_base_idx, get_lds_memref, issue_tdm_loads, + lds_load_b32_raw, + lds_load_b64_raw, lds_load_b128_raw, pipeline_fence, pipeline_fence_signal, @@ -78,6 +80,7 @@ def _vec_chunks(n: int): done += w return chunks + LDS_PAD_A_BYTES = 16 LDS_PAD_D_BYTES = 16 LDS_SEGMENT_BYTES = 64 * 1024 @@ -118,6 +121,44 @@ def is_ref_segmented_lds_layout( ) +def use_n4k4_bscale_layout( + *, + data_format, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + n, + scale_mode="mxscale", + scale_load_path="tdm", + use_scale_opsel=False, + b_streaming=False, + b_split_load=False, +): + """Whether B-scale uses the tile-independent N4K4 preshuffle layout.""" + if scale_mode != "mxscale": + return False + if data_format not in ("fp8", "a8w4"): + return False + if scale_load_path != "tdm": + return False + if use_scale_opsel or b_split_load: + return False + if tile_k % 128 != 0: + return False + if n % 64 != 0: + return False + if tile_n % 64 != 0 and 64 % tile_n != 0: + return False + wmma_m_rep = (tile_m // m_warp) // WMMA_M + wmma_n_rep = (tile_n // n_warp) // WMMA_N + n_accs = wmma_m_rep * wmma_n_rep + # Row-major streaming is selected exactly when a rep is odd or n_accs < 8 + # (see _pick_compute_schedule_kind); otherwise fp8/a8w4 route to quadrant. + return (not b_streaming) and (wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8) + + @functools.lru_cache(maxsize=256) def compile_fp8fp4_gemm( *, @@ -312,6 +353,34 @@ def compile_fp8fp4_gemm( # FP4 A/B swap: BScale rep derived from WMMA_M, not WMMA_N_EFF b_scale_load_rep = warp_tile_n // WMMA_M if is_fp4 else wmma_n_rep + use_n4k4_bscale = use_n4k4_bscale_layout( + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + n=N, + scale_mode=scale_mode, + scale_load_path=scale_load_path, + use_scale_opsel=use_scale_opsel, + b_streaming=b_streaming, + b_split_load=b_split_load, + ) + if use_n4k4_bscale: + if K_scale % 4 != 0: + raise ValueError(f"N4K4 B-scale requires K_scale % 4 == 0, got {K_scale}") + n4k4_n_groups = N // 64 + n4k4_bs_global_row_stride = (K // WMMA_K) * 256 + n4k4_bs_lds_row_stride = k_wmma_steps * 256 + # Impact: cost-free when tile_n//64 is already a power of + # two (tile_n=64/128/256 -> 1/2/4 groups); only a non-pow2 group count + # (e.g. tile_n=192: 3->4) copies one extra oob-clipped scale group per + # tile (~0.1% of B traffic, no extra WMMA). + _n4k4_groups = (tile_n + 63) // 64 + n4k4_bs_tile_groups = 1 << (_n4k4_groups - 1).bit_length() + n4k4_bs_lds_rows = n4k4_bs_tile_groups + _b_frag_loads_per_wn = 2 if is_a8w4 else 4 _a_frag_loads_per_wm = 2 if is_fp4 else 4 # _scale_ds_loads counts scale ds_loads issued alongside A/B fragment loads in @@ -338,7 +407,10 @@ def compile_fp8fp4_gemm( ab_split_b_groups = tile_n // 32 _scale_guard_bytes = 16 lds_a_scale_bytes = 0 if (is_ptpc or use_ascale_vgpr) else tile_m * scale_k_per_tile + _scale_guard_bytes - lds_b_scale_bytes = 0 if is_ptpc else tile_n * scale_k_per_tile + _scale_guard_bytes + if use_n4k4_bscale: + lds_b_scale_bytes = n4k4_bs_lds_rows * n4k4_bs_lds_row_stride + _scale_guard_bytes + else: + lds_b_scale_bytes = 0 if is_ptpc else tile_n * scale_k_per_tile + _scale_guard_bytes interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile interleaved_scale_cols_b = b_scale_load_rep * scale_k_per_tile @@ -627,6 +699,9 @@ def _pick_compute_schedule_kind(): use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING + + if use_n4k4_bscale: + assert compute_schedule_kind == COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING if use_buffer_vgpr_scale: # General coalesced VGPR scale is supported on the row-major streaming # schedule (mxscale fp8/a8w4, no scale_opsel, wave-specialized TDM); the @@ -646,9 +721,7 @@ def _pick_compute_schedule_kind(): "wave_specialized_tdm" ) if use_a_vgpr and compute_schedule_kind != COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING: - raise ValueError( - f"a_load_path={a_load_path!r} requires the row-major streaming schedule" - ) + raise ValueError(f"a_load_path={a_load_path!r} requires the row-major streaming schedule") use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) @@ -908,13 +981,17 @@ def _avr_load_a_tile(k_base): for ks in range_constexpr(k_wmma_steps): for wm in range_constexpr(wmma_m_rep): row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 - row_off = row * _avr_lda_i32 + _avr_lane_kgrp_off + arith.index(ks * WMMA_K // PACK_FACTOR_A // 4) + row_off = ( + row * _avr_lda_i32 + _avr_lane_kgrp_off + arith.index(ks * WMMA_K // PACK_FACTOR_A // 4) + ) loads = [] for i in range_constexpr(DS_LOADS_PER_A_FRAG): off = arith.index_cast(T.i32, row_off + arith.index(i * 8)) - v = fx.Vector(buffer_ops.buffer_load( - _avr_a_rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff - )) + v = fx.Vector( + buffer_ops.buffer_load( + _avr_a_rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff + ) + ) loads.append(v) if const_expr(DS_LOADS_PER_A_FRAG == 2): frag = loads[0].shuffle(loads[1], list(range(8))) @@ -937,9 +1014,9 @@ def _avr_load_ascale(k_base): for grp in range_constexpr(_NG): grp_i32 = base_i32 + arith.index((ks * _NG + grp) * 32 * 4) + _avr_as_lane32 * arith.index(4) off = arith.index_cast(T.i32, grp_i32) - v = fx.Vector(buffer_ops.buffer_load( - _avr_as_rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff - )) + v = fx.Vector( + buffer_ops.buffer_load(_avr_as_rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff) + ) for j in range_constexpr(4): if const_expr(grp * 4 + j < wmma_m_rep): vals.append(v[j]) @@ -1067,6 +1144,28 @@ def make_desc_as(memref, k_base): ) def make_desc_bs(memref, k_base): + if const_expr(use_n4k4_bscale): + # N4K4: copy this tile's N-groups x K-blocks slice of the + # preshuffled [N//64, (K//128)*256] B-scale tensor. Each row is + # one 64-N group; the contiguous dim1 = tile_k//128 * 256B blocks. + g_off = blk_n // arith.index(64) + col_off = (k_base // arith.index(WMMA_K)) * arith.index(256) + return _make_tdm_desc( + global_ptr=arg_b_scale, + lds_memref=memref, + global_offset=(g_off, col_off), + tensor_shape=(n4k4_n_groups, n4k4_bs_global_row_stride), + strides=(n4k4_bs_global_row_stride, 1), + tile_shape=(n4k4_bs_tile_groups, n4k4_bs_lds_row_stride), + elem_bytes=1, + pad_interval=0, + pad_amount=0, + num_warps=tdm_desc_num_warps, + workgroup_mask=b_mcast_mask, + atomic_barrier_enable=atomic_barrier_enable, + early_timeout=True, + oob_outer_bound=n4k4_n_groups, + ) k_scale_off = k_base // arith.index(SCALE_BLOCK) outer_off = blk_n // arith.index(b_scale_load_rep) inner_off = k_scale_off * arith.index(b_scale_load_rep) @@ -1227,6 +1326,46 @@ def _precompute_scale_lane_bases(lds_ptr, warp_base, reps, interleaved_cols): base = base + lane_kgrp * arith.index(SCALES_PER_WMMA) return lds_ptr, [base] + def _precompute_n4k4_bscale_bases(lds_ptr): + """Precompute (first_block, lane_byte) for this warp's N4K4 reads. + + The TDM copies the 64-N group(s) containing the tile; within a copied + group a lane's 4 N-blocks are 16 contiguous bytes and consecutive + groups are n4k4_bs_lds_row_stride apart. ``b0`` is the warp's first + N-block *inside the copied group(s)*: its own N offset plus, for tiles + smaller than one 64-N group, the tile's slice offset within its group. + Only lanes 0..15 carry the consumed scale (scaleAType=0, no op_sel); + lanes 16..31 read the same word. + """ + b0 = wave_n_idx * arith.index(b_scale_load_rep) + if const_expr(tile_n < 64): + # Sub-64 tile: whole containing group was copied; shift to this + # tile's slice (row offset 0/16/32/48 -> N-block 0/1/2/3). + b0 = b0 + (blk_n % arith.index(64)) // arith.index(16) + lane_off = lane16 * arith.index(16) + return lds_ptr, (b0, lane_off) + + _N4K4_LOADERS = {1: lds_load_b32_raw, 2: lds_load_b64_raw, 4: lds_load_b128_raw} + + def load_n4k4_bscale(lds_buffer, bases, reps, ks=0): + """Load *reps* B-scale i32s from the N4K4 LDS layout for K-subtile *ks*.""" + b0, lane_off = bases + ks_off = arith.index(ks * 256) + row_stride = arith.index(n4k4_bs_lds_row_stride) + per_load = 4 if reps % 4 == 0 else (2 if reps % 2 == 0 else 1) + results = [] + for i in range_constexpr(reps // per_load): + blk = b0 + arith.index(i * per_load) + off = (blk // arith.index(4)) * row_stride + (blk % arith.index(4)) * arith.index(4) + lane_off + ks_off + raw = _N4K4_LOADERS[per_load](lds_buffer, off) + if const_expr(per_load == 1): + results.append(raw) + else: + vec = fx.Vector(raw) + for j in range_constexpr(per_load): + results.append(vec[j]) + return results + def load_scale_b128(lds_buffer, scale_base, reps, ks=0): """Load all wmma_rep scales via ds_load_b128(s) for K-subtile *ks*.""" ks_byte_off = ks * reps * SCALES_PER_WMMA @@ -1261,6 +1400,12 @@ def load_scale_slice_b128(lds_buffer, scale_base, full_reps, rep_start, rep_coun # consume is sequential at emit time (same pattern as epi_addrs_box). _vgpr_scale_box = [None] + def _load_b_scale_lds(bs_buf, bs_bases, ks): + """Load B-scale from LDS, dispatching to the N4K4 or legacy layout.""" + if const_expr(use_n4k4_bscale): + return load_n4k4_bscale(bs_buf, bs_bases, b_scale_load_rep, ks) + return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) + def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): """Load both scale tensors and apply op_sel downsampling per format. @@ -1278,11 +1423,10 @@ def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): if const_expr(use_ascale_vgpr): # A_scale from VGPR (bundled with A prefetch ring), B_scale from LDS. a = _a_vgpr_ascale_box[0][ks * wmma_m_rep : (ks + 1) * wmma_m_rep] - b_all = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) - b = b_all + b = _load_b_scale_lds(bs_buf, bs_bases, ks) return a, b a_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) - b_all = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) + b_all = _load_b_scale_lds(bs_buf, bs_bases, ks) if const_expr(use_scale_opsel): a = a_all[::2] b = b_all if const_expr(is_fp4) else b_all[::2] @@ -1423,7 +1567,9 @@ def _emit_rows(start_wm, a_frags): mid_compute_callback() if const_expr(_back_wm > 0): - a_frags_back = [load_a_frag(a_buf, a_bases[_front_wm + h], ks, wm=_front_wm + h) for h in range_constexpr(_back_wm)] + a_frags_back = [ + load_a_frag(a_buf, a_bases[_front_wm + h], ks, wm=_front_wm + h) for h in range_constexpr(_back_wm) + ] _back_drain = _bs_ds_loads if _use_partial_drain else 0 rocdl.s_wait_dscnt(_back_drain) _emit_rows(_front_wm, a_frags_back) @@ -1522,9 +1668,12 @@ def compute_tile( a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) + if const_expr(use_n4k4_bscale): + bs_buf, bs_bases = _precompute_n4k4_bscale_bases(lds_bs) + else: + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b + ) if const_expr(k_wmma_steps == 1): b_frags, b_scales, a_scales = _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, 0) @@ -2686,7 +2835,9 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): adv_a_i32 = fx.Int32(tile_k // PACK_FACTOR_A) adv_b_i32 = fx.Int32(packed_tile_k_b * 16) adv_as_i32 = fx.Int32(tile_k // SCALE_BLOCK * wmma_m_rep) - adv_bs_i32 = fx.Int32(tile_k // SCALE_BLOCK * b_scale_load_rep) + # N4K4 advances by one tile's worth of K-blocks (k_wmma_steps*256B) per + # K-step; the legacy interleaved layout advances by scale_k_per_tile*rep. + adv_bs_i32 = fx.Int32(n4k4_bs_lds_row_stride if use_n4k4_bscale else tile_k // SCALE_BLOCK * b_scale_load_rep) pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 029cfade..3e1b4f4f 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -28,6 +28,7 @@ compile_mxscale_gemm, compile_ptpc_gemm, is_ref_segmented_lds_layout, + use_n4k4_bscale_layout, ) from tests.kernels.utils import fp4_utils # noqa: E402 @@ -160,6 +161,28 @@ def preshuffle_e8m0_scale( return g.reshape(-1, k_groups * k_wmma_steps * wmma_rep * SCALES_PER_WMMA) +def preshuffle_e8m0_bscale_n4k4(scale: torch.Tensor) -> torch.Tensor: + """Tile-independent N4K4 B-scale preshuffle: [N, K_scale] -> [N//64, (K_scale//4)*256]. + + Atomic block = 4 N-blocks x 1 K-block = 64 N-rows x 4 scale-bytes = 256B, so + the byte layout depends only on the constants (64, 16, 4, 4) and never on + tile_n/n_warp/tile_k. Weights are preshuffled once and served to any tile + config landing on the default row-major streaming schedule. + + B_scale_pre[g, kb, n, r, k] = scale[g*64 + r*16 + n, kb*4 + k] + + where g = N//64 group, kb = K//128 block (4 scale bytes = one WMMA's K=128), + n = lane16 (N-row within a 16-block), r = the 4 N-blocks in a 64-group, + k = the 4 scale bytes within one WMMA. Mirrors the kernel's N4K4 TDM+LDS + read (see flydsl_fp8_perf/verify_n4k4_bscale_layout.py for the parity proof). + """ + N, Ks = scale.shape + assert N % 64 == 0 and Ks % 4 == 0, f"N4K4 B-scale needs N%64==0, Ks%4==0; got N={N} Ks={Ks}" + g = scale.view(N // 64, 4, 16, Ks // 4, 4) # [g, r, n, kb, k] + g = g.permute(0, 3, 2, 1, 4).contiguous() # [g, kb, n, r, k] + return g.reshape(N // 64, (Ks // 4) * 256) + + def preshuffle_scale_for_load_path( scale, warp_tile, skt, *, scale_load_path, data_format, ref_segmented, row_align=None ): @@ -578,15 +601,30 @@ def _run_mxscale_gemm_test( ref_segmented=_ref_seg, row_align=tile_m, ) - b_scale = preshuffle_scale_for_load_path( - b_scale, - warp_tile_n, - skt, - scale_load_path=scale_load_path, + if use_n4k4_bscale_layout( data_format=data_format, - ref_segmented=_ref_seg, - row_align=tile_n, - ) + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + n=padded_n, + scale_load_path=scale_load_path, + use_scale_opsel=use_scale_opsel, + b_streaming=b_streaming, + b_split_load=b_split_load, + ): + b_scale = preshuffle_e8m0_bscale_n4k4(b_scale) + else: + b_scale = preshuffle_scale_for_load_path( + b_scale, + warp_tile_n, + skt, + scale_load_path=scale_load_path, + data_format=data_format, + ref_segmented=_ref_seg, + row_align=tile_n, + ) # Preshuffle B data K_packed = padded_k // padded_shape["pack_b"] @@ -806,32 +844,6 @@ def test_mxfp4_gemm( ) -@pytest.mark.parametrize("out_dtype", ["bf16", "f16"]) -def test_mxfp4_metadata_and_spill_regression(out_dtype): - launch_fn = _run_mxscale_gemm_test( - "fp4", - 1024, - 1024, - 1024, - 256, - 256, - 256, - 2, - 2, - num_buffers=4, - use_tdm_store=True, - out_dtype=out_dtype, - return_launch_fn=True, - ) - artifact = _get_latest_artifact(launch_fn) - - assert ( - "known_block_size = array" in artifact.source_ir - ), f"expected known_block_size metadata in source IR:\n{artifact.source_ir}" - - compiled_ir = artifact.ir - assert _extract_i64_metadata(compiled_ir, "max_flat_workgroup_size") == 128 - assert _extract_i64_metadata(compiled_ir, "vgpr_spill_count") == 0 @pytest.mark.parametrize( @@ -967,183 +979,6 @@ def test_a8w4_gemm_irregular_m_tile16(M, N, K, use_tdm_store): ) -@pytest.mark.parametrize( - "scale_load_path, a_load_path, b_split_load", - [ - ("tdm", "vgpr", False), - ("tdm", "vgpr", True), - ("tdm", "tdm", True), - ("vgpr", "tdm", False), - ], -) -def test_a8w4_small_m_load_path_variants(scale_load_path, a_load_path, b_split_load): - _run_mxscale_gemm_test( - "a8w4", - 1, - 256, - 2048, - 16, - 64, - 512, - 1, - 4, - num_buffers=4, - use_tdm_store=True, - out_dtype="bf16", - wave_specialized_tdm=True, - l2_prefetch_distance=0, - use_scale_opsel=False, - scale_load_path=scale_load_path, - a_load_path=a_load_path, - b_split_load=b_split_load, - ) - - -@pytest.mark.parametrize( - "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", - [ - ("fp4", 128, 512, 7168, 128, 128, 256, 2, 2), - ("fp8", 128, 256, 256, 128, 256, 128, 2, 4), - ("a8w4", 128, 256, 256, 128, 256, 128, 2, 4), - ], -) -def test_b_streaming_correctness(data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp): - _run_mxscale_gemm_test( - data_format, - M, - N, - K, - tile_m, - tile_n, - tile_k, - m_warp, - n_warp, - num_buffers=2, - use_tdm_store=True, - out_dtype="bf16", - l2_prefetch_distance=2, - b_streaming=True, - ) - - -@pytest.mark.parametrize( - "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", - [ - ("fp4", 128, 256, 512, 128, 128, 256, 2, 2), - ("fp8", 128, 256, 256, 128, 256, 128, 2, 2), - ("a8w4", 128, 256, 256, 128, 256, 128, 2, 2), - ], -) -def test_b_streaming_with_wave_spec_tdm(data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp): - _run_mxscale_gemm_test( - data_format, - M, - N, - K, - tile_m, - tile_n, - tile_k, - m_warp, - n_warp, - num_buffers=2, - use_tdm_store=True, - out_dtype="bf16", - l2_prefetch_distance=2, - b_streaming=True, - wave_specialized_tdm=True, - ) - - -@pytest.mark.parametrize("num_buffers", [2, 3]) -@pytest.mark.parametrize("use_tdm_store", [True, False]) -@pytest.mark.parametrize("use_scale_opsel", [False, True]) -def test_mxfp8_wave_spec_scale_load_tdm(num_buffers, use_tdm_store, use_scale_opsel): - _run_mxscale_gemm_test( - "fp8", - 128, - 256, - 384, - 128, - 256, - 128, - 2, - 2, - num_buffers=num_buffers, - use_tdm_store=use_tdm_store, - out_dtype="bf16", - l2_prefetch_distance=2, - wave_specialized_tdm=True, - use_scale_opsel=use_scale_opsel, - scale_load_path="tdm", - ) - - -@pytest.mark.parametrize("scale_load_path", ["vgpr", "vgpr_ab_split"]) -@pytest.mark.parametrize("cluster_m, cluster_n", [(1, 1), (2, 2)]) -def test_mxfp8_vgpr_scale_load(scale_load_path, cluster_m, cluster_n): - _run_mxscale_gemm_test( - "fp8", - 256 * cluster_m, - 256 * cluster_n, - 512, - 256, - 256, - 128, - 2, - 2, - num_buffers=4, - use_tdm_store=True, - out_dtype="bf16", - l2_prefetch_distance=2, - wave_specialized_tdm=True, - cluster_m=cluster_m, - cluster_n=cluster_n, - scale_load_path=scale_load_path, - ) - - -@pytest.mark.parametrize( - "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", - [ - ("fp4", 256, 512, 256, 128, 256, 128, 2, 2, 2, 2), - ("fp8", 256, 512, 256, 128, 256, 128, 2, 2, 2, 2), - ], -) -def test_b_streaming_with_cluster_mcast( - data_format, - M, - N, - K, - tile_m, - tile_n, - tile_k, - m_warp, - n_warp, - cluster_m, - cluster_n, -): - if str(get_rocm_arch()) != "gfx1250": - pytest.skip("requires gfx1250") - if "FFMLITE_TOPOLOGY" in os.environ or "AM_TOPOLOGY" in os.environ: - pytest.skip("cluster multicast not supported on simulator") - _run_mxscale_gemm_test( - data_format, - M, - N, - K, - tile_m, - tile_n, - tile_k, - m_warp, - n_warp, - num_buffers=2, - use_tdm_store=True, - out_dtype="bf16", - l2_prefetch_distance=2, - b_streaming=True, - cluster_m=cluster_m, - cluster_n=cluster_n, - ) @pytest.mark.parametrize( From e93772d81ce25fbdfb0a368b4e93018054136445 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Sun, 14 Jun 2026 15:57:13 +0000 Subject: [PATCH 07/16] remove a vgpr/b nsplit/b split load/cluster graph exps --- kernels/gemm_fp8fp4_gfx1250.py | 555 ++------------------ lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp | 65 +-- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 188 ++++--- 3 files changed, 142 insertions(+), 666 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 24458607..ec256b37 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -134,7 +134,6 @@ def use_n4k4_bscale_layout( scale_load_path="tdm", use_scale_opsel=False, b_streaming=False, - b_split_load=False, ): """Whether B-scale uses the tile-independent N4K4 preshuffle layout.""" if scale_mode != "mxscale": @@ -143,7 +142,7 @@ def use_n4k4_bscale_layout( return False if scale_load_path != "tdm": return False - if use_scale_opsel or b_split_load: + if use_scale_opsel: return False if tile_k % 128 != 0: return False @@ -187,8 +186,6 @@ def compile_fp8fp4_gemm( b_streaming: bool = False, scale_load_path: str = "tdm", fp8_schedule: str = "auto", - a_load_path: str = "tdm", - b_split_load: bool = False, ): """Compile an FP4/FP8/A8W4 GEMM kernel with TDM async copy. @@ -222,9 +219,8 @@ def compile_fp8fp4_gemm( raise ValueError(f"out_dtype must be 'f32', 'bf16', or 'f16', got {out_dtype!r}") elem_bytes_d = 2 if out_dtype in ("bf16", "f16") else 4 # scale_load_path: "tdm" = TDM->LDS (default); "vgpr" = buffer_load->VGPR, - # off the LDS/TDM/barrier path; "vgpr_ab_split" = "vgpr" plus repurposing the - # idle scale waves 2,3 to load the second A/B halves. - scale_load_paths = ("tdm", "vgpr", "vgpr_ab_split") + # off the LDS/TDM/barrier path. + scale_load_paths = ("tdm", "vgpr") if scale_load_path not in scale_load_paths: raise ValueError(f"scale_load_path must be one of {scale_load_paths}, got {scale_load_path!r}") fp8_schedule_modes = ("auto", "quadrant", "deep-pipeline") @@ -234,20 +230,6 @@ def compile_fp8fp4_gemm( raise ValueError(f"fp8_schedule={fp8_schedule!r} is only valid for data_format='fp8'") if fp8_schedule != "auto" and b_streaming: raise ValueError("fp8_schedule cannot be combined with b_streaming=True") - a_load_path_modes = ("tdm", "vgpr", "vgpr_ascale") - if a_load_path not in a_load_path_modes: - raise ValueError(f"a_load_path must be one of {a_load_path_modes}, got {a_load_path!r}") - use_a_vgpr = a_load_path != "tdm" - use_ascale_vgpr = a_load_path == "vgpr_ascale" - if use_a_vgpr and scale_load_path != "tdm": - raise ValueError("a_load_path and scale_load_path cannot both bypass TDM") - if use_a_vgpr and not wave_specialized_tdm: - raise ValueError("a_load_path != 'tdm' requires wave_specialized_tdm=True") - if use_a_vgpr and data_format not in ("fp8", "a8w4"): - raise ValueError("a_load_path != 'tdm' requires data_format='fp8' or 'a8w4'") - if use_a_vgpr and is_ptpc: - raise ValueError("a_load_path != 'tdm' requires scale_mode='mxscale'") - b_split_load = bool(b_split_load) effective_expert_sched_mode = bool(expert_sched_mode) if num_buffers not in (2, 3, 4, 5, 6): @@ -267,21 +249,11 @@ def compile_fp8fp4_gemm( raise ValueError(f"block_threads must be <= 1024, got {block_threads}") # Wave-specialized TDM dedicates one loader wave per TDM tensor. - # Determine which tensors bypass TDM to calculate minimum wave count. - # A data: bypasses TDM when use_a_vgpr - # A_scale: bypasses TDM when use_ascale_vgpr, is_ptpc, or scale_load_path=="vgpr" - # B_scale: bypasses TDM when is_ptpc or scale_load_path=="vgpr" - # Remaining TDM tensors determine wave assignment and min warp count. - _drop_scale_loader_waves = is_ptpc or scale_load_path == "vgpr" or use_ascale_vgpr - _drop_a_loader_wave = use_a_vgpr - if _drop_a_loader_wave and _drop_scale_loader_waves: - _min_wave_spec_warps = 2 # only B + B_scale (or just B for ptpc) - elif _drop_scale_loader_waves: - _min_wave_spec_warps = 2 # only A + B - elif _drop_a_loader_wave: - _min_wave_spec_warps = 4 # B + A_scale + B_scale (wave0 idle) - else: - _min_wave_spec_warps = 4 # A + B + A_scale + B_scale + # Scales bypass TDM (no dedicated loader waves) for ptpc or the buffer->VGPR + # scale path, leaving only A + B -> 2 waves; otherwise A + B + A_scale + + # B_scale -> 4 waves. + _drop_scale_loader_waves = is_ptpc or scale_load_path == "vgpr" + _min_wave_spec_warps = 2 if _drop_scale_loader_waves else 4 if wave_specialized_tdm and num_warps < _min_wave_spec_warps: raise ValueError(f"wave_specialized_tdm requires at least {_min_wave_spec_warps} waves, got {num_warps}") @@ -365,7 +337,6 @@ def compile_fp8fp4_gemm( scale_load_path=scale_load_path, use_scale_opsel=use_scale_opsel, b_streaming=b_streaming, - b_split_load=b_split_load, ) if use_n4k4_bscale: if K_scale % 4 != 0: @@ -387,26 +358,19 @@ def compile_fp8fp4_gemm( # the streaming schedule (used for the partial-drain s_wait_dscnt bookkeeping). # The general VGPR scale path holds scales in registers (no ds_load), so it # contributes zero. Finalized below once use_general_vgpr_scale is known. - _a_scale_ds = 0 if use_ascale_vgpr else (wmma_m_rep + 3) // 4 + _a_scale_ds = (wmma_m_rep + 3) // 4 _b_scale_ds = (b_scale_load_rep + 3) // 4 _scale_ds_loads = _a_scale_ds + _b_scale_ds - _a_frag_ds = 0 if use_a_vgpr else wmma_m_rep * _a_frag_loads_per_wm + _a_frag_ds = wmma_m_rep * _a_frag_loads_per_wm _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _scale_ds_loads _as_ds_loads = _a_frag_ds + _scale_ds_loads lds_a_stride_bytes = packed_tile_k_a + LDS_PAD_A_BYTES - if scale_load_path == "vgpr_ab_split" or b_split_load: - if tile_m % 2 != 0: - raise ValueError("B/A split load variants require even tile_m, got " f"{tile_m}") - if tile_n % 32 != 0: - raise ValueError("B/A split load variants require tile_n divisible by 32, got " f"{tile_n}") - lds_a_data_bytes = 0 if use_a_vgpr else tile_m * lds_a_stride_bytes + lds_a_data_bytes = tile_m * lds_a_stride_bytes lds_b_data_bytes = tile_n * packed_tile_k_b - ab_split_a_rows = tile_m // 2 - ab_split_b_groups = tile_n // 32 _scale_guard_bytes = 16 - lds_a_scale_bytes = 0 if (is_ptpc or use_ascale_vgpr) else tile_m * scale_k_per_tile + _scale_guard_bytes + lds_a_scale_bytes = 0 if is_ptpc else tile_m * scale_k_per_tile + _scale_guard_bytes if use_n4k4_bscale: lds_b_scale_bytes = n4k4_bs_lds_rows * n4k4_bs_lds_row_stride + _scale_guard_bytes else: @@ -469,16 +433,14 @@ def _align_up(value: int, align: int) -> int: use_scale_opsel=use_scale_opsel, ) - # "vgpr"/"vgpr_ab_split": load scale global->VGPR via buffer_load, bypassing + # "vgpr": load scale global->VGPR via buffer_load, bypassing # TDM+LDS entirely. Two layouts coexist: the reference segmented deep-pipeline # path (use_ref_segmented_lds_layout, fp8 256x256x128) and the general # coalesced path (use_general_vgpr_scale) used by the row-major streaming # schedule (a8w4/fp8, arbitrary warp_tile / tile_k). The full schedule+format # eligibility check runs once the compute schedule is known (below). - use_buffer_vgpr_scale = scale_load_path in ("vgpr", "vgpr_ab_split") + use_buffer_vgpr_scale = scale_load_path == "vgpr" use_general_vgpr_scale = use_buffer_vgpr_scale and not use_ref_segmented_lds_layout - if use_general_vgpr_scale and scale_load_path == "vgpr_ab_split": - raise ValueError("scale_load_path='vgpr_ab_split' requires the reference segmented LDS layout") if use_general_vgpr_scale: # General VGPR scales live in registers: no scale ds_loads to wait on. _scale_ds_loads = 0 @@ -501,62 +463,9 @@ def _align_up(value: int, align: int) -> int: # all-up-front variant. Full-K scale must fit in VGPRs -- NOT general. _bvs_b128 = use_general_vgpr_scale and bool(int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_PRELOAD", "0"))) _bvs_preload = _bvs_b128 and loop_iters == 0 - # ab_half_split: repurpose the (under "vgpr") idle scale waves 2,3 as the - # second halves of A/B, so all 4 waves share the A/B TDM (wave0=A0, wave1=B0, - # wave2=A1, wave3=B1). Measured wall-neutral. - use_ab_half_split = scale_load_path == "vgpr_ab_split" # The buffer_load->VGPR scale ring is built only when scale is actually loaded. _bvs_active = use_buffer_vgpr_scale - # A VGPR prefetch: buffer_load A data directly into VGPRs, bypassing TDM/LDS. - # Per-tile A frag count: k_wmma_steps * wmma_m_rep vec<16xi32> (or vec<8xi32> for fp4). - _avr_active = use_a_vgpr - _avr_D = max(1, int(os.environ.get("FLYDSL_A_VGPR_DEPTH", str(num_buffers)))) - _avr_frag_width = 8 if is_fp4 else 16 # vec elements per A fragment - _avr_frags_per_tile = k_wmma_steps * wmma_m_rep - # When vgpr_ascale, A_scale is bundled with the A ring. - _avr_ascale_per_tile = k_wmma_steps * wmma_m_rep if use_ascale_vgpr else 0 - - # B N-split (env FLYDSL_B_KSPLIT, default off): on the VGPR A path wave0 is - # freed from loading A, so wave0 and wave1 co-load B — wave0 the first half of - # the tile's N-groups, wave1 the second — issued in parallel under the normal - # single unified barrier. N is the outer LDS dim, so the two halves write the - # same contiguous tile the full-B descriptor would; the LDS layout, the compute - # loop, and the fence are all unchanged. The ONLY delta vs the plain vgpr path - # is: wave0 is activated and the B load is split across waves 0/1 by N-group. - # Valid only when A is VGPR with TDM scales (wave0 otherwise idle, waves - # 1/2/3 = B/A_scale/B_scale) and tile_n splits into two equal N-group halves. - _b_nsplit = ( - b_split_load - and use_a_vgpr - and not use_ascale_vgpr - and not is_ptpc - and scale_load_path == "tdm" - and wave_specialized_tdm - and tile_n % 32 == 0 - ) - # TDM B N-split while A still uses TDM. Wave assignment: - # wave0=A, wave1=B first N-half, wave2=B second N-half, - # wave3=A_scale then B_scale. - # Wave3 issues two tensor ops per K-tile, so the pipeline wait uses a - # wave-specific outstanding count in the fence helpers below. - _tdm_b_nsplit_scale_combo = ( - b_split_load - and not use_a_vgpr - and data_format == "a8w4" - and not is_ptpc - and scale_load_path == "tdm" - and wave_specialized_tdm - and num_buffers == 4 - and tile_n % 32 == 0 - ) - if b_split_load and not (_b_nsplit or _tdm_b_nsplit_scale_combo): - raise ValueError( - "b_split_load currently supports either a_load_path='vgpr' with TDM scales, " - "or A8W4 a_load_path='tdm' scale_load_path='tdm' wave_specialized_tdm=True " - "num_buffers=4" - ) - if use_ref_segmented_lds_layout: # The A/B data pools are no longer packed into the same per-stage # 64KiB segment window. Scale pools keep the reference 0x800 stride so @@ -720,8 +629,6 @@ def _pick_compute_schedule_kind(): "the row-major streaming schedule with mxscale fp8/a8w4, no scale_opsel, and " "wave_specialized_tdm" ) - if use_a_vgpr and compute_schedule_kind != COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING: - raise ValueError(f"a_load_path={a_load_path!r} requires the row-major streaming schedule") use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) @@ -951,92 +858,6 @@ def _bvs_prefetch(k_base, preload=False): else: lda_packed = fx.Index(i32_lda) / arith.index(PACK_FACTOR_A) - if const_expr(_avr_active): - # arg_a is dynamically shaped (runtime M), so max_size=False would fall - # back to a max-sized descriptor and disable hardware OOB. Clip - # num_records to M*lda bytes (fp8 A: 1 byte/elem) so rows >= M read 0 - # for non-tile-aligned M, matching the TDM path. - _avr_a_rsrc = buffer_ops.create_buffer_resource(arg_a, num_records_bytes=m_idx * lda_packed) - # buffer_load voffset is in i32 (4-byte) elements (it multiplies by 4 - # internally), so the byte-domain row/K/lane offsets must be divided by - # 4. k_base goes in the byte-domain soffset and stays as-is. - _avr_lda_i32 = lda_packed // arith.index(4) - _avr_lane_kgrp_off = lane_kgrp * arith.index(4) # 16 bytes / 4 - if const_expr(use_ascale_vgpr): - _avr_as_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) - _avr_as_Kt = K // tile_k - _avr_as_mb = blk_m // arith.index(warp_tile_m) + wave_m_idx - _avr_as_lane32 = lane_kgrp * arith.index(16) + lane16 - - def _avr_load_a_tile(k_base): - """Issue buffer_load_b128 for one K-tile of A data. - - Returns list of vec<_avr_frag_width xi32> IR values, length - k_wmma_steps * wmma_m_rep (indexed [ks * wmma_m_rep + wm]). - The K-tile offset goes in soffset (scalar, same for all lanes); - per-lane row/kgrp address stays in voffset (reused across tiles). - """ - kt_soff = arith.index_cast(T.i32, k_base) - frags = [] - for ks in range_constexpr(k_wmma_steps): - for wm in range_constexpr(wmma_m_rep): - row = blk_m + warp_m_base + arith.index(wm * WMMA_M) + lane16 - row_off = ( - row * _avr_lda_i32 + _avr_lane_kgrp_off + arith.index(ks * WMMA_K // PACK_FACTOR_A // 4) - ) - loads = [] - for i in range_constexpr(DS_LOADS_PER_A_FRAG): - off = arith.index_cast(T.i32, row_off + arith.index(i * 8)) - v = fx.Vector( - buffer_ops.buffer_load( - _avr_a_rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff - ) - ) - loads.append(v) - if const_expr(DS_LOADS_PER_A_FRAG == 2): - frag = loads[0].shuffle(loads[1], list(range(8))) - else: - v01 = loads[0].shuffle(loads[1], list(range(8))) - v23 = loads[2].shuffle(loads[3], list(range(8))) - frag = v01.shuffle(v23, list(range(16))) - frags.append(frag.ir_value()) - return frags - - def _avr_load_ascale(k_base): - """Load A_scale for one K-tile via buffer_load (coalesced layout).""" - kt = k_base // arith.index(tile_k) - _NG = (wmma_m_rep + 3) // 4 - _S = k_wmma_steps * _NG * 32 * 4 - base_i32 = _avr_as_mb * arith.index(_avr_as_Kt) * arith.index(_S) - kt_soff = arith.index_cast(T.i32, kt * arith.index(_S) * arith.index(4)) - vals = [] - for ks in range_constexpr(k_wmma_steps): - for grp in range_constexpr(_NG): - grp_i32 = base_i32 + arith.index((ks * _NG + grp) * 32 * 4) + _avr_as_lane32 * arith.index(4) - off = arith.index_cast(T.i32, grp_i32) - v = fx.Vector( - buffer_ops.buffer_load(_avr_as_rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff) - ) - for j in range_constexpr(4): - if const_expr(grp * 4 + j < wmma_m_rep): - vals.append(v[j]) - return vals - - def _avr_prefetch(k_base): - """Issue A data (and optionally A_scale) prefetch for one K-tile. - - Returns (a_frags, a_scales) where a_frags is a list of - vec IR values and a_scales is a list of i32 (or empty). - """ - a_frags = _avr_load_a_tile(k_base) - if const_expr(use_ascale_vgpr): - a_scales = _avr_load_ascale(k_base) - else: - a_scales = [] - return a_frags, a_scales - - _a_vgpr_box = [None] - _a_vgpr_ascale_box = [None] n_stride = fx.Index(i32_ldc) c_nrec = m_idx * n_stride * arith.index(elem_bytes_d) c_rsrc = buffer_ops.create_buffer_resource(arg_c, num_records_bytes=c_nrec) @@ -1082,47 +903,6 @@ def make_desc_b(memref, k_base): early_timeout=True, ) - def make_desc_a_half(memref, k_base, m_half: int): - row_start = m_half * ab_split_a_rows - k_packed_off = k_base // arith.index(PACK_FACTOR_A) - return _make_tdm_desc( - global_ptr=arg_a, - lds_memref=memref, - global_offset=(blk_m + arith.index(row_start), k_packed_off), - tensor_shape=(tile_m, packed_tile_k_a), - strides=(lda_packed, 1), - tile_shape=(ab_split_a_rows, packed_tile_k_a), - elem_bytes=1, - pad_interval=packed_tile_k_a, - pad_amount=LDS_PAD_A_BYTES, - num_warps=1, - workgroup_mask=a_mcast_mask, - lds_byte_offset=arith.index(row_start * lds_a_stride_bytes), - atomic_barrier_enable=atomic_barrier_enable, - early_timeout=True, - oob_outer_bound=i32_m, - ) - - def make_desc_b_half(memref, k_base, n_half: int): - group_start = n_half * ab_split_b_groups - k_packed_off = k_base // arith.index(PACK_FACTOR_B) - return _make_tdm_desc( - global_ptr=arg_b, - lds_memref=memref, - global_offset=(blk_n // arith.index(16) + arith.index(group_start), k_packed_off * arith.index(16)), - tensor_shape=(N // 16, K_packed_b * 16), - strides=(K_packed_b * 16, 1), - tile_shape=(ab_split_b_groups, packed_tile_k_b * 16), - elem_bytes=1, - pad_interval=0, - pad_amount=0, - num_warps=1, - workgroup_mask=b_mcast_mask, - lds_byte_offset=arith.index(group_start * packed_tile_k_b * 16), - atomic_barrier_enable=atomic_barrier_enable, - early_timeout=True, - ) - def make_desc_as(memref, k_base): k_scale_off = k_base // arith.index(SCALE_BLOCK) outer_off = blk_m // arith.index(wmma_m_rep) @@ -1210,8 +990,8 @@ def _precompute_a_lane_bases(lds_ptr): bases.append(base) return lds_ptr, bases - def load_a_frag(lds_buffer, a_lane_base, ks, wm=0): - """Load one A-fragment from LDS (or VGPR box when use_a_vgpr). + def load_a_frag(lds_buffer, a_lane_base, ks): + """Load one A-fragment from LDS. FP4: vec<8xi32> via 2 × ds_load_b128 (32 bytes per lane). FP8/A8W4: vec<16xi32> via 4 × ds_load_b128 (64 bytes per lane). @@ -1219,8 +999,6 @@ def load_a_frag(lds_buffer, a_lane_base, ks, wm=0): kgrp0 reads bytes [0:15],[32:47],[64:79],[96:111] (stride=32) kgrp1 reads bytes [16:31],[48:63],[80:95],[112:127] (stride=32) """ - if const_expr(use_a_vgpr): - return _a_vgpr_box[0][ks * wmma_m_rep + wm] k_byte_off = arith.index(ks * WMMA_K // PACK_FACTOR_A) byte_off = a_lane_base + k_byte_off v0 = fx.Vector(lds_load_b128_raw(lds_buffer, byte_off)) @@ -1420,11 +1198,6 @@ def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): a = pf_a[ks * wmma_m_rep : (ks + 1) * wmma_m_rep] b = pf_b[ks * b_scale_load_rep : (ks + 1) * b_scale_load_rep] return a, b - if const_expr(use_ascale_vgpr): - # A_scale from VGPR (bundled with A prefetch ring), B_scale from LDS. - a = _a_vgpr_ascale_box[0][ks * wmma_m_rep : (ks + 1) * wmma_m_rep] - b = _load_b_scale_lds(bs_buf, bs_bases, ks) - return a, b a_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) b_all = _load_b_scale_lds(bs_buf, bs_bases, ks) if const_expr(use_scale_opsel): @@ -1440,7 +1213,7 @@ def _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, ks): return b_frags, b_scales, a_scales def _load_a_and_scales(a_buf, a_bases, as_buf, as_bases, bs_buf, bs_bases, ks): - a_frags = [load_a_frag(a_buf, a_bases[wm], ks, wm=wm) for wm in range_constexpr(wmma_m_rep)] + a_frags = [load_a_frag(a_buf, a_bases[wm], ks) for wm in range_constexpr(wmma_m_rep)] a_scales, b_scales = _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks) return a_frags, a_scales, b_scales @@ -1549,7 +1322,7 @@ def _emit_rows(start_wm, a_frags): wn = (wmma_n_rep - 1 - wn_raw) if (wm % 2 == 1) else wn_raw _emit_wmma(accs, wm, wn, a_frags[frag_i], b_frags[wn], a_scales, b_scales) - a_frags_front = [load_a_frag(a_buf, a_bases[wm], ks, wm=wm) for wm in range_constexpr(_front_wm)] + a_frags_front = [load_a_frag(a_buf, a_bases[wm], ks) for wm in range_constexpr(_front_wm)] _use_partial_drain = next_bs_info is not None and _front_wm * wmma_n_rep >= 4 @@ -1567,9 +1340,7 @@ def _emit_rows(start_wm, a_frags): mid_compute_callback() if const_expr(_back_wm > 0): - a_frags_back = [ - load_a_frag(a_buf, a_bases[_front_wm + h], ks, wm=_front_wm + h) for h in range_constexpr(_back_wm) - ] + a_frags_back = [load_a_frag(a_buf, a_bases[_front_wm + h], ks) for h in range_constexpr(_back_wm)] _back_drain = _bs_ds_loads if _use_partial_drain else 0 rocdl.s_wait_dscnt(_back_drain) _emit_rows(_front_wm, a_frags_back) @@ -1647,14 +1418,8 @@ def compute_tile( scale_k_base=None, pf_a_scales=None, pf_b_scales=None, - pf_a_data=None, - pf_a_data_scales=None, ): current_accs = list(accs_in) - if const_expr(use_a_vgpr): - _a_vgpr_box[0] = pf_a_data - if const_expr(use_ascale_vgpr): - _a_vgpr_ascale_box[0] = pf_a_data_scales if const_expr(use_general_vgpr_scale): # Scales come from VGPR: use the loop-prefetched ring when provided, # else issue the buffer_loads inline (tail path) for scale_k_base. @@ -2277,7 +2042,7 @@ def hot_loop_scheduler(): _b_loads_per_frag = 2 if is_a8w4 else 4 # No scale ds_loads when scales are in registers (PTPC epilogue / VGPR). _scale_dsrd = 0 if (is_ptpc or use_general_vgpr_scale) else 2 - _a_half_dsrd = 0 if use_a_vgpr else _half_wm * DS_LOADS_PER_A_FRAG + _a_half_dsrd = _half_wm * DS_LOADS_PER_A_FRAG for _ks in range_constexpr(k_wmma_steps): if const_expr(_ks == 0): @@ -2393,8 +2158,6 @@ def compute_tile_scheduled( scale_k_base=None, pf_a_scales=None, pf_b_scales=None, - pf_a_data=None, - pf_a_data_scales=None, ): if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): return compute_tile_b_streaming( @@ -2453,8 +2216,6 @@ def compute_tile_scheduled( scale_k_base=scale_k_base, pf_a_scales=pf_a_scales, pf_b_scales=pf_b_scales, - pf_a_data=pf_a_data, - pf_a_data_scales=pf_a_data_scales, ) def hot_loop_scheduler_b_streaming(): @@ -2766,71 +2527,24 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): stages_as_lds_addr = [] stages_bs_lds_addr = [] for i in range_constexpr(num_buffers): - if const_expr(not use_a_vgpr): - stages_a_lds_addr.append(_dg0_lane(make_desc_a(stages_a_mem[i], arith.index(0)), 1)) + stages_a_lds_addr.append(_dg0_lane(make_desc_a(stages_a_mem[i], arith.index(0)), 1)) stages_b_lds_addr.append(_dg0_lane(make_desc_b(stages_b_mem[i], arith.index(0)), 1)) - if const_expr(not is_ptpc and not use_ascale_vgpr): - stages_as_lds_addr.append(_dg0_lane(make_desc_as(stages_as_mem[i], arith.index(0)), 1)) if const_expr(not is_ptpc): + stages_as_lds_addr.append(_dg0_lane(make_desc_as(stages_as_mem[i], arith.index(0)), 1)) stages_bs_lds_addr.append(_dg0_lane(make_desc_bs(stages_bs_mem[i], arith.index(0)), 1)) - if const_expr(not use_a_vgpr): - desc_a_init = make_desc_a(stages_a_mem[0], split_k_base) + desc_a_init = make_desc_a(stages_a_mem[0], split_k_base) desc_b_init = make_desc_b(stages_b_mem[0], split_k_base) - if const_expr(is_ptpc or use_ascale_vgpr): - # Alias unused A/A_scale slots to B (predicated off waves). - if const_expr(not stages_a_lds_addr): - stages_a_lds_addr = stages_b_lds_addr - if const_expr(not stages_as_lds_addr): - stages_as_lds_addr = stages_b_lds_addr - if const_expr(not stages_bs_lds_addr): - stages_bs_lds_addr = stages_b_lds_addr - if const_expr(use_a_vgpr): - desc_a_init = desc_b_init - if const_expr(is_ptpc): - desc_as_init = desc_b_init - desc_bs_init = desc_b_init - else: - desc_as_init = desc_b_init - desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) - elif const_expr(use_a_vgpr): - # A via VGPR, scales via TDM: alias A slot to B (wave0 predicated off). - stages_a_lds_addr = stages_b_lds_addr - desc_a_init = desc_b_init - desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) - desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) + if const_expr(is_ptpc): + # No scale TDM for PTPC: alias the scale descriptors/addresses to A/B. + # Scale waves are predicated off, so these selections are never issued. + stages_as_lds_addr = stages_a_lds_addr + stages_bs_lds_addr = stages_b_lds_addr + desc_as_init = desc_a_init + desc_bs_init = desc_b_init else: desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) - if const_expr(use_ab_half_split): - stages_a0_lds_addr = [] - stages_b0_lds_addr = [] - stages_a1_lds_addr = [] - stages_b1_lds_addr = [] - for i in range_constexpr(num_buffers): - stages_a0_lds_addr.append(_dg0_lane(make_desc_a_half(stages_a_mem[i], arith.index(0), 0), 1)) - stages_b0_lds_addr.append(_dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 0), 1)) - stages_a1_lds_addr.append(_dg0_lane(make_desc_a_half(stages_a_mem[i], arith.index(0), 1), 1)) - stages_b1_lds_addr.append(_dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 1), 1)) - - desc_a0_init = make_desc_a_half(stages_a_mem[0], split_k_base, 0) - desc_b0_init = make_desc_b_half(stages_b_mem[0], split_k_base, 0) - desc_a1_init = make_desc_a_half(stages_a_mem[0], split_k_base, 1) - desc_b1_init = make_desc_b_half(stages_b_mem[0], split_k_base, 1) - - if const_expr(_b_nsplit or _tdm_b_nsplit_scale_combo): - # N-direction B halves: wave0 -> N-groups [0:tile_n//32], wave1 -> - # [tile_n//32:tile_n//16]. N is the outer LDS dim, so the two halves - # write contiguous blocks that together equal the full-B tile layout - # (make_desc_b_half bakes the N offset into both the global offset and - # lds_byte_offset). load_b_frag reads are unchanged. - nstages_b0_lds_addr = [] - nstages_b1_lds_addr = [] - for i in range_constexpr(num_buffers): - nstages_b0_lds_addr.append(_dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 0), 1)) - nstages_b1_lds_addr.append(_dg0_lane(make_desc_b_half(stages_b_mem[i], arith.index(0), 1), 1)) - desc_bn0_init = make_desc_b_half(stages_b_mem[0], split_k_base, 0) - desc_bn1_init = make_desc_b_half(stages_b_mem[0], split_k_base, 1) adv_a_i32 = fx.Int32(tile_k // PACK_FACTOR_A) adv_b_i32 = fx.Int32(packed_tile_k_b * 16) @@ -2841,21 +2555,9 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): - _drop_scale_waves = is_ptpc or (use_buffer_vgpr_scale and not use_ab_half_split) or use_ascale_vgpr - if const_expr(_b_nsplit or _tdm_b_nsplit_scale_combo): - # Split variants use only the first four waves. Keep extra compute - # waves from falling through the 4-slot selector to the last slot. - active_pred_const = arith.select(tdm_wave_id < fx.Int32(4), fx.Int32(1), fx.Int32(0)) - elif const_expr(_drop_a_loader_wave and not _drop_scale_waves): - # A via VGPR, scales via TDM: wave0 (A) idle, waves 1,2,3 active - active_pred_const = arith.select( - tdm_wave_id >= fx.Int32(1), - arith.select(tdm_wave_id < fx.Int32(4), fx.Int32(1), fx.Int32(0)), - fx.Int32(0), - ) - else: - _active_wave_limit = 2 if _drop_scale_waves else 4 - active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) + _drop_scale_waves = is_ptpc or use_buffer_vgpr_scale + _active_wave_limit = 2 if _drop_scale_waves else 4 + active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) def _select4(values): return _select_wave_tdm_value(values[0], values[1], values[2], values[3]) @@ -2884,46 +2586,7 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): else: active_pred_const = pred_const - if const_expr(use_ab_half_split): - # All 4 waves load A/B halves: wave0=A0, wave1=B0, wave2=A1, wave3=B1. - # Both halves of A share adv_a (same K-step); both halves of B share adv_b. - active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( - (stages_a0_lds_addr, stages_b0_lds_addr, stages_a1_lds_addr, stages_b1_lds_addr), - (desc_a0_init, desc_b0_init, desc_a1_init, desc_b1_init), - (adv_a_i32, adv_b_i32, adv_a_i32, adv_b_i32), - ) - elif const_expr(_b_nsplit): - # B N-split: wave0=B N-half0, wave1=B N-half1 (both adv by a full - # tile_k; the N offset is constant), wave2=A_scale, wave3=B_scale. - active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( - (nstages_b0_lds_addr, nstages_b1_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), - (desc_bn0_init, desc_bn1_init, desc_as_init, desc_bs_init), - (adv_b_i32, adv_b_i32, adv_as_i32, adv_bs_i32), - ) - elif const_expr(_tdm_b_nsplit_scale_combo): - # A + B N-split + combined scale wave: - # wave0=A, wave1=B N-half0, wave2=B N-half1, wave3=A_scale; - # wave3 additionally issues B_scale below with an independent address. - active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( - (stages_a_lds_addr, nstages_b0_lds_addr, nstages_b1_lds_addr, stages_as_lds_addr), - (desc_a_init, desc_bn0_init, desc_bn1_init, desc_as_init), - (adv_a_i32, adv_b_i32, adv_b_i32, adv_as_i32), - ) - active_extra_pred_const = arith.select(tdm_wave_id == fx.Int32(3), fx.Int32(1), fx.Int32(0)) - active_extra_stage_lds_addr = stages_bs_lds_addr - active_extra_addr_lo = _dg0_lane(desc_bs_init, 2) - active_extra_addr_hi = _dg0_lane(desc_bs_init, 3) - active_extra_dgroup1 = desc_bs_init.dgroup1 - active_extra_adv_i32 = adv_bs_i32 - elif const_expr(wave_specialized_tdm and use_ascale_vgpr): - # A + A_scale via VGPR: only B (wave0) and B_scale (wave1) need TDM. - # Remap: slot0=B, slot1=B_scale, slots 2,3 aliased (predicated off). - active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( - (stages_b_lds_addr, stages_bs_lds_addr, stages_b_lds_addr, stages_bs_lds_addr), - (desc_b_init, desc_bs_init, desc_b_init, desc_bs_init), - (adv_b_i32, adv_bs_i32, adv_b_i32, adv_bs_i32), - ) - elif const_expr(wave_specialized_tdm): + if const_expr(wave_specialized_tdm): active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), (desc_a_init, desc_b_init, desc_as_init, desc_bs_init), @@ -2944,52 +2607,18 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): dgroup1_as = desc_as_init.dgroup1 dgroup1_bs = desc_bs_init.dgroup1 - def _pipeline_tensor_wait(outstanding=0): - if const_expr(_tdm_b_nsplit_scale_combo and outstanding > 0): - if_op = scf.IfOp(tdm_wave_id == fx.Int32(3), [], has_else=True) - with ir.InsertionPoint(if_op.then_block): - tdm_ops.tensor_wait(outstanding * 2) - scf.YieldOp([]) - with ir.InsertionPoint(if_op.else_block): - tdm_ops.tensor_wait(outstanding) - scf.YieldOp([]) - else: - tdm_ops.tensor_wait(outstanding) - def _pipeline_fence(outstanding=0): - if const_expr(_tdm_b_nsplit_scale_combo): - _pipeline_tensor_wait(outstanding) - if const_expr(use_cluster): - cluster.cluster_barrier() - else: - gpu.barrier() - else: - pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) + pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) def _pipeline_fence_signal(outstanding=0): - if const_expr(_tdm_b_nsplit_scale_combo): - _pipeline_tensor_wait(outstanding) - rocdl.s_barrier_signal(-1) - if const_expr(use_cluster): - cluster.cluster_signal_once_per_wg() - else: - pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) + pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) if const_expr(wave_specialized_tdm): - def _issue_active_tdm(load_stage, addr_box, extra_addr_box=None, k_prefetch=None): + def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): dg0 = _pack_dg0(active_pred_const, active_stage_lds_addr[load_stage], addr_box[0], active_addr_hi) tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) addr_box[0] = addr_box[0] + active_adv_i32 - if const_expr(_tdm_b_nsplit_scale_combo): - dg0_extra = _pack_dg0( - active_extra_pred_const, - active_extra_stage_lds_addr[load_stage], - extra_addr_box[0], - active_extra_addr_hi, - ) - tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0_extra, active_extra_dgroup1)) - extra_addr_box[0] = extra_addr_box[0] + active_extra_adv_i32 if k_prefetch is not None: _l2_prefetch(k_prefetch) @@ -2997,12 +2626,7 @@ def _issue_active_tdm(load_stage, addr_box, extra_addr_box=None, k_prefetch=None if const_expr(wave_specialized_tdm): for i in range_constexpr(pre_loaded): addr_box = [active_addr_lo] - if const_expr(_tdm_b_nsplit_scale_combo): - extra_addr_box = [active_extra_addr_lo] - _issue_active_tdm(i, addr_box, extra_addr_box) - active_extra_addr_lo = extra_addr_box[0] - else: - _issue_active_tdm(i, addr_box) + _issue_active_tdm(i, addr_box) active_addr_lo = addr_box[0] else: for i in range_constexpr(pre_loaded): @@ -3031,11 +2655,6 @@ def _issue_active_tdm(load_stage, addr_box, extra_addr_box=None, k_prefetch=None _bvs_ra = [_v for (_a, _b) in _bvs_pf for _v in _a] _bvs_rb = [_v for (_a, _b) in _bvs_pf for _v in _b] - if const_expr(_avr_active and loop_iters > 0): - _avr_pf = [_avr_prefetch(split_k_base + arith.index(_d * tile_k)) for _d in range(_avr_D)] - _avr_rf = [_v for (_f, _s) in _avr_pf for _v in _f] - _avr_rs = [_v for (_f, _s) in _avr_pf for _v in _s] - _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) # Main loop — acc_mixed style: fence at top, TDM_load mid-compute. @@ -3048,52 +2667,35 @@ def _issue_active_tdm(load_stage, addr_box, extra_addr_box=None, k_prefetch=None if const_expr(loop_iters > 0): if const_expr(wave_specialized_tdm): init_args = list(accs) + [active_addr_lo] - if const_expr(_tdm_b_nsplit_scale_combo): - init_args = init_args + [active_extra_addr_lo] if const_expr(_bvs_active): init_args = init_args + _bvs_ra + _bvs_rb - if const_expr(_avr_active): - init_args = init_args + _avr_rf + _avr_rs for loop_iter, state in range(0, loop_iters, 1, init=init_args): accs_in = list(state[:n_accs]) cur_addr_lo = state[n_accs] _state_off = n_accs + 1 - if const_expr(_tdm_b_nsplit_scale_combo): - cur_extra_addr_lo = state[_state_off] - _state_off += 1 if const_expr(_bvs_active): _ra0 = _state_off _ring_a = list(state[_ra0 : _ra0 + _bvs_D * _vs_tile_a]) _rb0 = _ra0 + _bvs_D * _vs_tile_a _ring_b = list(state[_rb0 : _rb0 + _bvs_D * _vs_tile_b]) _state_off = _rb0 + _bvs_D * _vs_tile_b - if const_expr(_avr_active): - _af0 = _state_off - _avr_ring_f = list(state[_af0 : _af0 + _avr_D * _avr_frags_per_tile]) - _as0 = _af0 + _avr_D * _avr_frags_per_tile - _avr_ring_s = list(state[_as0 : _as0 + _avr_D * _avr_ascale_per_tile]) for buf_idx in range_constexpr(num_buffers): load_stage = (buf_idx + num_buffers - 1) % num_buffers addr_box = [cur_addr_lo] - if const_expr(_tdm_b_nsplit_scale_combo): - extra_addr_box = [cur_extra_addr_lo] - else: - extra_addr_box = None def _mid_tdm_ws( _ls=load_stage, _ab=addr_box, - _eb=extra_addr_box, _k_off=( split_k_base + loop_iter * arith.index(num_buffers * tile_k) + arith.index(buf_idx * tile_k) ), ): - _issue_active_tdm(_ls, _ab, _eb, k_prefetch=_k_off) + _issue_active_tdm(_ls, _ab, k_prefetch=_k_off) if const_expr(not use_ws_tdm_split_signal_overlap): _pipeline_fence_signal(outstanding=_fence_outstanding) @@ -3128,21 +2730,6 @@ def _late_tdm_ws_split_signal(): _cur_a = None _cur_b = None - if const_expr(_avr_active): - _cur_ad = _avr_ring_f[:_avr_frags_per_tile] - _cur_as = _avr_ring_s[:_avr_ascale_per_tile] if _avr_ascale_per_tile else None - _next_akb = ( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index((buf_idx + _avr_D) * tile_k) - ) - _nf, _ns = _avr_prefetch(_next_akb) - _avr_ring_f = _avr_ring_f[_avr_frags_per_tile:] + list(_nf) - _avr_ring_s = _avr_ring_s[_avr_ascale_per_tile:] + list(_ns) - else: - _cur_ad = None - _cur_as = None - accs_in = compute_tile_scheduled( accs_in, stages_a_idx[buf_idx], @@ -3154,32 +2741,18 @@ def _late_tdm_ws_split_signal(): a0_prefetch=a0_prefetch, pf_a_scales=_cur_a, pf_b_scales=_cur_b, - pf_a_data=_cur_ad, - pf_a_data_scales=_cur_as, ) cur_addr_lo = addr_box[0] - if const_expr(_tdm_b_nsplit_scale_combo): - cur_extra_addr_lo = extra_addr_box[0] hot_loop_scheduler_scheduled() if const_expr(_bvs_active): _bvs_yield = _ring_a + _ring_b else: _bvs_yield = [] - if const_expr(_avr_active): - _avr_yield = _avr_ring_f + _avr_ring_s - else: - _avr_yield = [] - if const_expr(_tdm_b_nsplit_scale_combo): - _extra_yield = [cur_extra_addr_lo] - else: - _extra_yield = [] - results = yield list(accs_in) + [cur_addr_lo] + _extra_yield + _bvs_yield + _avr_yield + results = yield list(accs_in) + [cur_addr_lo] + _bvs_yield accs = list(results[:n_accs]) active_addr_lo = results[n_accs] - if const_expr(_tdm_b_nsplit_scale_combo): - active_extra_addr_lo = results[n_accs + 1] else: init_args = list(accs) + [addr_lo_a, addr_lo_b, addr_lo_as, addr_lo_bs] @@ -3319,29 +2892,8 @@ def _bvs_tail_scales(): for _ in range_constexpr(_bvs_D): _bvs_tail_issue_one() - _avr_tail_ring = [] - _avr_tail_issue_kt = [loop_iters * num_buffers] - - def _avr_tail_issue_one(): - if const_expr(_avr_active and _avr_tail_issue_kt[0] < num_k_tiles): - _kb = split_k_base + arith.index(_avr_tail_issue_kt[0] * tile_k) - _avr_tail_ring.append(_avr_prefetch(_kb)) - _avr_tail_issue_kt[0] += 1 - - def _avr_tail_consume(): - if const_expr(_avr_active): - _f, _s = _avr_tail_ring.pop(0) - return _f, _s if _s else None - return None, None - - if const_expr(_avr_active): - rocdl.sched_barrier(0) - for _ in range_constexpr(_avr_D): - _avr_tail_issue_one() - for _load_stage, _compute_stage, _outstanding in tail_plan: _entry_kb, _pf_a_scales, _pf_b_scales = _bvs_tail_scales() - _tail_ad, _tail_as = _avr_tail_consume() if const_expr(_outstanding == -1): if const_expr(_tail_had_load): _pipeline_fence(outstanding=0) @@ -3358,8 +2910,6 @@ def _avr_tail_consume(): scale_k_base=_entry_kb, pf_a_scales=_pf_a_scales, pf_b_scales=_pf_b_scales, - pf_a_data=_tail_ad, - pf_a_data_scales=_tail_as, ) else: @@ -3379,8 +2929,6 @@ def _emit_epi_addrs(): scale_k_base=_entry_kb, pf_a_scales=_pf_a_scales, pf_b_scales=_pf_b_scales, - pf_a_data=_tail_ad, - pf_a_data_scales=_tail_as, ) else: _pipeline_fence_signal(outstanding=_outstanding) @@ -3391,13 +2939,9 @@ def _emit_epi_addrs(): _tail_had_load = True if const_expr(wave_specialized_tdm): _tail_addr_box = [active_addr_lo] - if const_expr(_tdm_b_nsplit_scale_combo): - _tail_extra_addr_box = [active_extra_addr_lo] - else: - _tail_extra_addr_box = None - def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box, _eb=_tail_extra_addr_box): - _issue_active_tdm(_ls, _ab, _eb) + def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box): + _issue_active_tdm(_ls, _ab) _tail_mid_cb = _tail_mid_ws else: @@ -3425,7 +2969,6 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) rocdl.sched_barrier(0) _bvs_tail_issue_one() - _avr_tail_issue_one() accs = compute_tile_scheduled( accs, stages_a_idx[_compute_stage], @@ -3437,15 +2980,11 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): scale_k_base=_entry_kb, pf_a_scales=_pf_a_scales, pf_b_scales=_pf_b_scales, - pf_a_data=_tail_ad, - pf_a_data_scales=_tail_as, ) if const_expr(_load_stage is not None): if const_expr(wave_specialized_tdm): active_addr_lo = _tail_addr_box[0] - if const_expr(_tdm_b_nsplit_scale_combo): - active_extra_addr_lo = _tail_extra_addr_box[0] else: addr_lo_a = _tail_ab[0][0] addr_lo_b = _tail_ab[1][0] @@ -3516,8 +3055,6 @@ def _emit_buffer_store(): atomic_barrier_enable, b_streaming, scale_load_path, - a_load_path, - b_split_load, fp8_schedule, ) @@ -3580,26 +3117,18 @@ def launch_mxscale_gemm( def compile_mxscale_gemm(**kw): """Backward-compatible wrapper: MX block-scale (E8M0) GEMM.""" - if "b_split_load" not in kw: - kw["b_split_load"] = bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) return compile_fp8fp4_gemm(scale_mode="mxscale", **kw) def compile_mxfp4_gemm(**kw): - if "b_split_load" not in kw: - kw["b_split_load"] = bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) return compile_fp8fp4_gemm(data_format="fp4", scale_mode="mxscale", **kw) def compile_mxfp8_gemm(**kw): - if "b_split_load" not in kw: - kw["b_split_load"] = bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) return compile_fp8fp4_gemm(data_format="fp8", scale_mode="mxscale", **kw) def compile_a8w4_gemm(**kw): - if "b_split_load" not in kw: - kw["b_split_load"] = bool(int(os.environ.get("FLYDSL_B_KSPLIT", "0"))) return compile_fp8fp4_gemm(data_format="a8w4", scale_mode="mxscale", **kw) diff --git a/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp b/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp index 66a3b68d..a8b2d3f3 100644 --- a/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp +++ b/lib/Runtime/ROCm/FlyRocmRuntimeWrappers.cpp @@ -17,15 +17,8 @@ #include #include "hip/hip_runtime.h" -#include "hip/hip_version.h" #include "mlir/ExecutionEngine/CRunnerUtils.h" -// TODO(gfx1250-cluster): TEMPORARY. This version check does NOT actually guarantee -// cluster launch support (this version doesn't support it either). It's a debug-only -// hack for gfx1250 bring-up and will likely break other environments. Replace with a -// real capability check once the supported version/path is confirmed. -#define FLY_HIP_HAS_CLUSTER_LAUNCH (HIP_VERSION >= 70000000) - #define HIP_REPORT_IF_ERROR(expr) \ [](hipError_t result) { \ if (!result) \ @@ -75,44 +68,7 @@ extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, intptr_t cluster intptr_t blockY, intptr_t blockZ, int32_t smem, hipStream_t stream, void **params, void **extra, size_t /*paramsCount*/) { - const bool requestedRealCluster = (clusterX > 1) || (clusterY > 1) || (clusterZ > 1); - -#if FLY_HIP_HAS_CLUSTER_LAUNCH - hipStreamCaptureStatus capStatus = hipStreamCaptureStatusNone; - hipGraph_t capGraph = nullptr; - const hipGraphNode_t *capDeps = nullptr; - size_t numCapDeps = 0; - if (hipStreamGetCaptureInfo_v2(stream, &capStatus, /*id_out=*/nullptr, &capGraph, &capDeps, - &numCapDeps) == hipSuccess && - capStatus == hipStreamCaptureStatusActive) { - hipKernelNodeParams nodeParams{}; - nodeParams.func = reinterpret_cast(function); - nodeParams.gridDim = dim3(static_cast(gridX), static_cast(gridY), - static_cast(gridZ)); - nodeParams.blockDim = dim3(static_cast(blockX), static_cast(blockY), - static_cast(blockZ)); - nodeParams.sharedMemBytes = static_cast(smem); - nodeParams.kernelParams = params; - nodeParams.extra = extra; - - hipGraphNode_t node = nullptr; - hipError_t addErr = hipGraphAddKernelNode(&node, capGraph, capDeps, numCapDeps, &nodeParams); - if (addErr != hipSuccess) { - // Fail loudly: a silent empty graph would report bogus ~0us replay times. - fprintf(stderr, - "[mgpuLaunchClusterKernel] hipGraphAddKernelNode failed (err=%d) during stream " - "capture for cluster=(%ld,%ld,%ld); cluster kernels cannot be captured into a " - "hipGraph on this HIP build.\n", - static_cast(addErr), static_cast(clusterX), static_cast(clusterY), - static_cast(clusterZ)); - HIP_REPORT_IF_ERROR(addErr); - return; - } - HIP_REPORT_IF_ERROR( - hipStreamUpdateCaptureDependencies(stream, &node, 1, hipStreamSetCaptureDependencies)); - return; - } - +#ifdef hipLaunchAttributeClusterDimension hipLaunchAttribute attrs[1]; attrs[0].id = hipLaunchAttributeClusterDimension; attrs[0].value.clusterDim.x = static_cast(clusterX); @@ -135,6 +91,7 @@ extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, intptr_t cluster if (err == hipSuccess) return; + const bool requestedRealCluster = (clusterX > 1) || (clusterY > 1) || (clusterZ > 1); if (requestedRealCluster) { fprintf(stderr, "[mgpuLaunchClusterKernel] hipDrvLaunchKernelEx failed (err=%d) " @@ -146,7 +103,6 @@ extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, intptr_t cluster return; } - // cluster=(1,1,1) carries no cluster semantics — plain launch is equivalent. fprintf(stderr, "[mgpuLaunchClusterKernel] hipDrvLaunchKernelEx failed (err=%d) " "for cluster=(1,1,1); falling back to hipModuleLaunchKernel.\n", @@ -154,20 +110,15 @@ extern "C" void mgpuLaunchClusterKernel(hipFunction_t function, intptr_t cluster HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); #else - // HIP < 7.0: no cluster API. Refuse to downgrade silently — kernel relies on - // cluster semantics (multicast, cluster_barrier) that a plain launch breaks. - if (requestedRealCluster) { + // Cluster launch not supported by this HIP version; ignore cluster dims + // and fall back to regular kernel launch. + if ((clusterX > 1) || (clusterY > 1) || (clusterZ > 1)) { fprintf(stderr, "[mgpuLaunchClusterKernel] cluster=(%ld,%ld,%ld) requested but " - "FlyDSL was built against HIP %d (need HIP >= 7.0 / ROCm >= 7.0 " - "for hipDrvLaunchKernelEx + hipLaunchAttributeClusterDimension). " - "Aborting.\n", - static_cast(clusterX), static_cast(clusterY), static_cast(clusterZ), - HIP_VERSION); - HIP_REPORT_IF_ERROR(hipErrorNotSupported); - return; + "hipLaunchAttributeClusterDimension is not available in this HIP " + "version; falling back to hipModuleLaunchKernel.\n", + static_cast(clusterX), static_cast(clusterY), static_cast(clusterZ)); } - // cluster=(1,1,1): plain launch is equivalent. HIP_REPORT_IF_ERROR(hipModuleLaunchKernel(function, gridX, gridY, gridZ, blockX, blockY, blockZ, smem, stream, params, extra)); #endif diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 3e1b4f4f..2ad704d8 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -138,7 +138,7 @@ def preshuffle_e8m0_scale( """Preshuffle E8M0 scale: optional byte swap + interleave for WMMA access. ``coalesced=True`` produces the lane-major layout the scale_load_path - "vgpr"/"vgpr_ab_split" buffer_load->VGPR path expects. + "vgpr" buffer_load->VGPR path expects. """ if coalesced: return preshuffle_e8m0_scale_coalesced(scale, block=warp_tile) @@ -189,12 +189,12 @@ def preshuffle_scale_for_load_path( """Host scale preshuffle matching the kernel's selected scale_load_path. - 'tdm': interleaved TDM/LDS layout. - - 'vgpr'/'vgpr_ab_split' on the ref-segmented deep-pipeline config: legacy - lane-major coalesced layout. + - 'vgpr' on the ref-segmented deep-pipeline config: legacy lane-major + coalesced layout. - 'vgpr' on any other (general) config: general coalesced layout, with the a8w4/fp4 lane_kgrp scale shift. """ - if scale_load_path in ("vgpr", "vgpr_ab_split"): + if scale_load_path == "vgpr": if ref_segmented: return preshuffle_e8m0_scale(scale, warp_tile, scale_k_per_tile=skt, coalesced=True) kgrp_shift = 1 if data_format in ("a8w4", "fp4") else 0 @@ -493,10 +493,7 @@ def _run_mxscale_gemm_test( split_k=1, b_streaming=False, scale_load_path="tdm", - a_load_path="tdm", - b_split_load=False, return_launch_fn=False, - use_graph=False, ): """Unified test body for FP4 and FP8.""" is_fp4 = data_format == "fp4" @@ -539,14 +536,11 @@ def _run_mxscale_gemm_test( mcast_str = f", cluster=({cluster_m},{cluster_n})" if cluster_m > 1 or cluster_n > 1 else "" tdm_str = ", tdm_store" if use_tdm_store else ", buffer_store" scale_load_str = "" if scale_load_path == "tdm" else f", scale_load={scale_load_path}" - a_load_str = "" if a_load_path == "tdm" else f", a_load={a_load_path}" - b_split_str = ", b_split_load" if b_split_load else "" - graph_str = ", graph" if use_graph else "" pad_str = _format_kernel_pad(M, N, K, padded_shape) print( f"\nRunning {fmt_name} GEMM: M={M}, N={N}, K={K}{pad_str}, " f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}" - f"{mcast_str}{tdm_str}{scale_load_str}{a_load_str}{b_split_str}{graph_str}, preshuffle, out={out_dtype}" + f"{mcast_str}{tdm_str}{scale_load_str}, preshuffle, out={out_dtype}" ) # Generate data @@ -612,7 +606,6 @@ def _run_mxscale_gemm_test( scale_load_path=scale_load_path, use_scale_opsel=use_scale_opsel, b_streaming=b_streaming, - b_split_load=b_split_load, ): b_scale = preshuffle_e8m0_bscale_n4k4(b_scale) else: @@ -660,8 +653,6 @@ def _run_mxscale_gemm_test( expert_sched_mode=expert_sched_mode, b_streaming=b_streaming, scale_load_path=scale_load_path, - a_load_path=a_load_path, - b_split_load=b_split_load, ) # Keep 2D — dynamic_layout=True packs shape as i32; flattening overflows for M*K >= 2^31. @@ -671,7 +662,7 @@ def _run_mxscale_gemm_test( as_flat = as_gpu.contiguous() bs_flat = bs_gpu.contiguous() - compiled_exe = flyc.compile( + flyc.compile( launch_fn, c_flat, a_flat, @@ -684,35 +675,6 @@ def _run_mxscale_gemm_test( padded_n, torch.cuda.current_stream(), ) - - if use_graph: - - def _launch(): - compiled_exe( - c_flat, - a_flat, - b_flat, - as_flat, - bs_flat, - padded_m, - padded_n, - padded_k, - padded_n, - torch.cuda.current_stream(), - ) - - g = torch.cuda.CUDAGraph() - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - _launch() - torch.cuda.current_stream().wait_stream(s) - torch.cuda.synchronize() - c_gpu.zero_() - with torch.cuda.graph(g, stream=s): - _launch() - c_gpu.zero_() - g.replay() torch.cuda.synchronize() c_out = c_gpu[:M, :N].to(torch_out_dtype).cpu() @@ -844,8 +806,6 @@ def test_mxfp4_gemm( ) - - @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", [ @@ -979,6 +939,83 @@ def test_a8w4_gemm_irregular_m_tile16(M, N, K, use_tdm_store): ) +# ── Tile-independent N4K4 B-scale coverage ── +# tile_m=16, m_warp=1 -> wmma_m_rep=1 (odd) -> the default row-major streaming +# schedule, which is the (phase-1) N4K4 B-scale path. The sweep covers every +# tile_n/n_warp that maps to a distinct read shape (b32/b64/b128 per_load and +# group counts 1/2/4 and the non-power-of-2 group count 3 that exercises the +# TDM warp-distribution power-of-two padding), both data formats, k_wmma_steps +# 1/2/4, wave-spec on/off, f32/bf16, multi-buffer, and ragged/decode M. +_N4K4_N_FOR_TN = {16: 128, 32: 128, 64: 128, 128: 256, 192: 384, 256: 512} +_N4K4_TN_NW = [ + (16, 1), (32, 1), (32, 2), (64, 1), (64, 2), (64, 4), + (128, 1), (128, 2), (128, 4), (192, 1), (192, 2), (192, 4), + (256, 1), (256, 2), (256, 4), +] # fmt: skip + + +def _gen_n4k4_configs(): + cfgs, seen = [], set() + + def add(fmt, M, tile_n, n_warp, tile_k, nbuf, od, ws): + N = _N4K4_N_FOR_TN[tile_n] + K = tile_k * max(nbuf, 2) # >= nbuf K-tiles for double/triple buffering + key = (fmt, M, N, K, tile_n, tile_k, n_warp, nbuf, od, ws) + if key not in seen: + seen.add(key) + cfgs.append(key) + + for fmt in ("fp8", "a8w4"): + # 1) full tile_n x n_warp shape sweep (all rep/group/per_load cases), + # non-wave-spec so the cooperative TDM warp distribution is exercised. + for tn, nw in _N4K4_TN_NW: + add(fmt, 16, tn, nw, 256, 2, "bf16", False) + # 2) wave-spec (needs >=4 waves -> n_warp=4), M=1 decode-like. The real + # decode shape (tile_n=64) uses deep K + 4 buffers; larger tile_n keeps + # a modest tile so LDS fits while still exercising the wave-spec TDM. + add(fmt, 1, 64, 4, 512, 4, "bf16", True) + for tn in (128, 192, 256): + add(fmt, 1, tn, 4, 256, 2, "bf16", True) + # 3) k_wmma_steps 1/2/4 on the next_pow2 (192) and clean (256/64) shapes. + for tn, nw in [(192, 4), (256, 4), (64, 4)]: + for tk in (128, 512): + add(fmt, 16, tn, nw, tk, 2, "bf16", False) + # 4) f32 + triple buffering on a few shapes. + for tn, nw in [(192, 4), (128, 2), (32, 2)]: + add(fmt, 16, tn, nw, 256, 3, "f32", False) + # 5) ragged / decode / OOB M. + for M in (1, 13, 33): + add(fmt, M, 256, 4, 256, 2, "bf16", False) + return cfgs + + +@pytest.mark.parametrize( + "data_format, M, N, K, tile_n, tile_k, n_warp, num_buffers, out_dtype, ws", _gen_n4k4_configs() +) +def test_mxscale_n4k4_bscale(data_format, M, N, K, tile_n, tile_k, n_warp, num_buffers, out_dtype, ws): + # Guard: every config here must actually take the N4K4 B-scale layout, else + # the sweep would silently test the legacy path instead. + assert use_n4k4_bscale_layout( + data_format=data_format, tile_m=16, tile_n=tile_n, tile_k=tile_k, m_warp=1, n_warp=n_warp, n=N + ), f"config does not hit the N4K4 gate: {(data_format, tile_n, tile_k, n_warp, N)}" + _run_mxscale_gemm_test( + data_format, + M, + N, + K, + 16, + tile_n, + tile_k, + 1, + n_warp, + num_buffers, + use_tdm_store=True, + out_dtype=out_dtype, + wave_specialized_tdm=ws, + l2_prefetch_distance=0, + use_scale_opsel=False, + scale_load_path="tdm", + ) @pytest.mark.parametrize( @@ -1020,23 +1057,14 @@ def test_mxfp4_gemm_mcast( @pytest.mark.parametrize( - "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", + "data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp", [ - ("fp8", 128, 256, 256, 128, 256, 128, 2, 2, 1, 1), - ("fp4", 128, 256, 256, 128, 256, 128, 2, 2, 1, 1), - ("fp8", 256, 512, 256, 128, 256, 128, 2, 2, 2, 2), - ("fp4", 256, 512, 256, 128, 256, 128, 2, 2, 2, 2), - ("a8w4", 256, 512, 256, 128, 256, 128, 2, 4, 2, 2), - ], - ids=[ - "fp8-128x256x256", - "fp4-128x256x256", - "fp8-256x512x256-cluster2x2", - "fp4-256x512x256-cluster2x2", - "a8w4-256x512x256-cluster2x2", + ("fp8", 128, 256, 256, 128, 256, 128, 2, 2), + ("fp4", 128, 256, 256, 128, 256, 128, 2, 2), ], + ids=["fp8-128x256x256", "fp4-128x256x256"], ) -def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n): +def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp): """Verify that the gfx1250 MX-scale GEMM kernel works inside a hipGraph. Captures one launch, replays once, and checks the replay output is @@ -1051,15 +1079,11 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ pytest.skip("hipGraph capture/replay not supported on simulator") is_fp4 = data_format == "fp4" - is_a8w4 = data_format == "a8w4" # Build inputs (mirrors _run_mxscale_gemm_test, but no padding needed # because we pick a clean shape). torch.manual_seed(0) - if is_a8w4: - a = random_fp8_data(M, K) # FP8 activation - b = fp4_utils.random_fp4_packed(N, K) # FP4 weight - elif is_fp4: + if is_fp4: a = fp4_utils.random_fp4_packed(M, K) b = fp4_utils.random_fp4_packed(N, K) else: @@ -1073,7 +1097,7 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ warp_tile_n = tile_n // n_warp a_scale_ps = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) b_scale_ps = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) - pack_b = 2 if (is_fp4 or is_a8w4) else 1 + pack_b = 2 if is_fp4 else 1 b_ps = fp4_utils.preshuffle_b_16x16(b, N, K // pack_b) a_gpu = a.cuda() @@ -1096,8 +1120,6 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ out_dtype="bf16", wave_specialized_tdm=False, split_k=1, - cluster_m=cluster_m, - cluster_n=cluster_n, ) c_flat = c_gpu.contiguous() @@ -1854,8 +1876,7 @@ def _run_benchmark(args): print( f" Buffers={args.num_buffers}, out={args.out_dtype}, " f"opsel={args.use_scale_opsel}, inst_prefetch={args.inst_prefetch}, " - f"scale_load={args.scale_load_path}, a_load={args.a_load_path}, " - f"b_split_load={args.b_split_load}" + f"scale_load={args.scale_load_path}" ) if args.warmup < 0: raise ValueError(f"--warmup must be >= 0, got {args.warmup}") @@ -1888,8 +1909,6 @@ def _run_benchmark(args): _ptpc_ignored.append(f"--scale-load-path {args.scale_load_path}") if args.b_streaming: _ptpc_ignored.append("--b-streaming") - if args.b_split_load: - _ptpc_ignored.append("--b-split-load") if _ptpc_ignored: print(f" Note: PTPC ignores (forced internally): {', '.join(_ptpc_ignored)}") print("=" * 72) @@ -2032,8 +2051,6 @@ def _run_benchmark(args): atomic_barrier_enable=args.atomic_barrier_enable, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, - a_load_path=args.a_load_path, - b_split_load=args.b_split_load, ) compiled_exe = flyc.compile( @@ -2317,8 +2334,6 @@ def _run_graph_verify(args): atomic_barrier_enable=args.atomic_barrier_enable, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, - a_load_path=args.a_load_path, - b_split_load=args.b_split_load, ) c_flat = c_gpu.contiguous() @@ -2434,24 +2449,10 @@ def launch(): "--scale-load-path", type=str, default="tdm", - choices=["tdm", "vgpr", "vgpr_ab_split"], - ) - parser.add_argument( - "--a-load-path", - type=str, - default="tdm", - choices=["tdm", "vgpr", "vgpr_ascale"], + choices=["tdm", "vgpr"], ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument("--b-streaming", action="store_true", default=False) - parser.add_argument( - "--b-split-load", - action="store_true", - default=False, - help="Split the B TDM load by N groups. With --a-load-path vgpr this reuses " - "wave0 for B half 0; with TDM A/scale it uses wave0=A, wave1/2=B halves, " - "wave3=A_scale+B_scale.", - ) parser.add_argument( "--atomic-barrier-enable", action="store_true", @@ -2511,8 +2512,6 @@ def launch(): if args.scale_mode == "ptpc" and args.verify_graph: raise SystemExit("--scale-mode ptpc does not support --verify-graph") - if args.scale_mode == "ptpc" and args.use_graph and not args.benchmark: - raise SystemExit("--scale-mode ptpc does not support --use-graph for functional tests (use --benchmark)") def _run_correctness_test(): """Run the functional test (computes a reference and asserts correctness).""" @@ -2560,9 +2559,6 @@ def _run_correctness_test(): expert_sched_mode=args.expert_sched_mode, b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, - a_load_path=args.a_load_path, - b_split_load=args.b_split_load, - use_graph=args.use_graph, ) if args.verify_graph: From 2dad2bfbfd9e42b26ed23a94a5d83902fc783f8a Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 15 Jun 2026 07:10:50 +0000 Subject: [PATCH 08/16] add b scale op_sel optimization --- kernels/gemm_fp8fp4_gfx1250.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index ec256b37..c73cd108 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -338,19 +338,24 @@ def compile_fp8fp4_gemm( use_scale_opsel=use_scale_opsel, b_streaming=b_streaming, ) + use_n4k4_opsel = False if use_n4k4_bscale: if K_scale % 4 != 0: raise ValueError(f"N4K4 B-scale requires K_scale % 4 == 0, got {K_scale}") n4k4_n_groups = N // 64 n4k4_bs_global_row_stride = (K // WMMA_K) * 256 n4k4_bs_lds_row_stride = k_wmma_steps * 256 - # Impact: cost-free when tile_n//64 is already a power of - # two (tile_n=64/128/256 -> 1/2/4 groups); only a non-pow2 group count - # (e.g. tile_n=192: 3->4) copies one extra oob-clipped scale group per - # tile (~0.1% of B traffic, no extra WMMA). - _n4k4_groups = (tile_n + 63) // 64 - n4k4_bs_tile_groups = 1 << (_n4k4_groups - 1).bit_length() + # Per-tile N-group count padded to a power of two so the TDM warp split + # stays clean (a non-pow2 count, e.g. 192->3, miscopies LDS). Cost-free + # for 64/128/256 (1/2/4 groups); tile_n=192 copies 1 extra oob-clipped group. + n4k4_bs_tile_groups = 1 << ((tile_n + 63) // 64 - 1).bit_length() n4k4_bs_lds_rows = n4k4_bs_tile_groups + # N op_sel: pack blocks (j, j+rep/2) into one VGPR via lane_kgrp (kgrp1 = + # the "second half"), halving B-scale loads & VGPRs. Power-of-2 rep only + # (then the kgrp byte offset is uniform); rep 1/3/6/12 stay off. + use_n4k4_opsel = wmma_n_rep >= 2 and (wmma_n_rep & (wmma_n_rep - 1)) == 0 + _half = wmma_n_rep // 2 + n4k4_opsel_kgrp_off = (_half // 4) * n4k4_bs_lds_row_stride + (_half % 4) * 4 _b_frag_loads_per_wn = 2 if is_a8w4 else 4 _a_frag_loads_per_wm = 2 if is_fp4 else 4 @@ -1126,15 +1131,21 @@ def _precompute_n4k4_bscale_bases(lds_ptr): _N4K4_LOADERS = {1: lds_load_b32_raw, 2: lds_load_b64_raw, 4: lds_load_b128_raw} def load_n4k4_bscale(lds_buffer, bases, reps, ks=0): - """Load *reps* B-scale i32s from the N4K4 LDS layout for K-subtile *ks*.""" + """Load N4K4 B-scale i32s for K-subtile *ks*.""" b0, lane_off = bases ks_off = arith.index(ks * 256) row_stride = arith.index(n4k4_bs_lds_row_stride) - per_load = 4 if reps % 4 == 0 else (2 if reps % 2 == 0 else 1) + if const_expr(use_n4k4_opsel): + n_load = reps // 2 # read first half; kgrp1 supplies the matching second half + lane = lane_off + lane_kgrp * arith.index(n4k4_opsel_kgrp_off) + else: + n_load = reps + lane = lane_off + per_load = 4 if n_load % 4 == 0 else (2 if n_load % 2 == 0 else 1) results = [] - for i in range_constexpr(reps // per_load): + for i in range_constexpr(n_load // per_load): blk = b0 + arith.index(i * per_load) - off = (blk // arith.index(4)) * row_stride + (blk % arith.index(4)) * arith.index(4) + lane_off + ks_off + off = (blk // arith.index(4)) * row_stride + (blk % arith.index(4)) * arith.index(4) + lane + ks_off raw = _N4K4_LOADERS[per_load](lds_buffer, off) if const_expr(per_load == 1): results.append(raw) @@ -1273,6 +1284,9 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): if const_expr(use_scale_opsel): b_scale_idx = wn // 2 b_opsel = wn % 2 + elif const_expr(use_n4k4_opsel): + b_scale_idx = wn % (wmma_n_rep // 2) + b_opsel = wn // (wmma_n_rep // 2) else: b_scale_idx = wn b_opsel = 0 From d60d9ff12d85be33f9d8a1ccb1cdd9d45a48477b Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 15 Jun 2026 07:33:57 +0000 Subject: [PATCH 09/16] remove the b streaming exp codes --- kernels/gemm_fp8fp4_gfx1250.py | 158 +--------------------- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 9 -- 2 files changed, 2 insertions(+), 165 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index c73cd108..883dfef8 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -133,7 +133,6 @@ def use_n4k4_bscale_layout( scale_mode="mxscale", scale_load_path="tdm", use_scale_opsel=False, - b_streaming=False, ): """Whether B-scale uses the tile-independent N4K4 preshuffle layout.""" if scale_mode != "mxscale": @@ -155,7 +154,7 @@ def use_n4k4_bscale_layout( n_accs = wmma_m_rep * wmma_n_rep # Row-major streaming is selected exactly when a rep is odd or n_accs < 8 # (see _pick_compute_schedule_kind); otherwise fp8/a8w4 route to quadrant. - return (not b_streaming) and (wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8) + return wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8 @functools.lru_cache(maxsize=256) @@ -183,7 +182,6 @@ def compile_fp8fp4_gemm( use_scale_opsel: bool = False, expert_sched_mode: bool = True, atomic_barrier_enable: bool = False, - b_streaming: bool = False, scale_load_path: str = "tdm", fp8_schedule: str = "auto", ): @@ -228,8 +226,6 @@ def compile_fp8fp4_gemm( raise ValueError(f"fp8_schedule must be one of {fp8_schedule_modes}, got {fp8_schedule!r}") if fp8_schedule != "auto" and data_format != "fp8": raise ValueError(f"fp8_schedule={fp8_schedule!r} is only valid for data_format='fp8'") - if fp8_schedule != "auto" and b_streaming: - raise ValueError("fp8_schedule cannot be combined with b_streaming=True") effective_expert_sched_mode = bool(expert_sched_mode) if num_buffers not in (2, 3, 4, 5, 6): @@ -336,7 +332,6 @@ def compile_fp8fp4_gemm( scale_mode=scale_mode, scale_load_path=scale_load_path, use_scale_opsel=use_scale_opsel, - b_streaming=b_streaming, ) use_n4k4_opsel = False if use_n4k4_bscale: @@ -568,7 +563,6 @@ def _align_up(value: int, align: int) -> int: COMPUTE_SCHEDULE_FP4_COL_BAND = "fp4_col_band" COMPUTE_SCHEDULE_FP8_QUADRANT = "fp8_quadrant" COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE = "fp8_deep_pipeline" - COMPUTE_SCHEDULE_B_STREAMING = "b_streaming" fp8_deep_pipeline_eligible = ( data_format in ("fp8", "a8w4") @@ -590,8 +584,6 @@ def _align_up(value: int, align: int) -> int: ) def _pick_compute_schedule_kind(): - if b_streaming: - return COMPUTE_SCHEDULE_B_STREAMING if wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8: return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING # Quadrant schedules split B into left/right halves and compute @@ -612,7 +604,6 @@ def _pick_compute_schedule_kind(): use_fp4_bank_friendly_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE - use_b_streaming_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING if use_n4k4_bscale: assert compute_schedule_kind == COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING @@ -640,11 +631,6 @@ def _pick_compute_schedule_kind(): and num_buffers == 4 and use_cluster ) - if use_b_streaming_schedule: - print( - f"[b_streaming] {data_format} tile=({tile_m},{tile_n},{tile_k}) " f"M_r={wmma_m_rep} N_r={wmma_n_rep}", - flush=True, - ) if use_fp4_bank_friendly_schedule: _bank_half_wm = wmma_m_rep // 2 @@ -1367,59 +1353,6 @@ def _emit_rows(start_wm, a_frags): return accs, next_result return accs - def _b_streaming_compute( - accs, - b_buf, - b_bases, - a_frags, - a_scales, - b_scales, - ks, - emit_filler=None, - next_info=None, - mid_compute_callback=None, - ): - """B-streaming counterpart to _a_streaming_compute (A held, B streamed).""" - next_result = None - _front_wn = (wmma_n_rep + 1) // 2 - _back_wn = wmma_n_rep - _front_wn - - def _emit_cols(start_wn, b_frags_chunk): - for frag_i in range_constexpr(len(b_frags_chunk)): - wn = start_wn + frag_i - if const_expr(wn == wmma_n_rep - 1 and emit_filler is not None): - rocdl.sched_barrier(0) - emit_filler() - for wm_raw in range_constexpr(wmma_m_rep): - wm = (wmma_m_rep - 1 - wm_raw) if (wn % 2 == 1) else wm_raw - _emit_wmma(accs, wm, wn, a_frags[wm], b_frags_chunk[frag_i], a_scales, b_scales) - - b_frags_front = [load_b_frag(b_buf, b_bases, wn, ks) for wn in range_constexpr(_front_wn)] - _use_partial_drain = next_info is not None and _front_wn * wmma_m_rep >= 4 - - if const_expr(_use_partial_drain): - next_result = _load_a_and_scales(*next_info) - rocdl.s_wait_dscnt(_as_ds_loads) - else: - rocdl.s_wait_dscnt(0) - - _emit_cols(0, b_frags_front) - - if const_expr(mid_compute_callback is not None): - rocdl.sched_barrier(0) - mid_compute_callback() - - if const_expr(_back_wn > 0): - b_frags_back = [load_b_frag(b_buf, b_bases, _front_wn + h, ks) for h in range_constexpr(_back_wn)] - rocdl.s_wait_dscnt(_as_ds_loads if _use_partial_drain else 0) - _emit_cols(_front_wn, b_frags_back) - - if const_expr(_use_partial_drain): - return accs, next_result - if const_expr(next_info is not None): - return accs, _load_a_and_scales(*next_info) - return accs - # ── Compute on one LDS buffer ── def compute_tile( accs_in, @@ -1999,57 +1932,6 @@ def _prefetch_a2(): return current_accs - def compute_tile_b_streaming( - accs_in, lds_a, lds_b, lds_as, lds_bs, emit_filler=None, mid_compute_callback=None - ): - """compute_tile counterpart with A held and B streamed.""" - current_accs = list(accs_in) - a_buf, a_bases = _precompute_a_lane_bases(lds_a) - b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) - load_args = (a_buf, a_bases, as_buf, as_bases, bs_buf, bs_bases) - - if const_expr(k_wmma_steps == 1): - a_frags, a_scales, b_scales = _load_a_and_scales(*load_args, 0) - return _b_streaming_compute( - current_accs, - b_buf, - b_bases, - a_frags, - a_scales, - b_scales, - 0, - emit_filler=emit_filler, - mid_compute_callback=mid_compute_callback, - ) - - prev_a, prev_as, prev_bs = _load_a_and_scales(*load_args, 0) - for ks in range_constexpr(k_wmma_steps - 1): - current_accs, (prev_a, prev_as, prev_bs) = _b_streaming_compute( - current_accs, - b_buf, - b_bases, - prev_a, - prev_as, - prev_bs, - ks, - next_info=load_args + (ks + 1,), - mid_compute_callback=mid_compute_callback if ks == 0 else None, - ) - return _b_streaming_compute( - current_accs, - b_buf, - b_bases, - prev_a, - prev_as, - prev_bs, - k_wmma_steps - 1, - emit_filler=emit_filler, - ) - def hot_loop_scheduler(): _half_wm = wmma_m_rep // 2 _half_wmma = _half_wm * wmma_n_rep @@ -2173,16 +2055,6 @@ def compute_tile_scheduled( pf_a_scales=None, pf_b_scales=None, ): - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): - return compute_tile_b_streaming( - accs_in, - lds_a, - lds_b, - lds_as, - lds_bs, - emit_filler=emit_filler, - mid_compute_callback=mid_compute_callback, - ) if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): return compute_tile_fp4_bank_friendly( accs_in, @@ -2232,32 +2104,8 @@ def compute_tile_scheduled( pf_b_scales=pf_b_scales, ) - def hot_loop_scheduler_b_streaming(): - """hot_loop_scheduler counterpart for B-streaming.""" - _front_wn = (wmma_n_rep + 1) // 2 - _back_wn = wmma_n_rep - _front_wn - _a_loads_total = wmma_m_rep * DS_LOADS_PER_A_FRAG - _front_b_loads = _front_wn * _b_frag_loads_per_wn - _back_b_loads = _back_wn * _b_frag_loads_per_wn - _next_ks_loads = _a_loads_total + _scale_ds_loads - - for _ks in range_constexpr(k_wmma_steps): - if const_expr(_ks == 0): - rocdl.sched_dsrd(_next_ks_loads + _front_b_loads) - else: - rocdl.sched_dsrd(_front_b_loads) - rocdl.sched_mfma(_front_wn * wmma_m_rep) - if const_expr(_back_wn > 0): - rocdl.sched_dsrd(_back_b_loads) - rocdl.sched_mfma(_back_wn * wmma_m_rep) - if const_expr(_ks < k_wmma_steps - 1): - rocdl.sched_dsrd(_next_ks_loads) - rocdl.sched_barrier(0) - def hot_loop_scheduler_scheduled(): - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_B_STREAMING): - hot_loop_scheduler_b_streaming() - elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): hot_loop_scheduler_fp4_bank_friendly() elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE): hot_loop_scheduler_fp8_deep_pipeline() @@ -3067,7 +2915,6 @@ def _emit_buffer_store(): use_scale_opsel, expert_sched_mode, atomic_barrier_enable, - b_streaming, scale_load_path, fp8_schedule, ) @@ -3180,7 +3027,6 @@ def compile_ptpc_gemm( return compile_fp8fp4_gemm( data_format=data_format, scale_mode="ptpc", - b_streaming=False, wave_specialized_tdm=True, use_scale_opsel=False, fp8_schedule="auto", diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 2ad704d8..5b380f45 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -491,7 +491,6 @@ def _run_mxscale_gemm_test( waves_per_eu=None, expert_sched_mode=True, split_k=1, - b_streaming=False, scale_load_path="tdm", return_launch_fn=False, ): @@ -605,7 +604,6 @@ def _run_mxscale_gemm_test( n=padded_n, scale_load_path=scale_load_path, use_scale_opsel=use_scale_opsel, - b_streaming=b_streaming, ): b_scale = preshuffle_e8m0_bscale_n4k4(b_scale) else: @@ -651,7 +649,6 @@ def _run_mxscale_gemm_test( split_k=split_k, use_scale_opsel=use_scale_opsel, expert_sched_mode=expert_sched_mode, - b_streaming=b_streaming, scale_load_path=scale_load_path, ) @@ -1907,8 +1904,6 @@ def _run_benchmark(args): _ptpc_ignored.append("--use-scale-opsel") if args.scale_load_path != "tdm": _ptpc_ignored.append(f"--scale-load-path {args.scale_load_path}") - if args.b_streaming: - _ptpc_ignored.append("--b-streaming") if _ptpc_ignored: print(f" Note: PTPC ignores (forced internally): {', '.join(_ptpc_ignored)}") print("=" * 72) @@ -2049,7 +2044,6 @@ def _run_benchmark(args): use_scale_opsel=args.use_scale_opsel, expert_sched_mode=args.expert_sched_mode, atomic_barrier_enable=args.atomic_barrier_enable, - b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, ) @@ -2332,7 +2326,6 @@ def _run_graph_verify(args): use_scale_opsel=args.use_scale_opsel, expert_sched_mode=args.expert_sched_mode, atomic_barrier_enable=args.atomic_barrier_enable, - b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, ) @@ -2452,7 +2445,6 @@ def launch(): choices=["tdm", "vgpr"], ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) - parser.add_argument("--b-streaming", action="store_true", default=False) parser.add_argument( "--atomic-barrier-enable", action="store_true", @@ -2557,7 +2549,6 @@ def _run_correctness_test(): inst_prefetch=args.inst_prefetch, waves_per_eu=args.waves_per_eu, expert_sched_mode=args.expert_sched_mode, - b_streaming=args.b_streaming, scale_load_path=args.scale_load_path, ) From bd1e4e764258565d6c65dd24f11d3b7f24466f3f Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 15 Jun 2026 10:34:44 +0000 Subject: [PATCH 10/16] implement a scale buffer load vgpr path --- kernels/gemm_fp8fp4_gfx1250.py | 320 ++++++++++++++++++---- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 103 ++++++- 2 files changed, 367 insertions(+), 56 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 883dfef8..a20459ce 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -121,6 +121,11 @@ def is_ref_segmented_lds_layout( ) +def _is_row_major_streaming(wmma_m_rep, wmma_n_rep, n_accs): + # Mirrors _pick_compute_schedule_kind: row-major when a rep is odd or n_accs < 8. + return wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8 + + def use_n4k4_bscale_layout( *, data_format, @@ -133,8 +138,14 @@ def use_n4k4_bscale_layout( scale_mode="mxscale", scale_load_path="tdm", use_scale_opsel=False, + num_buffers=2, + out_dtype="f32", + wave_specialized_tdm=False, + fp8_schedule="auto", ): - """Whether B-scale uses the tile-independent N4K4 preshuffle layout.""" + """B-scale uses the tile-independent N4K4 layout on row-major-streaming and the + FP8/A8W4 quadrant schedule (one preshuffle serves both); deep-pipeline and fp4 + keep the legacy layout.""" if scale_mode != "mxscale": return False if data_format not in ("fp8", "a8w4"): @@ -152,9 +163,65 @@ def use_n4k4_bscale_layout( wmma_m_rep = (tile_m // m_warp) // WMMA_M wmma_n_rep = (tile_n // n_warp) // WMMA_N n_accs = wmma_m_rep * wmma_n_rep - # Row-major streaming is selected exactly when a rep is odd or n_accs < 8 - # (see _pick_compute_schedule_kind); otherwise fp8/a8w4 route to quadrant. - return wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8 + if _is_row_major_streaming(wmma_m_rep, wmma_n_rep, n_accs): + return True + # Even-rep FP8/A8W4: quadrant uses N4K4; the deep-pipeline shape keeps legacy + # (its B-scale rides the VGPR ring, not LDS). + deep_eligible = ( + tile_m == 256 + and tile_n == 256 + and tile_k == 128 + and m_warp == 2 + and n_warp == 2 + and num_buffers == 4 + and wave_specialized_tdm + and out_dtype == "bf16" + and not use_scale_opsel + ) + is_deep = fp8_schedule == "deep-pipeline" or (fp8_schedule == "auto" and deep_eligible) + return not is_deep + + +def use_natural_ascale_vgpr( + *, + data_format, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + n, + scale_mode="mxscale", + scale_load_path="tdm", + ascale_load_path="vgpr", + use_scale_opsel=False, + wave_specialized_tdm=False, +): + """Whether A-scale uses the natural (un-reshuffled) buffer_load->VGPR path. + + Row-major-streaming (decode) only: pairs with N4K4 B (TDM), A read straight from + runtime ``A_scale[M, K//32]`` into VGPRs, loop-ahead prefetched. Quadrant (prefill) + keeps legacy/TDM A-scale (its target is A=tdm-M4K4). Requires wave-specialized TDM.""" + if ascale_load_path != "vgpr": + return False + if not wave_specialized_tdm: + return False + if not use_n4k4_bscale_layout( + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + n=n, + scale_mode=scale_mode, + scale_load_path=scale_load_path, + use_scale_opsel=use_scale_opsel, + ): + return False + wmma_m_rep = (tile_m // m_warp) // WMMA_M + wmma_n_rep = (tile_n // n_warp) // WMMA_N + return _is_row_major_streaming(wmma_m_rep, wmma_n_rep, wmma_m_rep * wmma_n_rep) @functools.lru_cache(maxsize=256) @@ -183,6 +250,7 @@ def compile_fp8fp4_gemm( expert_sched_mode: bool = True, atomic_barrier_enable: bool = False, scale_load_path: str = "tdm", + ascale_load_path: str = "vgpr", fp8_schedule: str = "auto", ): """Compile an FP4/FP8/A8W4 GEMM kernel with TDM async copy. @@ -249,9 +317,8 @@ def compile_fp8fp4_gemm( # scale path, leaving only A + B -> 2 waves; otherwise A + B + A_scale + # B_scale -> 4 waves. _drop_scale_loader_waves = is_ptpc or scale_load_path == "vgpr" - _min_wave_spec_warps = 2 if _drop_scale_loader_waves else 4 - if wave_specialized_tdm and num_warps < _min_wave_spec_warps: - raise ValueError(f"wave_specialized_tdm requires at least {_min_wave_spec_warps} waves, got {num_warps}") + # Min loader-wave check is finalized after use_natural_ascale is known (natural + # A-scale frees its wave, allowing >=2; see below). # ── Format-dependent compile-time constants ── # A8W4: activation is FP8 (PACK_FACTOR_A=1), weight is FP4 (PACK_FACTOR_B=2) @@ -332,6 +399,10 @@ def compile_fp8fp4_gemm( scale_mode=scale_mode, scale_load_path=scale_load_path, use_scale_opsel=use_scale_opsel, + num_buffers=num_buffers, + out_dtype=out_dtype, + wave_specialized_tdm=wave_specialized_tdm, + fp8_schedule=fp8_schedule, ) use_n4k4_opsel = False if use_n4k4_bscale: @@ -352,6 +423,37 @@ def compile_fp8fp4_gemm( _half = wmma_n_rep // 2 n4k4_opsel_kgrp_off = (_half // 4) * n4k4_bs_lds_row_stride + (_half % 4) * 4 + # A-scale natural buffer_load->VGPR (no reshuffle), paired with N4K4 B (TDM). + # TEMP: ascale_load_path='tdm' fallback + old scale_load_path/non-ws kept until + # those legacy paths are retired; the natural vgpr path is the target. + if ascale_load_path not in ("vgpr", "tdm"): + raise ValueError(f"ascale_load_path must be 'vgpr' or 'tdm', got {ascale_load_path!r}") + use_natural_ascale = use_natural_ascale_vgpr( + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + n=N, + scale_mode=scale_mode, + scale_load_path=scale_load_path, + ascale_load_path=ascale_load_path, + use_scale_opsel=use_scale_opsel, + wave_specialized_tdm=wave_specialized_tdm, + ) + # M op_sel pairs A-blocks (wm, wm+rep/2) into one VGPR via lane_kgrp; power-of-2 rep. + use_natural_ascale_opsel = use_natural_ascale and wmma_m_rep >= 2 and (wmma_m_rep & (wmma_m_rep - 1)) == 0 + nat_as_half = wmma_m_rep // 2 + nat_as_load = nat_as_half if use_natural_ascale_opsel else wmma_m_rep + # Natural TDM tensors = {A-data, B-data, B-scale}; at exactly 2 waves wave0 also + # issues B-scale (secondary), so the natural path needs only >=2 loader waves. + natural_two_wave = use_natural_ascale and num_warps == 2 + + _min_wave_spec_warps = 2 if (_drop_scale_loader_waves or use_natural_ascale) else 4 + if wave_specialized_tdm and num_warps < _min_wave_spec_warps: + raise ValueError(f"wave_specialized_tdm requires at least {_min_wave_spec_warps} waves, got {num_warps}") + _b_frag_loads_per_wn = 2 if is_a8w4 else 4 _a_frag_loads_per_wm = 2 if is_fp4 else 4 # _scale_ds_loads counts scale ds_loads issued alongside A/B fragment loads in @@ -370,7 +472,8 @@ def compile_fp8fp4_gemm( lds_a_data_bytes = tile_m * lds_a_stride_bytes lds_b_data_bytes = tile_n * packed_tile_k_b _scale_guard_bytes = 16 - lds_a_scale_bytes = 0 if is_ptpc else tile_m * scale_k_per_tile + _scale_guard_bytes + # Natural A-scale lives in VGPRs (buffer_load), so it needs no LDS. + lds_a_scale_bytes = 0 if (is_ptpc or use_natural_ascale) else tile_m * scale_k_per_tile + _scale_guard_bytes if use_n4k4_bscale: lds_b_scale_bytes = n4k4_bs_lds_rows * n4k4_bs_lds_row_stride + _scale_guard_bytes else: @@ -446,12 +549,18 @@ def _align_up(value: int, align: int) -> int: _scale_ds_loads = 0 _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn _as_ds_loads = _a_frag_ds + elif use_natural_ascale: + # Only A-scale leaves LDS (VGPR); B-scale stays an N4K4 ds_load. + _a_scale_ds = 0 + _scale_ds_loads = _b_scale_ds + _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _scale_ds_loads + _as_ds_loads = _a_frag_ds + _scale_ds_loads # Scale prefetch depth (K-tiles ahead) for the buffer->VGPR path. The # ref-segmented deep-pipeline path is VGPR-bound (D=2 doubles scale VGPRs -> # spill + ~18% regression), so it stays at D=1. The general coalesced path # (thin row-major streaming tiles) has spare VGPRs, so it prefetches deeper # to overlap each scale buffer_load with an earlier tile's TDM wait. - _bvs_D_default = 3 if use_general_vgpr_scale else 1 + _bvs_D_default = 3 if (use_general_vgpr_scale or use_natural_ascale) else 1 _bvs_D = max(1, int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", str(_bvs_D_default)))) # FLYDSL_BUFFER_VGPR_SCALE_PRELOAD=1 (experiment, small-M only): switch the # general vgpr scale path to the b128 layout (ks innermost) so scales load @@ -463,8 +572,9 @@ def _align_up(value: int, align: int) -> int: # all-up-front variant. Full-K scale must fit in VGPRs -- NOT general. _bvs_b128 = use_general_vgpr_scale and bool(int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_PRELOAD", "0"))) _bvs_preload = _bvs_b128 and loop_iters == 0 - # The buffer_load->VGPR scale ring is built only when scale is actually loaded. - _bvs_active = use_buffer_vgpr_scale + # The buffer_load->VGPR scale ring is built only when scale is actually loaded + # (coalesced A+B vgpr path, or the natural A-scale path with B still on TDM). + _bvs_active = use_buffer_vgpr_scale or use_natural_ascale if use_ref_segmented_lds_layout: # The A/B data pools are no longer packed into the same per-stage @@ -606,7 +716,7 @@ def _pick_compute_schedule_kind(): use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE if use_n4k4_bscale: - assert compute_schedule_kind == COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING + assert compute_schedule_kind in (COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING, COMPUTE_SCHEDULE_FP8_QUADRANT) if use_buffer_vgpr_scale: # General coalesced VGPR scale is supported on the row-major streaming # schedule (mxscale fp8/a8w4, no scale_opsel, wave-specialized TDM); the @@ -655,7 +765,13 @@ def _pick_compute_schedule_kind(): _fp8_half_wm = wmma_m_rep // 2 _fp8_half_wn = wmma_n_rep // 2 _fp8_group_size = _fp8_half_wm * _fp8_half_wn - _fp8_b_scale_loads = 0 if is_ptpc else (b_scale_load_rep + 3) // 4 + if use_n4k4_bscale: + # N4K4 B-scale ds_load instruction count (matches load_n4k4_bscale). + _n4k4_bn = b_scale_load_rep // 2 if use_n4k4_opsel else b_scale_load_rep + _n4k4_bpl = 4 if _n4k4_bn % 4 == 0 else (2 if _n4k4_bn % 2 == 0 else 1) + _fp8_b_scale_loads = _n4k4_bn // _n4k4_bpl + else: + _fp8_b_scale_loads = 0 if is_ptpc else (b_scale_load_rep + 3) // 4 if use_fp8_deep_pipeline_schedule: _fp8_pair_wm = 2 _fp8_pair_wn = 2 @@ -717,6 +833,23 @@ def kernel_mxscale_gemm( warp_m_base = wave_m_idx * arith.index(warp_tile_m) warp_n_base = wave_n_idx * arith.index(warp_tile_n) + def _load_contig_i32(rsrc, base_idx, n, soff): + # Load n contiguous i32 from base_idx (i32-element units) via the widest + # buffer_load chunks (b128/b64/b32). Returns a list of n values. + out = [None] * n + _chunks = _vec_chunks(n) + for _ci in range_constexpr(len(_chunks)): + start, w = _chunks[_ci] + off = arith.index_cast(T.i32, base_idx + arith.index(start)) + r = buffer_ops.buffer_load(rsrc, off, vec_width=w, dtype=T.i32, soffset_bytes=soff) + if const_expr(w == 1): + out[start] = r + else: + rv = fx.Vector(r) + for c in range_constexpr(w): + out[start + c] = rv[c] + return out + if const_expr(use_buffer_vgpr_scale): # Direct global->VGPR scale load (no TDM/LDS). One K-tile's scales for # all reps land in VGPRs; the loop carries a small prefetch ring. @@ -784,23 +917,6 @@ def _gvs_load_scales(rsrc, mb, rep, k_base): vals.append(v[j]) return vals - def _load_contig_i32(rsrc, base_idx, n, soff): - # Load n contiguous i32 from base_idx (element units) via the widest - # buffer_load chunks (b128/b64/b32). Returns a list of n values. - out = [None] * n - _chunks = _vec_chunks(n) - for _ci in range_constexpr(len(_chunks)): - start, w = _chunks[_ci] - off = arith.index_cast(T.i32, base_idx + arith.index(start)) - r = buffer_ops.buffer_load(rsrc, off, vec_width=w, dtype=T.i32, soffset_bytes=soff) - if const_expr(w == 1): - out[start] = r - else: - rv = fx.Vector(r) - for c in range_constexpr(w): - out[start + c] = rv[c] - return out - def _gvs_load_scales_b128(rsrc, mb, rep, k_base, in_voffset): # b128 layout [kt, grp, j, lane32, ks, spw]: a lane's KS scale-words # are contiguous, so each (rep,lane) is one wide load (b128 for @@ -841,6 +957,34 @@ def _bvs_prefetch(k_base, preload=False): b = _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base) return a, b + elif const_expr(use_natural_ascale): + # Natural A-scale: read A_scale[M, K//32] straight into VGPRs (no reshuffle). + # A row's K-scales are contiguous -> one wide load per M-block grabs all ks. + # kt rides the scalar soffset so the per-lane voffset is K-tile-invariant + # (CSE'd -> loads fully hidden, like _gvs_load_scales). M op_sel: kgrp1 reads + # block wm+rep/2 (a pair per load). + _nat_as_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) + _nat_row_i32 = K_scale // 4 # i32 elements per A_scale row (K_scale = K//32, %4==0) + _nat_row0 = blk_m + warp_m_base + lane16 + if const_expr(use_natural_ascale_opsel): + _nat_row0 = _nat_row0 + lane_kgrp * arith.index(nat_as_half * WMMA_M) + _vs_tile_a = k_wmma_steps * nat_as_load + _vs_tile_b = 0 + + def _nat_as_load(k_base): + kt = k_base / arith.index(tile_k) + soff = arith.index_cast(T.i32, kt * arith.index(scale_k_per_tile)) + vals = [None] * (k_wmma_steps * nat_as_load) + for i in range_constexpr(nat_as_load): + vidx = (_nat_row0 + arith.index(i * WMMA_M)) * arith.index(_nat_row_i32) + ks_vals = _load_contig_i32(_nat_as_rsrc, vidx, k_wmma_steps, soff) + for ks in range_constexpr(k_wmma_steps): + vals[ks * nat_as_load + i] = ks_vals[ks] + return vals + + def _bvs_prefetch(k_base, preload=False): + return _nat_as_load(k_base), [] + m_idx = fx.Index(i32_m) # Runtime leading-dim strides (strided A/C). Dense callers pass lda == K, # ldc == N for byte-identical addressing. A's stride is in packed elements. @@ -1195,6 +1339,13 @@ def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): a = pf_a[ks * wmma_m_rep : (ks + 1) * wmma_m_rep] b = pf_b[ks * b_scale_load_rep : (ks + 1) * b_scale_load_rep] return a, b + if const_expr(use_natural_ascale): + # A from the natural VGPR ring (slice this ks; M op_sel handled in + # _emit via nat_as_half); B from the N4K4 LDS layout. + pf_a, _ = _vgpr_scale_box[0] + a = pf_a[ks * nat_as_load : (ks + 1) * nat_as_load] + b = _load_b_scale_lds(bs_buf, bs_bases, ks) + return a, b a_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) b_all = _load_b_scale_lds(bs_buf, bs_bases, ks) if const_expr(use_scale_opsel): @@ -1249,6 +1400,10 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): if const_expr(use_scale_opsel): a_scale_idx = wm // 2 a_opsel = wm % 2 + elif const_expr(use_natural_ascale_opsel): + # Natural A M op_sel pairs (j, j+rep/2): kgrp1 carries the second half. + a_scale_idx = wm % nat_as_half + a_opsel = wm // nat_as_half else: a_scale_idx = wm a_opsel = 0 @@ -1367,7 +1522,7 @@ def compute_tile( pf_b_scales=None, ): current_accs = list(accs_in) - if const_expr(use_general_vgpr_scale): + if const_expr(use_general_vgpr_scale or use_natural_ascale): # Scales come from VGPR: use the loop-prefetched ring when provided, # else issue the buffer_loads inline (tail path) for scale_k_base. if const_expr(pf_a_scales is not None): @@ -1379,7 +1534,12 @@ def compute_tile( _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) + if const_expr(use_natural_ascale): + as_buf, as_bases = None, None # A-scale from the VGPR ring, not LDS + else: + as_buf, as_bases = _precompute_scale_lane_bases( + lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a + ) if const_expr(use_n4k4_bscale): bs_buf, bs_bases = _precompute_n4k4_bscale_bases(lds_bs) else: @@ -1583,9 +1743,12 @@ def compute_tile_fp8_quadrant( a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) + if const_expr(use_n4k4_bscale): + bs_buf, bs_bases = _precompute_n4k4_bscale_bases(lds_bs) + else: + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b + ) _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn _b_left_bundle_loads = _b_half_loads + _fp8_b_scale_loads @@ -1608,6 +1771,8 @@ def _load_a_scales(ks): def _load_b_scales(ks): if const_expr(is_ptpc): return None # PTPC: scale applied in epilogue, not in K-loop + if const_expr(use_n4k4_bscale): + return _load_b_scale_lds(bs_buf, bs_bases, ks) # op_sel in _emit_wmma b_scales = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) if const_expr(use_scale_opsel): return b_scales[::2] @@ -2392,7 +2557,9 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): stages_a_lds_addr.append(_dg0_lane(make_desc_a(stages_a_mem[i], arith.index(0)), 1)) stages_b_lds_addr.append(_dg0_lane(make_desc_b(stages_b_mem[i], arith.index(0)), 1)) if const_expr(not is_ptpc): - stages_as_lds_addr.append(_dg0_lane(make_desc_as(stages_as_mem[i], arith.index(0)), 1)) + # Natural A-scale has no TDM (VGPR); B-scale keeps its TDM descriptor. + if const_expr(not use_natural_ascale): + stages_as_lds_addr.append(_dg0_lane(make_desc_as(stages_as_mem[i], arith.index(0)), 1)) stages_bs_lds_addr.append(_dg0_lane(make_desc_bs(stages_bs_mem[i], arith.index(0)), 1)) desc_a_init = make_desc_a(stages_a_mem[0], split_k_base) @@ -2405,7 +2572,12 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): desc_as_init = desc_a_init desc_bs_init = desc_b_init else: - desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) + if const_expr(use_natural_ascale): + # A-scale on VGPR: alias its (never-issued) TDM slot to A; wave2 carries B-scale. + stages_as_lds_addr = stages_a_lds_addr + desc_as_init = desc_a_init + else: + desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) adv_a_i32 = fx.Int32(tile_k // PACK_FACTOR_A) @@ -2418,7 +2590,12 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): _drop_scale_waves = is_ptpc or use_buffer_vgpr_scale - _active_wave_limit = 2 if _drop_scale_waves else 4 + if const_expr(use_natural_ascale): + # wave0=A, wave1=B, wave2=B-scale (>=3 waves); A-scale is VGPR. At 2 + # waves wave2 doesn't exist and B-scale rides wave0 as a secondary. + _active_wave_limit = min(num_warps, 3) + else: + _active_wave_limit = 2 if _drop_scale_waves else 4 active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) def _select4(values): @@ -2449,11 +2626,28 @@ def _select_active_tdm(stage_lds_addrs, descs, advs): active_pred_const = pred_const if const_expr(wave_specialized_tdm): + if const_expr(use_natural_ascale): + # Remap: wave2 (the old A-scale slot) now issues B-scale; wave3 is the + # padded 4th slot (predicated off by _active_wave_limit=3). + _tdm_stage_sel = (stages_a_lds_addr, stages_b_lds_addr, stages_bs_lds_addr, stages_bs_lds_addr) + _tdm_desc_sel = (desc_a_init, desc_b_init, desc_bs_init, desc_bs_init) + _tdm_adv_sel = (adv_a_i32, adv_b_i32, adv_bs_i32, adv_bs_i32) + else: + _tdm_stage_sel = (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr) + _tdm_desc_sel = (desc_a_init, desc_b_init, desc_as_init, desc_bs_init) + _tdm_adv_sel = (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32) active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( - (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr), - (desc_a_init, desc_b_init, desc_as_init, desc_bs_init), - (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32), + _tdm_stage_sel, _tdm_desc_sel, _tdm_adv_sel ) + if const_expr(natural_two_wave): + # Secondary TDM: B-scale issued by wave0 only (2-wave packs A-data + + # B-scale onto wave0). Static, wave-independent; carried addr_lo below. + sec_pred_const = arith.select(tdm_wave_id == fx.Int32(0), fx.Int32(1), fx.Int32(0)) + sec_stage_lds_addr = stages_bs_lds_addr + sec_addr_hi = _dg0_lane(desc_bs_init, 3) + sec_dgroup1 = desc_bs_init.dgroup1 + sec_adv_i32 = adv_bs_i32 + sec_addr_lo_init = _dg0_lane(desc_bs_init, 2) else: addr_lo_a = _dg0_lane(desc_a_init, 2) addr_hi_a = _dg0_lane(desc_a_init, 3) @@ -2477,18 +2671,30 @@ def _pipeline_fence_signal(outstanding=0): if const_expr(wave_specialized_tdm): - def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): + def _issue_active_tdm(load_stage, addr_box, k_prefetch=None, sec_box=None): dg0 = _pack_dg0(active_pred_const, active_stage_lds_addr[load_stage], addr_box[0], active_addr_hi) tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) addr_box[0] = addr_box[0] + active_adv_i32 + if const_expr(natural_two_wave): + # wave0's second descriptor: B-scale (predicated to wave0). + dg0s = _pack_dg0(sec_pred_const, sec_stage_lds_addr[load_stage], sec_box[0], sec_addr_hi) + tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0s, sec_dgroup1)) + sec_box[0] = sec_box[0] + sec_adv_i32 if k_prefetch is not None: _l2_prefetch(k_prefetch) # Prologue if const_expr(wave_specialized_tdm): + if const_expr(natural_two_wave): + active_sec_lo = sec_addr_lo_init for i in range_constexpr(pre_loaded): addr_box = [active_addr_lo] - _issue_active_tdm(i, addr_box) + if const_expr(natural_two_wave): + sec_box = [active_sec_lo] + _issue_active_tdm(i, addr_box, sec_box=sec_box) + active_sec_lo = sec_box[0] + else: + _issue_active_tdm(i, addr_box) active_addr_lo = addr_box[0] else: for i in range_constexpr(pre_loaded): @@ -2529,6 +2735,8 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): if const_expr(loop_iters > 0): if const_expr(wave_specialized_tdm): init_args = list(accs) + [active_addr_lo] + if const_expr(natural_two_wave): + init_args = init_args + [active_sec_lo] if const_expr(_bvs_active): init_args = init_args + _bvs_ra + _bvs_rb @@ -2536,6 +2744,9 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): accs_in = list(state[:n_accs]) cur_addr_lo = state[n_accs] _state_off = n_accs + 1 + if const_expr(natural_two_wave): + cur_sec_lo = state[_state_off] + _state_off = _state_off + 1 if const_expr(_bvs_active): _ra0 = _state_off _ring_a = list(state[_ra0 : _ra0 + _bvs_D * _vs_tile_a]) @@ -2547,17 +2758,19 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None): load_stage = (buf_idx + num_buffers - 1) % num_buffers addr_box = [cur_addr_lo] + sec_box = [cur_sec_lo] if natural_two_wave else None def _mid_tdm_ws( _ls=load_stage, _ab=addr_box, + _sb=sec_box, _k_off=( split_k_base + loop_iter * arith.index(num_buffers * tile_k) + arith.index(buf_idx * tile_k) ), ): - _issue_active_tdm(_ls, _ab, k_prefetch=_k_off) + _issue_active_tdm(_ls, _ab, k_prefetch=_k_off, sec_box=_sb) if const_expr(not use_ws_tdm_split_signal_overlap): _pipeline_fence_signal(outstanding=_fence_outstanding) @@ -2605,16 +2818,21 @@ def _late_tdm_ws_split_signal(): pf_b_scales=_cur_b, ) cur_addr_lo = addr_box[0] + if const_expr(natural_two_wave): + cur_sec_lo = sec_box[0] hot_loop_scheduler_scheduled() if const_expr(_bvs_active): _bvs_yield = _ring_a + _ring_b else: _bvs_yield = [] - results = yield list(accs_in) + [cur_addr_lo] + _bvs_yield + _sec_yield = [cur_sec_lo] if natural_two_wave else [] + results = yield list(accs_in) + [cur_addr_lo] + _sec_yield + _bvs_yield accs = list(results[:n_accs]) active_addr_lo = results[n_accs] + if const_expr(natural_two_wave): + active_sec_lo = results[n_accs + 1] else: init_args = list(accs) + [addr_lo_a, addr_lo_b, addr_lo_as, addr_lo_bs] @@ -2712,8 +2930,9 @@ def _bvs_tail_kb(): # General VGPR scale: prefetch the tail's scales _bvs_D K-tiles ahead so # each scale buffer_load overlaps an earlier tile's TDM wait instead of # stalling the WMMA inline. The ref-segmented deep-pipeline path keeps its - # inline per-tile load to stay within its tight VGPR budget. - _bvs_tail_pf = use_general_vgpr_scale + # inline per-tile load to stay within its tight VGPR budget. Natural A-scale + # uses the same ahead-of-time tail prefetch. + _bvs_tail_pf = use_general_vgpr_scale or use_natural_ascale _bvs_tail_ring = [] _bvs_tail_issue_kt = [loop_iters * num_buffers] # Preload (opt-in): all K-tiles' scales loaded up front, indexed by step. @@ -2801,9 +3020,10 @@ def _emit_epi_addrs(): _tail_had_load = True if const_expr(wave_specialized_tdm): _tail_addr_box = [active_addr_lo] + _tail_sec_box = [active_sec_lo] if natural_two_wave else None - def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box): - _issue_active_tdm(_ls, _ab) + def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box, _sb=_tail_sec_box): + _issue_active_tdm(_ls, _ab, sec_box=_sb) _tail_mid_cb = _tail_mid_ws else: @@ -2847,6 +3067,8 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): if const_expr(_load_stage is not None): if const_expr(wave_specialized_tdm): active_addr_lo = _tail_addr_box[0] + if const_expr(natural_two_wave): + active_sec_lo = _tail_sec_box[0] else: addr_lo_a = _tail_ab[0][0] addr_lo_b = _tail_ab[1][0] diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 5b380f45..4c66ce8d 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -29,6 +29,7 @@ compile_ptpc_gemm, is_ref_segmented_lds_layout, use_n4k4_bscale_layout, + use_natural_ascale_vgpr, ) from tests.kernels.utils import fp4_utils # noqa: E402 @@ -492,6 +493,7 @@ def _run_mxscale_gemm_test( expert_sched_mode=True, split_k=1, scale_load_path="tdm", + ascale_load_path="vgpr", return_launch_fn=False, ): """Unified test body for FP4 and FP8.""" @@ -585,15 +587,33 @@ def _run_mxscale_gemm_test( wave_specialized_tdm=wave_specialized_tdm, use_scale_opsel=use_scale_opsel, ) - a_scale = preshuffle_scale_for_load_path( - a_scale, - warp_tile_m, - skt, - scale_load_path=scale_load_path, + _natural_ascale = use_natural_ascale_vgpr( data_format=data_format, - ref_segmented=_ref_seg, - row_align=tile_m, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + n=padded_n, + scale_load_path=scale_load_path, + ascale_load_path=ascale_load_path, + use_scale_opsel=use_scale_opsel, + wave_specialized_tdm=wave_specialized_tdm, ) + if _natural_ascale: + # Natural path reads A_scale[M, K//32] straight from VGPRs -- no reshuffle, + # the (already row-padded) tensor is uploaded as-is. + pass + else: + a_scale = preshuffle_scale_for_load_path( + a_scale, + warp_tile_m, + skt, + scale_load_path=scale_load_path, + data_format=data_format, + ref_segmented=_ref_seg, + row_align=tile_m, + ) if use_n4k4_bscale_layout( data_format=data_format, tile_m=tile_m, @@ -604,6 +624,9 @@ def _run_mxscale_gemm_test( n=padded_n, scale_load_path=scale_load_path, use_scale_opsel=use_scale_opsel, + num_buffers=num_buffers, + out_dtype=out_dtype, + wave_specialized_tdm=wave_specialized_tdm, ): b_scale = preshuffle_e8m0_bscale_n4k4(b_scale) else: @@ -650,6 +673,7 @@ def _run_mxscale_gemm_test( use_scale_opsel=use_scale_opsel, expert_sched_mode=expert_sched_mode, scale_load_path=scale_load_path, + ascale_load_path=ascale_load_path, ) # Keep 2D — dynamic_layout=True packs shape as i32; flattening overflows for M*K >= 2^31. @@ -1015,6 +1039,71 @@ def test_mxscale_n4k4_bscale(data_format, M, N, K, tile_n, tile_k, n_warp, num_b ) +def _gen_natural_ascale_configs(): + # (fmt, M, tile_m, tile_n, tile_k, m_warp, n_warp, nbuf) for the natural A-scale + # buffer_load path (wave-spec, default ascale_load_path='vgpr'). Covers wave + # counts 2/3/4, A-scale M op_sel via tile_m (rep 1/2/4/8/16), tile_k (k_steps + # 1/2/4), multi-64 tile_n, and ragged M. tile_k kept small at large tile_m so + # LDS fits. M==tile_m exercises a full M tile; M rep_n=1 row-major): rep_m sweep via tile_m. + cfgs += [ + (fmt, 16, 16, 64, 512, 1, 4, 2), # rep1, k_steps=4 + (fmt, 32, 32, 64, 512, 1, 4, 2), # rep2 (op_sel) + (fmt, 64, 64, 64, 256, 1, 4, 2), # rep4 (op_sel) + (fmt, 128, 128, 64, 256, 1, 4, 2), # rep8 (op_sel) + (fmt, 256, 256, 64, 128, 1, 4, 2), # rep16 (op_sel) + (fmt, 16, 16, 64, 128, 1, 4, 2), # k_steps=1 + (fmt, 16, 16, 64, 256, 1, 4, 2), # k_steps=2 + (fmt, 16, 16, 128, 256, 1, 4, 2), # tile_n=128 + (fmt, 16, 16, 192, 256, 1, 4, 2), # tile_n=192 (next_pow2) + (fmt, 16, 16, 256, 256, 1, 4, 2), # tile_n=256 + ] + # 2-wave (wave0 issues A-data + B-scale): rep_m 1/2 (row-major needs n_accs<8). + cfgs += [(fmt, 16, 16, 64, 512, 1, 2, 2), (fmt, 32, 32, 64, 512, 1, 2, 2)] + # 3-wave (wave0/1/2 = A/B/B-scale). + cfgs += [(fmt, 16, 16, 192, 256, 1, 3, 2)] + # ragged / OOB M. + cfgs += [(fmt, 13, 16, 64, 512, 1, 4, 2), (fmt, 33, 64, 64, 256, 1, 4, 2)] + return cfgs + + +@pytest.mark.parametrize("data_format, M, tile_m, tile_n, tile_k, m_warp, n_warp, nbuf", _gen_natural_ascale_configs()) +def test_mxscale_natural_ascale(data_format, M, tile_m, tile_n, tile_k, m_warp, n_warp, nbuf): + # Guard: every config must take BOTH the natural A-scale path and N4K4 B-scale. + N = 2 * tile_n + K = tile_k * nbuf + assert use_natural_ascale_vgpr( + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + n=N, + wave_specialized_tdm=True, + ), f"config does not hit the natural A-scale gate: {(data_format, tile_m, tile_n, tile_k, m_warp, n_warp)}" + _run_mxscale_gemm_test( + data_format, + M, + N, + K, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + nbuf, + use_tdm_store=True, + out_dtype="bf16", + wave_specialized_tdm=True, + l2_prefetch_distance=0, + use_scale_opsel=False, + scale_load_path="tdm", + ) + + @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", [ From ff3480e74c8c9adc650c558a33dff0e5cfe54d74 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 15 Jun 2026 12:10:48 +0000 Subject: [PATCH 11/16] remove segmented lds and scale load vgpr exp codes --- kernels/gemm_fp8fp4_gfx1250.py | 372 +++------------------- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 245 +------------- 2 files changed, 57 insertions(+), 560 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index a20459ce..646da442 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -87,40 +87,6 @@ def _vec_chunks(n: int): LDS_GFX1250_MAX_BYTES = 5 * LDS_SEGMENT_BYTES -def is_ref_segmented_lds_layout( - *, - data_format, - tile_m, - tile_n, - tile_k, - m_warp, - n_warp, - num_buffers, - split_k, - wave_specialized_tdm, - use_scale_opsel, -): - """Whether this config uses the reference segmented LDS layout. - - Single source of truth shared by the kernel and the test host-side scale - preshuffle: the buffer_load->VGPR scale path uses the legacy lane-major - coalesced layout here (deep-pipeline) and the general coalesced layout - otherwise. - """ - return ( - data_format == "fp8" - and tile_m == 256 - and tile_n == 256 - and tile_k == 128 - and m_warp == 2 - and n_warp == 2 - and num_buffers == 4 - and split_k == 1 - and wave_specialized_tdm - and not use_scale_opsel - ) - - def _is_row_major_streaming(wmma_m_rep, wmma_n_rep, n_accs): # Mirrors _pick_compute_schedule_kind: row-major when a rep is odd or n_accs < 8. return wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8 @@ -136,7 +102,6 @@ def use_n4k4_bscale_layout( n_warp, n, scale_mode="mxscale", - scale_load_path="tdm", use_scale_opsel=False, num_buffers=2, out_dtype="f32", @@ -150,8 +115,6 @@ def use_n4k4_bscale_layout( return False if data_format not in ("fp8", "a8w4"): return False - if scale_load_path != "tdm": - return False if use_scale_opsel: return False if tile_k % 128 != 0: @@ -192,7 +155,6 @@ def use_natural_ascale_vgpr( n_warp, n, scale_mode="mxscale", - scale_load_path="tdm", ascale_load_path="vgpr", use_scale_opsel=False, wave_specialized_tdm=False, @@ -215,7 +177,6 @@ def use_natural_ascale_vgpr( n_warp=n_warp, n=n, scale_mode=scale_mode, - scale_load_path=scale_load_path, use_scale_opsel=use_scale_opsel, ): return False @@ -249,7 +210,6 @@ def compile_fp8fp4_gemm( use_scale_opsel: bool = False, expert_sched_mode: bool = True, atomic_barrier_enable: bool = False, - scale_load_path: str = "tdm", ascale_load_path: str = "vgpr", fp8_schedule: str = "auto", ): @@ -284,11 +244,6 @@ def compile_fp8fp4_gemm( if out_dtype not in ("f32", "bf16", "f16"): raise ValueError(f"out_dtype must be 'f32', 'bf16', or 'f16', got {out_dtype!r}") elem_bytes_d = 2 if out_dtype in ("bf16", "f16") else 4 - # scale_load_path: "tdm" = TDM->LDS (default); "vgpr" = buffer_load->VGPR, - # off the LDS/TDM/barrier path. - scale_load_paths = ("tdm", "vgpr") - if scale_load_path not in scale_load_paths: - raise ValueError(f"scale_load_path must be one of {scale_load_paths}, got {scale_load_path!r}") fp8_schedule_modes = ("auto", "quadrant", "deep-pipeline") if fp8_schedule not in fp8_schedule_modes: raise ValueError(f"fp8_schedule must be one of {fp8_schedule_modes}, got {fp8_schedule!r}") @@ -316,7 +271,7 @@ def compile_fp8fp4_gemm( # Scales bypass TDM (no dedicated loader waves) for ptpc or the buffer->VGPR # scale path, leaving only A + B -> 2 waves; otherwise A + B + A_scale + # B_scale -> 4 waves. - _drop_scale_loader_waves = is_ptpc or scale_load_path == "vgpr" + _drop_scale_loader_waves = is_ptpc # Min loader-wave check is finalized after use_natural_ascale is known (natural # A-scale frees its wave, allowing >=2; see below). @@ -397,7 +352,6 @@ def compile_fp8fp4_gemm( n_warp=n_warp, n=N, scale_mode=scale_mode, - scale_load_path=scale_load_path, use_scale_opsel=use_scale_opsel, num_buffers=num_buffers, out_dtype=out_dtype, @@ -424,8 +378,8 @@ def compile_fp8fp4_gemm( n4k4_opsel_kgrp_off = (_half // 4) * n4k4_bs_lds_row_stride + (_half % 4) * 4 # A-scale natural buffer_load->VGPR (no reshuffle), paired with N4K4 B (TDM). - # TEMP: ascale_load_path='tdm' fallback + old scale_load_path/non-ws kept until - # those legacy paths are retired; the natural vgpr path is the target. + # TEMP: ascale_load_path='tdm' fallback + non-ws kept until those legacy paths + # are retired; the natural vgpr path is the target. if ascale_load_path not in ("vgpr", "tdm"): raise ValueError(f"ascale_load_path must be 'vgpr' or 'tdm', got {ascale_load_path!r}") use_natural_ascale = use_natural_ascale_vgpr( @@ -437,7 +391,6 @@ def compile_fp8fp4_gemm( n_warp=n_warp, n=N, scale_mode=scale_mode, - scale_load_path=scale_load_path, ascale_load_path=ascale_load_path, use_scale_opsel=use_scale_opsel, wave_specialized_tdm=wave_specialized_tdm, @@ -459,7 +412,7 @@ def compile_fp8fp4_gemm( # _scale_ds_loads counts scale ds_loads issued alongside A/B fragment loads in # the streaming schedule (used for the partial-drain s_wait_dscnt bookkeeping). # The general VGPR scale path holds scales in registers (no ds_load), so it - # contributes zero. Finalized below once use_general_vgpr_scale is known. + # contributes zero. Finalized below once use_natural_ascale is known. _a_scale_ds = (wmma_m_rep + 3) // 4 _b_scale_ds = (b_scale_load_rep + 3) // 4 _scale_ds_loads = _a_scale_ds + _b_scale_ds @@ -523,110 +476,37 @@ def _align_up(value: int, align: int) -> int: ), ) - use_ref_segmented_lds_layout = is_ref_segmented_lds_layout( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - num_buffers=num_buffers, - split_k=split_k, - wave_specialized_tdm=wave_specialized_tdm, - use_scale_opsel=use_scale_opsel, - ) - - # "vgpr": load scale global->VGPR via buffer_load, bypassing - # TDM+LDS entirely. Two layouts coexist: the reference segmented deep-pipeline - # path (use_ref_segmented_lds_layout, fp8 256x256x128) and the general - # coalesced path (use_general_vgpr_scale) used by the row-major streaming - # schedule (a8w4/fp8, arbitrary warp_tile / tile_k). The full schedule+format - # eligibility check runs once the compute schedule is known (below). - use_buffer_vgpr_scale = scale_load_path == "vgpr" - use_general_vgpr_scale = use_buffer_vgpr_scale and not use_ref_segmented_lds_layout - if use_general_vgpr_scale: - # General VGPR scales live in registers: no scale ds_loads to wait on. - _scale_ds_loads = 0 - _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn - _as_ds_loads = _a_frag_ds - elif use_natural_ascale: - # Only A-scale leaves LDS (VGPR); B-scale stays an N4K4 ds_load. + if use_natural_ascale: + # A-scale leaves LDS (VGPR); B-scale stays an N4K4 ds_load. _a_scale_ds = 0 _scale_ds_loads = _b_scale_ds _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _scale_ds_loads _as_ds_loads = _a_frag_ds + _scale_ds_loads - # Scale prefetch depth (K-tiles ahead) for the buffer->VGPR path. The - # ref-segmented deep-pipeline path is VGPR-bound (D=2 doubles scale VGPRs -> - # spill + ~18% regression), so it stays at D=1. The general coalesced path - # (thin row-major streaming tiles) has spare VGPRs, so it prefetches deeper - # to overlap each scale buffer_load with an earlier tile's TDM wait. - _bvs_D_default = 3 if (use_general_vgpr_scale or use_natural_ascale) else 1 + # Scale prefetch depth (K-tiles ahead) for the A-scale VGPR ring: prefetch + # deeper so each scale buffer_load overlaps an earlier tile's TDM wait. + _bvs_D_default = 3 if use_natural_ascale else 1 _bvs_D = max(1, int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", str(_bvs_D_default)))) - # FLYDSL_BUFFER_VGPR_SCALE_PRELOAD=1 (experiment, small-M only): switch the - # general vgpr scale path to the b128 layout (ks innermost) so scales load - # with wide buffer_load_b128 instead of many b32, and -- when the whole K - # runs in the tail (loop_iters == 0) -- load them ALL up front into VGPRs - # (preload) instead of a per-tile ring. The b128 layout must match the host - # (test preshuffle_scale_for_load_path). _bvs_b128 is independent of - # loop_iters so host<->kernel layout always agree; _bvs_preload is the - # all-up-front variant. Full-K scale must fit in VGPRs -- NOT general. - _bvs_b128 = use_general_vgpr_scale and bool(int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_PRELOAD", "0"))) - _bvs_preload = _bvs_b128 and loop_iters == 0 - # The buffer_load->VGPR scale ring is built only when scale is actually loaded - # (coalesced A+B vgpr path, or the natural A-scale path with B still on TDM). - _bvs_active = use_buffer_vgpr_scale or use_natural_ascale - - if use_ref_segmented_lds_layout: - # The A/B data pools are no longer packed into the same per-stage - # 64KiB segment window. Scale pools keep the reference 0x800 stride so - # every TDM LDS target remains 2KiB-aligned. - ref_a_stage_stride = 0x9000 - ref_b_stage_stride = 0x8000 - ref_scale_stage_stride = 0x800 - if lds_a_data_bytes > ref_a_stage_stride: - raise RuntimeError( - "reference segmented LDS layout requires A stage <= 0x9000 bytes, " f"got {lds_a_data_bytes}" - ) - if lds_b_data_bytes > ref_b_stage_stride: - raise RuntimeError( - "reference segmented LDS layout requires B stage <= 0x8000 bytes, " f"got {lds_b_data_bytes}" - ) - if lds_a_scale_bytes > ref_scale_stage_stride or lds_b_scale_bytes > ref_scale_stage_stride: - raise RuntimeError( - "reference segmented LDS layout requires scale stage <= 0x800 bytes, " - f"got A={lds_a_scale_bytes} B={lds_b_scale_bytes}" - ) - - stage_a_data_off = [0x00000, 0x09000, 0x16000, 0x1F000] - stage_a_scale_off = [0x12000 + i * ref_scale_stage_stride for i in range(num_buffers)] - stage_b_scale_off = [0x28000 + i * ref_scale_stage_stride for i in range(num_buffers)] - stage_b_data_off = [0x30000 + i * ref_b_stage_stride for i in range(num_buffers)] - arena_alloc.ptr = LDS_GFX1250_MAX_BYTES - arena_total_bytes = arena_alloc.ptr - - # The epilogue may reuse the prefix only after all main/tail TDM traffic - # is fully fenced. This is outside the hot loop and avoids assuming a - # single monotonic per-stage base for the segmented pool layout. - epilogue_fence_threshold_bytes = 0 - else: - stage_phys_order = [i for i in range(num_buffers) if i != _last_compute_stage] - stage_phys_order.append(_last_compute_stage) - stage_base_off = [0] * num_buffers - for phys_i, logical_i in enumerate(stage_phys_order): - stage_base_off[logical_i] = phys_i * stage_pitch_bytes - arena_alloc.ptr = stage_pitch_bytes * num_buffers - arena_total_bytes = arena_alloc.ptr - epilogue_fence_threshold_bytes = tdm_epilogue_fence_threshold_bytes( - stage_base_off=stage_base_off, - tail_plan=_base_tail_plan, - loop_iters=loop_iters, - extra=extra, - ) + # The buffer_load->VGPR scale ring is built only for the natural A-scale path. + _bvs_active = use_natural_ascale + + stage_phys_order = [i for i in range(num_buffers) if i != _last_compute_stage] + stage_phys_order.append(_last_compute_stage) + stage_base_off = [0] * num_buffers + for phys_i, logical_i in enumerate(stage_phys_order): + stage_base_off[logical_i] = phys_i * stage_pitch_bytes + arena_alloc.ptr = stage_pitch_bytes * num_buffers + arena_total_bytes = arena_alloc.ptr + epilogue_fence_threshold_bytes = tdm_epilogue_fence_threshold_bytes( + stage_base_off=stage_base_off, + tail_plan=_base_tail_plan, + loop_iters=loop_iters, + extra=extra, + ) - stage_a_data_off = [stage_base_off[i] + stage_a_data_rel_off for i in range(num_buffers)] - stage_b_data_off = [stage_base_off[i] + stage_b_data_rel_off for i in range(num_buffers)] - stage_a_scale_off = [stage_base_off[i] + stage_a_scale_rel_off for i in range(num_buffers)] - stage_b_scale_off = [stage_base_off[i] + stage_b_scale_rel_off for i in range(num_buffers)] + stage_a_data_off = [stage_base_off[i] + stage_a_data_rel_off for i in range(num_buffers)] + stage_b_data_off = [stage_base_off[i] + stage_b_data_rel_off for i in range(num_buffers)] + stage_a_scale_off = [stage_base_off[i] + stage_a_scale_rel_off for i in range(num_buffers)] + stage_b_scale_off = [stage_base_off[i] + stage_b_scale_rel_off for i in range(num_buffers)] if use_tdm_store: lds_d_row_stride = warp_tile_n * elem_bytes_d + LDS_PAD_D_BYTES @@ -717,24 +597,6 @@ def _pick_compute_schedule_kind(): if use_n4k4_bscale: assert compute_schedule_kind in (COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING, COMPUTE_SCHEDULE_FP8_QUADRANT) - if use_buffer_vgpr_scale: - # General coalesced VGPR scale is supported on the row-major streaming - # schedule (mxscale fp8/a8w4, no scale_opsel, wave-specialized TDM); the - # ref-segmented path keeps the FP8 deep-pipeline schedule. - _vgpr_streaming_ok = ( - use_general_vgpr_scale - and compute_schedule_kind == COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING - and data_format in ("fp8", "a8w4") - and not is_ptpc - and not use_scale_opsel - and wave_specialized_tdm - ) - if not (use_fp8_deep_pipeline_schedule or _vgpr_streaming_ok): - raise ValueError( - f"scale_load_path={scale_load_path!r} requires the FP8 deep-pipeline schedule, or " - "the row-major streaming schedule with mxscale fp8/a8w4, no scale_opsel, and " - "wave_specialized_tdm" - ) use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) @@ -850,119 +712,11 @@ def _load_contig_i32(rsrc, base_idx, n, soff): out[start + c] = rv[c] return out - if const_expr(use_buffer_vgpr_scale): - # Direct global->VGPR scale load (no TDM/LDS). One K-tile's scales for - # all reps land in VGPRs; the loop carries a small prefetch ring. - # - # Two host layouts share the same prefetch/consume plumbing: - # - deep-pipeline (use_ref_segmented_lds_layout): lane-major - # [M_block(128), K_tile, group(2), lane16(16), 4 i32], single - # k_wmma_step, no lane_kgrp shift. i32_off = (mb*Kt+kt)*128 + - # group*64 + lane16*4. - # - general (use_general_vgpr_scale): coalesced - # [mb, K_tile, k_wmma_step, rep_group, lane32(32), j(4), spw] with the - # a8w4 lane_kgrp shift baked per physical lane (lane32 = kgrp*16+L), so - # the address is kgrp-agnostic. i32_off = - # (((mb*Kt+kt)*KS+ks)*NG + grp)*32*4 + lane32*4 + j. Matches the TDM - # path value-for-value (see flydsl_fp8_perf/verify_vgpr_scale_layout.py). - _bvs_a_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) - _bvs_b_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, max_size=False) - _bvs_Kt = K // tile_k # total K-tiles - _bvs_mb_a = blk_m // arith.index(warp_tile_m) + wave_m_idx - _bvs_mb_b = blk_n // arith.index(warp_tile_n) + wave_n_idx - _bvs_lane4 = lane16 * arith.index(4) - _gvs_lane32 = lane_kgrp * arith.index(16) + lane16 - # Per-tile VGPR scale count (flat, ordered [k_wmma_step][rep]); reduces to - # `rep` for the deep-pipeline path (k_wmma_steps == 1). - _vs_tile_a = k_wmma_steps * wmma_m_rep - _vs_tile_b = k_wmma_steps * b_scale_load_rep - - def _bvs_load_scales(rsrc, mb, rep, k_base): - # Deep-pipeline lane-major layout (k_wmma_steps == 1). - kt = k_base // arith.index(tile_k) - tile_i32 = (mb * arith.index(_bvs_Kt) + kt) * arith.index(128) - vals = [] - for ld in range_constexpr(rep // 4): # rep=8 -> 2 groups of 4 i32 - off = arith.index_cast(T.i32, tile_i32 + arith.index(ld * 64) + _bvs_lane4) - v = fx.Vector(buffer_ops.buffer_load(rsrc, off, vec_width=4, dtype=T.i32)) - for j in range_constexpr(4): - vals.append(v[j]) - return vals - - def _gvs_load_scales(rsrc, mb, rep, k_base): - # General coalesced layout: k_wmma_steps * rep i32, flat [ks][rep]. - # The per-tile K term (kt) goes in the scalar soffset, NOT the - # per-lane voffset VGPR: that keeps the voffset identical across - # prefetched K-tiles, so the backend CSEs it to one address - # register instead of recomputing it per tile (which forced - # s_wait_xcnt address-drain serialization and left the scale - # buffer_loads only partially hidden). The within-tile ks/grp - # delta stays in voffset, where it folds into the buffer - # instruction's immediate offset. - kt = k_base // arith.index(tile_k) - _NG = (rep + 3) // 4 - _S = k_wmma_steps * _NG * 32 * 4 - base_i32 = mb * arith.index(_bvs_Kt) * arith.index(_S) - kt_soff = arith.index_cast(T.i32, kt * arith.index(_S) * arith.index(4)) - vals = [] - for ks in range_constexpr(k_wmma_steps): - for grp in range_constexpr(_NG): - grp_i32 = base_i32 + arith.index((ks * _NG + grp) * 32 * 4) + _gvs_lane32 * arith.index(4) - off = arith.index_cast(T.i32, grp_i32) - v = fx.Vector( - buffer_ops.buffer_load(rsrc, off, vec_width=4, dtype=T.i32, soffset_bytes=kt_soff) - ) - for j in range_constexpr(4): - if const_expr(grp * 4 + j < rep): - vals.append(v[j]) - return vals - - def _gvs_load_scales_b128(rsrc, mb, rep, k_base, in_voffset): - # b128 layout [kt, grp, j, lane32, ks, spw]: a lane's KS scale-words - # are contiguous, so each (rep,lane) is one wide load (b128 for - # KS==4). Returns vals flat [ks][rep] to match _scales_for_emit. - # in_voffset=True bakes kt into the voffset VGPR (preload: all tiles' - # voffsets live at once -> distinct regs, no s0 reuse). in_voffset= - # False keeps the per-lane voffset constant across tiles and puts kt - # in the scalar soffset (per-tile ring). - kt = k_base // arith.index(tile_k) - KS = k_wmma_steps - _NG = (rep + 3) // 4 - _S128 = _NG * 4 * 32 * KS - if const_expr(in_voffset): - base = (mb * arith.index(_bvs_Kt) + kt) * arith.index(_S128) - soff = None - else: - base = mb * arith.index(_bvs_Kt) * arith.index(_S128) - soff = arith.index_cast(T.i32, kt * arith.index(_S128) * arith.index(4)) - vals = [None] * (KS * rep) - for _rep in range_constexpr(rep): - rep_off = base + arith.index(_rep * 32 * KS) + _gvs_lane32 * arith.index(KS) - ks_vals = _load_contig_i32(rsrc, rep_off, KS, soff) - for ks in range_constexpr(KS): - vals[ks * rep + _rep] = ks_vals[ks] - return vals - - def _bvs_prefetch(k_base, preload=False): - # Issue scale buffer_load for one K-tile; returns (a, b) VGPR lists, - # each flat [k_wmma_step][rep] (length _vs_tile_a / _vs_tile_b). - if const_expr(_bvs_b128): - a = _gvs_load_scales_b128(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base, preload) - b = _gvs_load_scales_b128(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base, preload) - elif const_expr(use_general_vgpr_scale): - a = _gvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base) - b = _gvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base) - else: - a = _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, k_base) - b = _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, k_base) - return a, b - - elif const_expr(use_natural_ascale): + if const_expr(use_natural_ascale): # Natural A-scale: read A_scale[M, K//32] straight into VGPRs (no reshuffle). # A row's K-scales are contiguous -> one wide load per M-block grabs all ks. # kt rides the scalar soffset so the per-lane voffset is K-tile-invariant - # (CSE'd -> loads fully hidden, like _gvs_load_scales). M op_sel: kgrp1 reads - # block wm+rep/2 (a pair per load). + # (CSE'd -> loads fully hidden). M op_sel: kgrp1 reads block wm+rep/2. _nat_as_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) _nat_row_i32 = K_scale // 4 # i32 elements per A_scale row (K_scale = K//32, %4==0) _nat_row0 = blk_m + warp_m_base + lane16 @@ -972,7 +726,7 @@ def _bvs_prefetch(k_base, preload=False): _vs_tile_b = 0 def _nat_as_load(k_base): - kt = k_base / arith.index(tile_k) + kt = k_base // arith.index(tile_k) soff = arith.index_cast(T.i32, kt * arith.index(scale_k_per_tile)) vals = [None] * (k_wmma_steps * nat_as_load) for i in range_constexpr(nat_as_load): @@ -982,7 +736,7 @@ def _nat_as_load(k_base): vals[ks * nat_as_load + i] = ks_vals[ks] return vals - def _bvs_prefetch(k_base, preload=False): + def _bvs_prefetch(k_base): return _nat_as_load(k_base), [] m_idx = fx.Index(i32_m) @@ -1333,12 +1087,6 @@ def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): """ if const_expr(is_ptpc): return None, None - if const_expr(use_general_vgpr_scale): - # VGPR scales (no op_sel in this path); slice the prefetched ring. - pf_a, pf_b = _vgpr_scale_box[0] - a = pf_a[ks * wmma_m_rep : (ks + 1) * wmma_m_rep] - b = pf_b[ks * b_scale_load_rep : (ks + 1) * b_scale_load_rep] - return a, b if const_expr(use_natural_ascale): # A from the natural VGPR ring (slice this ks; M op_sel handled in # _emit via nat_as_half); B from the N4K4 LDS layout. @@ -1522,8 +1270,8 @@ def compute_tile( pf_b_scales=None, ): current_accs = list(accs_in) - if const_expr(use_general_vgpr_scale or use_natural_ascale): - # Scales come from VGPR: use the loop-prefetched ring when provided, + if const_expr(use_natural_ascale): + # A-scale comes from VGPR: use the loop-prefetched ring when provided, # else issue the buffer_loads inline (tail path) for scale_k_base. if const_expr(pf_a_scales is not None): _vgpr_scale_box[0] = (pf_a_scales, pf_b_scales) @@ -1951,19 +1699,11 @@ def load_b_pair(wn_pair, ks): def _load_a_scales(ks): if const_expr(is_ptpc): return None # PTPC: scale applied in epilogue, not in K-loop - if const_expr(use_buffer_vgpr_scale): - if const_expr(pf_a_scales is not None): - return pf_a_scales # prefetched (issued in the prior compute tile) - return _bvs_load_scales(_bvs_a_rsrc, _bvs_mb_a, wmma_m_rep, scale_k_base) return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) def _load_b_scales(ks): if const_expr(is_ptpc): return None # PTPC: scale applied in epilogue, not in K-loop - if const_expr(use_buffer_vgpr_scale): - if const_expr(pf_b_scales is not None): - return pf_b_scales - return _bvs_load_scales(_bvs_b_rsrc, _bvs_mb_b, b_scale_load_rep, scale_k_base) return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) def emit_panel_2x2( @@ -2102,7 +1842,7 @@ def hot_loop_scheduler(): _half_wmma = _half_wm * wmma_n_rep _b_loads_per_frag = 2 if is_a8w4 else 4 # No scale ds_loads when scales are in registers (PTPC epilogue / VGPR). - _scale_dsrd = 0 if (is_ptpc or use_general_vgpr_scale) else 2 + _scale_dsrd = 0 if is_ptpc else 2 _a_half_dsrd = _half_wm * DS_LOADS_PER_A_FRAG for _ks in range_constexpr(k_wmma_steps): @@ -2589,7 +2329,7 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): - _drop_scale_waves = is_ptpc or use_buffer_vgpr_scale + _drop_scale_waves = is_ptpc if const_expr(use_natural_ascale): # wave0=A, wave1=B, wave2=B-scale (>=3 waves); A-scale is VGPR. At 2 # waves wave2 doesn't exist and B-scale rides wave0 as a secondary. @@ -2927,46 +2667,28 @@ def _bvs_tail_kb(): _bvs_tail_kt[0] += 1 return kb - # General VGPR scale: prefetch the tail's scales _bvs_D K-tiles ahead so - # each scale buffer_load overlaps an earlier tile's TDM wait instead of - # stalling the WMMA inline. The ref-segmented deep-pipeline path keeps its - # inline per-tile load to stay within its tight VGPR budget. Natural A-scale - # uses the same ahead-of-time tail prefetch. - _bvs_tail_pf = use_general_vgpr_scale or use_natural_ascale + # Natural A-scale: prefetch the tail's scales _bvs_D K-tiles ahead so each + # scale buffer_load overlaps an earlier tile's TDM wait instead of stalling + # the WMMA inline. + _bvs_tail_pf = use_natural_ascale _bvs_tail_ring = [] _bvs_tail_issue_kt = [loop_iters * num_buffers] - # Preload (opt-in): all K-tiles' scales loaded up front, indexed by step. - _bvs_preload_ring = [] - _bvs_tail_step = [0] def _bvs_tail_issue_one(): - if const_expr(_bvs_tail_pf and not _bvs_preload and _bvs_tail_issue_kt[0] < num_k_tiles): + if const_expr(_bvs_tail_pf and _bvs_tail_issue_kt[0] < num_k_tiles): _kb = split_k_base + arith.index(_bvs_tail_issue_kt[0] * tile_k) _bvs_tail_ring.append(_bvs_prefetch(_kb)) _bvs_tail_issue_kt[0] += 1 def _bvs_tail_scales(): - # Per-tile (scale_k_base, pf_a_scales, pf_b_scales): consume the preload - # set or the prefetch ring on the general path, else fall back to the - # inline-load k_base. - if const_expr(_bvs_preload): - _i = _bvs_tail_step[0] - _bvs_tail_step[0] += 1 - _cur_a, _cur_b = _bvs_preload_ring[_i] - return None, _cur_a, _cur_b + # Per-tile (scale_k_base, pf_a_scales, pf_b_scales): consume the prefetch + # ring on the natural path, else fall back to the inline-load k_base. if const_expr(_bvs_tail_pf): _cur_a, _cur_b = _bvs_tail_ring.pop(0) return None, _cur_a, _cur_b return _bvs_tail_kb(), None, None - if const_expr(_bvs_preload): - # One-shot: issue ALL K-tiles' scales up front (distinct voffset VGPRs, - # no shared-soffset reuse). All loads overlap the prologue/first B TDM. - rocdl.sched_barrier(0) - for _t in range_constexpr(num_k_tiles): - _kb = split_k_base + arith.index(_t * tile_k) - _bvs_preload_ring.append(_bvs_prefetch(_kb, preload=True)) - elif const_expr(_bvs_tail_pf): + if const_expr(_bvs_tail_pf): # Prime the ring before the first tail fence so even tile 0's scale # load overlaps its TDM wait rather than stalling the WMMA. rocdl.sched_barrier(0) @@ -3137,7 +2859,7 @@ def _emit_buffer_store(): use_scale_opsel, expert_sched_mode, atomic_barrier_enable, - scale_load_path, + ascale_load_path, fp8_schedule, ) @@ -3252,7 +2974,6 @@ def compile_ptpc_gemm( wave_specialized_tdm=True, use_scale_opsel=False, fp8_schedule="auto", - scale_load_path="tdm", use_tdm_store=(split_k == 1), N=N, K=K, @@ -3275,7 +2996,6 @@ def compile_ptpc_gemm( __all__ = [ - "is_ref_segmented_lds_layout", "compile_fp8fp4_gemm", "compile_mxscale_gemm", "compile_mxfp4_gemm", diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 4c66ce8d..2a7d80af 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -27,7 +27,6 @@ from kernels.gemm_fp8fp4_gfx1250 import ( # noqa: E402 compile_mxscale_gemm, compile_ptpc_gemm, - is_ref_segmented_lds_layout, use_n4k4_bscale_layout, use_natural_ascale_vgpr, ) @@ -40,109 +39,14 @@ SCALE_BLOCK = 32 -def preshuffle_e8m0_scale_coalesced(scale: torch.Tensor, block: int = 128) -> torch.Tensor: - """Lane-major scale layout for direct buffer_load->VGPR. - - Per (M_block=128, K_tile): [group(2), lane16(16), 4 i32], so a buffer_load_b128's - 16 lanes read 256 contiguous bytes. M = mb*128 + (group*4 + j)*16 + lane16. - """ - M, Ks = scale.shape - assert M % block == 0 and Ks % 4 == 0, f"M={M} Ks={Ks} block={block}" - assert block == 128, "coalesced scale layout assumes warp_tile=128 (8 subtiles)" - Kt = Ks // 4 - g = scale.view(M // block, 2, 4, 16, Kt, 4) # [mb, group, j, lane16, kt, spw] - g = g.permute(0, 4, 1, 3, 2, 5).contiguous() # [mb, kt, group, lane16, j, spw] - return g.view(M, Ks) - - -def preshuffle_e8m0_scale_coalesced_general( - scale: torch.Tensor, - warp_tile: int, - scale_k_per_tile: int, - kgrp_shift: int, - WMMA_DIM: int = 16, - row_align: int = None, - b128: bool = False, -) -> torch.Tensor: - """General lane-major scale layout for the buffer_load->VGPR path. - - Generalizes :func:`preshuffle_e8m0_scale_coalesced` (which is locked to - warp_tile=128, scale_k_per_tile=4, no lane_kgrp) to arbitrary ``warp_tile``, - ``scale_k_per_tile`` (=> ``k_wmma_steps``) and the a8w4/fp4 ``lane_kgrp`` - scale shift. The byte delivered to every *physical* 32-lane index - ``lane32 = lane_kgrp*16 + lane16`` is baked here so the kernel's load address - is lane_kgrp-agnostic and fully coalesced (32 lanes -> 512 contiguous bytes - per (k_wmma_step, rep_group)). - - Per (row_block of ``warp_tile`` rows, K-tile): layout is - ``[k_wmma_step, rep_group(ceil(rep/4)), lane32(32), j(4), spw(4)]`` and the - value at ``(lane32=(G*16+L), rep=grp*4+j)`` is the original e8m0 quadruple of - row ``block*warp_tile + (rep + G*kgrp_shift)*16 + L``; slots whose shifted rep - reaches ``rep_count`` are filled with E8M0 127 (=1.0) guards. - - Mirrors the TDM/LDS path value-for-value (see the offline parity check in - ``flydsl_fp8_perf/verify_vgpr_scale_layout.py``), so it is correct by - construction against the working ``scale_load_path='tdm'`` path. - """ - rows, K_scale = scale.shape - assert K_scale % scale_k_per_tile == 0, f"K_scale={K_scale} % spt={scale_k_per_tile}" - assert scale_k_per_tile % 4 == 0, f"scale_k_per_tile={scale_k_per_tile} must be a multiple of 4" - align = row_align if row_align is not None else warp_tile - if rows % align != 0: - pad = _align_up(rows, align) - rows - scale = torch.cat([scale, torch.full((pad, K_scale), 127, dtype=scale.dtype, device=scale.device)], dim=0) - rows = scale.shape[0] - R = warp_tile // WMMA_DIM # rep_count (wmma reps per warp tile) - KS = scale_k_per_tile // 4 # k_wmma_steps - KG = K_scale // scale_k_per_tile # K-tiles - NG = (R + 3) // 4 # rep groups (vec4 b128 each) - num_mb = rows // warp_tile - - # Index grids over [mb, kt, ks, grp, lane32, j]; spw is the trailing dim. - dev = scale.device - mb = torch.arange(num_mb, device=dev).view(num_mb, 1, 1, 1, 1, 1) - kt = torch.arange(KG, device=dev).view(1, KG, 1, 1, 1, 1) - ks = torch.arange(KS, device=dev).view(1, 1, KS, 1, 1, 1) - grp = torch.arange(NG, device=dev).view(1, 1, 1, NG, 1, 1) - l32 = torch.arange(32, device=dev).view(1, 1, 1, 1, 32, 1) - j = torch.arange(4, device=dev).view(1, 1, 1, 1, 1, 4) - G = l32 // 16 - L = l32 % 16 - rep = grp * 4 + j - orig_rep = rep + G * kgrp_shift - valid = orig_rep < R - orig_row = mb * warp_tile + orig_rep * WMMA_DIM + L - # Clamp out-of-range (guard) rows so the gather stays in bounds; masked below. - orig_row = torch.where(valid, orig_row, torch.zeros_like(orig_row)) - colg = kt * KS + ks # group-of-4 column index into the K dimension - # Gather the 4 spw bytes for each (row, colg): scale viewed as [rows, KG*KS, 4]. - scale_g = scale.view(rows, KG * KS, 4) - row_idx, colg_idx = torch.broadcast_tensors(orig_row, colg) - out = scale_g[row_idx, colg_idx] # [mb, KG, KS, NG, 32, 4, 4] - out = torch.where(valid.unsqueeze(-1), out, torch.full_like(out, 127)) - if b128: - # b128 variant: move ks to the innermost (next to spw) so a lane's - # KS scale-words are contiguous -> one buffer_load_b128 reads a whole - # tile's ks per (rep,lane). Output order [mb, kt, grp, j, lane32, ks, spw]. - out = out.permute(0, 1, 3, 5, 4, 2, 6).contiguous() - return out.reshape(num_mb, -1).contiguous() - - def preshuffle_e8m0_scale( scale: torch.Tensor, warp_tile: int, scale_k_per_tile: int = 4, WMMA_DIM: int = 16, - coalesced: bool = False, row_align: int = None, ) -> torch.Tensor: - """Preshuffle E8M0 scale: optional byte swap + interleave for WMMA access. - - ``coalesced=True`` produces the lane-major layout the scale_load_path - "vgpr" buffer_load->VGPR path expects. - """ - if coalesced: - return preshuffle_e8m0_scale_coalesced(scale, block=warp_tile) + """Preshuffle E8M0 scale: byte swap + interleave for WMMA TDM/LDS access.""" rows, K_scale = scale.shape assert K_scale % 4 == 0, f"K_scale must be divisible by 4, got {K_scale}" # Accept an unpadded row count (M for a_scale / N for b_scale): pad rows to @@ -184,29 +88,8 @@ def preshuffle_e8m0_bscale_n4k4(scale: torch.Tensor) -> torch.Tensor: return g.reshape(N // 64, (Ks // 4) * 256) -def preshuffle_scale_for_load_path( - scale, warp_tile, skt, *, scale_load_path, data_format, ref_segmented, row_align=None -): - """Host scale preshuffle matching the kernel's selected scale_load_path. - - - 'tdm': interleaved TDM/LDS layout. - - 'vgpr' on the ref-segmented deep-pipeline config: legacy lane-major - coalesced layout. - - 'vgpr' on any other (general) config: general coalesced layout, with the - a8w4/fp4 lane_kgrp scale shift. - """ - if scale_load_path == "vgpr": - if ref_segmented: - return preshuffle_e8m0_scale(scale, warp_tile, scale_k_per_tile=skt, coalesced=True) - kgrp_shift = 1 if data_format in ("a8w4", "fp4") else 0 - # FLYDSL_BUFFER_VGPR_SCALE_PRELOAD=1 switches the general vgpr path to the - # b128 layout (ks innermost) so the kernel reads scales with wide - # buffer_load_b128 instead of many b32. Must match the kernel flag of the - # same name (kernels/gemm_fp8fp4_gfx1250.py::_bvs_b128). - b128 = bool(int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_PRELOAD", "0"))) - return preshuffle_e8m0_scale_coalesced_general( - scale, warp_tile, skt, kgrp_shift, row_align=row_align, b128=b128 - ) +def preshuffle_scale_for_load_path(scale, warp_tile, skt, *, row_align=None): + """Host scale preshuffle for the TDM/LDS interleaved layout.""" return preshuffle_e8m0_scale(scale, warp_tile, scale_k_per_tile=skt, row_align=row_align) @@ -492,7 +375,6 @@ def _run_mxscale_gemm_test( waves_per_eu=None, expert_sched_mode=True, split_k=1, - scale_load_path="tdm", ascale_load_path="vgpr", return_launch_fn=False, ): @@ -536,12 +418,11 @@ def _run_mxscale_gemm_test( fmt_name = "A8W4" if is_a8w4 else ("MXFP4" if is_fp4 else "MXFP8") mcast_str = f", cluster=({cluster_m},{cluster_n})" if cluster_m > 1 or cluster_n > 1 else "" tdm_str = ", tdm_store" if use_tdm_store else ", buffer_store" - scale_load_str = "" if scale_load_path == "tdm" else f", scale_load={scale_load_path}" pad_str = _format_kernel_pad(M, N, K, padded_shape) print( f"\nRunning {fmt_name} GEMM: M={M}, N={N}, K={K}{pad_str}, " f"tiles=({tile_m},{tile_n},{tile_k}), bufs={num_buffers}" - f"{mcast_str}{tdm_str}{scale_load_str}, preshuffle, out={out_dtype}" + f"{mcast_str}{tdm_str}, preshuffle, out={out_dtype}" ) # Generate data @@ -575,18 +456,6 @@ def _run_mxscale_gemm_test( skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // m_warp warp_tile_n = tile_n // n_warp - _ref_seg = is_ref_segmented_lds_layout( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - num_buffers=num_buffers, - split_k=split_k, - wave_specialized_tdm=wave_specialized_tdm, - use_scale_opsel=use_scale_opsel, - ) _natural_ascale = use_natural_ascale_vgpr( data_format=data_format, tile_m=tile_m, @@ -595,7 +464,6 @@ def _run_mxscale_gemm_test( m_warp=m_warp, n_warp=n_warp, n=padded_n, - scale_load_path=scale_load_path, ascale_load_path=ascale_load_path, use_scale_opsel=use_scale_opsel, wave_specialized_tdm=wave_specialized_tdm, @@ -605,15 +473,7 @@ def _run_mxscale_gemm_test( # the (already row-padded) tensor is uploaded as-is. pass else: - a_scale = preshuffle_scale_for_load_path( - a_scale, - warp_tile_m, - skt, - scale_load_path=scale_load_path, - data_format=data_format, - ref_segmented=_ref_seg, - row_align=tile_m, - ) + a_scale = preshuffle_scale_for_load_path(a_scale, warp_tile_m, skt, row_align=tile_m) if use_n4k4_bscale_layout( data_format=data_format, tile_m=tile_m, @@ -622,7 +482,6 @@ def _run_mxscale_gemm_test( m_warp=m_warp, n_warp=n_warp, n=padded_n, - scale_load_path=scale_load_path, use_scale_opsel=use_scale_opsel, num_buffers=num_buffers, out_dtype=out_dtype, @@ -630,15 +489,7 @@ def _run_mxscale_gemm_test( ): b_scale = preshuffle_e8m0_bscale_n4k4(b_scale) else: - b_scale = preshuffle_scale_for_load_path( - b_scale, - warp_tile_n, - skt, - scale_load_path=scale_load_path, - data_format=data_format, - ref_segmented=_ref_seg, - row_align=tile_n, - ) + b_scale = preshuffle_scale_for_load_path(b_scale, warp_tile_n, skt, row_align=tile_n) # Preshuffle B data K_packed = padded_k // padded_shape["pack_b"] @@ -672,7 +523,6 @@ def _run_mxscale_gemm_test( split_k=split_k, use_scale_opsel=use_scale_opsel, expert_sched_mode=expert_sched_mode, - scale_load_path=scale_load_path, ascale_load_path=ascale_load_path, ) @@ -839,7 +689,6 @@ def test_mxfp4_gemm( @pytest.mark.parametrize("use_tdm_store", [True, False]) @pytest.mark.parametrize("use_scale_opsel", [True, False]) @pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) -@pytest.mark.parametrize("scale_load_path", ["tdm"]) def test_mxfp8_gemm( M, N, @@ -853,7 +702,6 @@ def test_mxfp8_gemm( use_tdm_store, out_dtype, use_scale_opsel, - scale_load_path, ): _run_mxscale_gemm_test( "fp8", @@ -870,7 +718,6 @@ def test_mxfp8_gemm( out_dtype, l2_prefetch_distance=2, use_scale_opsel=use_scale_opsel, - scale_load_path=scale_load_path, ) @@ -1035,7 +882,6 @@ def test_mxscale_n4k4_bscale(data_format, M, N, K, tile_n, tile_k, n_warp, num_b wave_specialized_tdm=ws, l2_prefetch_distance=0, use_scale_opsel=False, - scale_load_path="tdm", ) @@ -1100,7 +946,6 @@ def test_mxscale_natural_ascale(data_format, M, tile_m, tile_n, tile_k, m_warp, wave_specialized_tdm=True, l2_prefetch_distance=0, use_scale_opsel=False, - scale_load_path="tdm", ) @@ -1961,8 +1806,7 @@ def _run_benchmark(args): print(f" Tile: ({tile_m}, {tile_n}, {tile_k}), warps=({args.m_warp}x{args.n_warp})") print( f" Buffers={args.num_buffers}, out={args.out_dtype}, " - f"opsel={args.use_scale_opsel}, inst_prefetch={args.inst_prefetch}, " - f"scale_load={args.scale_load_path}" + f"opsel={args.use_scale_opsel}, inst_prefetch={args.inst_prefetch}" ) if args.warmup < 0: raise ValueError(f"--warmup must be >= 0, got {args.warmup}") @@ -1991,8 +1835,6 @@ def _run_benchmark(args): _ptpc_ignored.append("--no-wave-spec-tdm") if args.use_scale_opsel: _ptpc_ignored.append("--use-scale-opsel") - if args.scale_load_path != "tdm": - _ptpc_ignored.append(f"--scale-load-path {args.scale_load_path}") if _ptpc_ignored: print(f" Note: PTPC ignores (forced internally): {', '.join(_ptpc_ignored)}") print("=" * 72) @@ -2042,36 +1884,8 @@ def _run_benchmark(args): a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) skt = tile_k // SCALE_BLOCK - _ref_seg = is_ref_segmented_lds_layout( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=args.m_warp, - n_warp=args.n_warp, - num_buffers=args.num_buffers, - split_k=args.split_k, - wave_specialized_tdm=args.wave_spec_tdm, - use_scale_opsel=args.use_scale_opsel, - ) - a_scale = preshuffle_scale_for_load_path( - a_scale, - warp_tile_m, - skt, - scale_load_path=args.scale_load_path, - data_format=data_format, - ref_segmented=_ref_seg, - row_align=tile_m, - ) - b_scale = preshuffle_scale_for_load_path( - b_scale, - warp_tile_n, - skt, - scale_load_path=args.scale_load_path, - data_format=data_format, - ref_segmented=_ref_seg, - row_align=tile_n, - ) + a_scale = preshuffle_scale_for_load_path(a_scale, warp_tile_m, skt, row_align=tile_m) + b_scale = preshuffle_scale_for_load_path(b_scale, warp_tile_n, skt, row_align=tile_n) K_packed = padded_k // PACK_B b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -2133,7 +1947,6 @@ def _run_benchmark(args): use_scale_opsel=args.use_scale_opsel, expert_sched_mode=args.expert_sched_mode, atomic_barrier_enable=args.atomic_barrier_enable, - scale_load_path=args.scale_load_path, ) compiled_exe = flyc.compile( @@ -2350,36 +2163,8 @@ def _run_graph_verify(args): skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // args.m_warp warp_tile_n = tile_n // args.n_warp - _ref_seg = is_ref_segmented_lds_layout( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=args.m_warp, - n_warp=args.n_warp, - num_buffers=args.num_buffers, - split_k=args.split_k, - wave_specialized_tdm=args.wave_spec_tdm, - use_scale_opsel=args.use_scale_opsel, - ) - a_scale = preshuffle_scale_for_load_path( - a_scale, - warp_tile_m, - skt, - scale_load_path=args.scale_load_path, - data_format=data_format, - ref_segmented=_ref_seg, - row_align=tile_m, - ) - b_scale = preshuffle_scale_for_load_path( - b_scale, - warp_tile_n, - skt, - scale_load_path=args.scale_load_path, - data_format=data_format, - ref_segmented=_ref_seg, - row_align=tile_n, - ) + a_scale = preshuffle_scale_for_load_path(a_scale, warp_tile_m, skt, row_align=tile_m) + b_scale = preshuffle_scale_for_load_path(b_scale, warp_tile_n, skt, row_align=tile_n) K_packed = padded_k // padded_shape["pack_b"] b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -2415,7 +2200,6 @@ def _run_graph_verify(args): use_scale_opsel=args.use_scale_opsel, expert_sched_mode=args.expert_sched_mode, atomic_barrier_enable=args.atomic_barrier_enable, - scale_load_path=args.scale_load_path, ) c_flat = c_gpu.contiguous() @@ -2527,12 +2311,6 @@ def launch(): parser.add_argument("--no-wave-spec-tdm", dest="wave_spec_tdm", action="store_false", default=True) parser.add_argument("--waves-per-eu", type=int, default=None) parser.add_argument("--use-scale-opsel", action="store_true", default=False) - parser.add_argument( - "--scale-load-path", - type=str, - default="tdm", - choices=["tdm", "vgpr"], - ) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument( "--atomic-barrier-enable", @@ -2638,7 +2416,6 @@ def _run_correctness_test(): inst_prefetch=args.inst_prefetch, waves_per_eu=args.waves_per_eu, expert_sched_mode=args.expert_sched_mode, - scale_load_path=args.scale_load_path, ) if args.verify_graph: From 6a0d5d92ceac7ceed1ecc6233158ea95464419cf Mon Sep 17 00:00:00 2001 From: aoli26 Date: Mon, 15 Jun 2026 13:59:49 +0000 Subject: [PATCH 12/16] Switch B-scale to 32x4 preshuffle layout --- kernels/gemm_fp8fp4_gfx1250.py | 170 +++++++++++----------- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 33 ++--- 2 files changed, 94 insertions(+), 109 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 646da442..48e496b4 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -25,7 +25,6 @@ get_lds_memref, issue_tdm_loads, lds_load_b32_raw, - lds_load_b64_raw, lds_load_b128_raw, pipeline_fence, pipeline_fence_signal, @@ -110,7 +109,8 @@ def use_n4k4_bscale_layout( ): """B-scale uses the tile-independent N4K4 layout on row-major-streaming and the FP8/A8W4 quadrant schedule (one preshuffle serves both); deep-pipeline and fp4 - keep the legacy layout.""" + keep the legacy layout. (Production layout is 32x4 `preshuffle_scale`; the gate + name keeps "n4k4" until the naming cleanup TODO.)""" if scale_mode != "mxscale": return False if data_format not in ("fp8", "a8w4"): @@ -119,9 +119,9 @@ def use_n4k4_bscale_layout( return False if tile_k % 128 != 0: return False - if n % 64 != 0: + if n % 32 != 0: return False - if tile_n % 64 != 0 and 64 % tile_n != 0: + if tile_n % 32 != 0: return False wmma_m_rep = (tile_m // m_warp) // WMMA_M wmma_n_rep = (tile_n // n_warp) // WMMA_N @@ -358,24 +358,22 @@ def compile_fp8fp4_gemm( wave_specialized_tdm=wave_specialized_tdm, fp8_schedule=fp8_schedule, ) - use_n4k4_opsel = False + # 32x4 B-scale layout (preshuffle_scale): [N//32, K//128, 32, 4]. A 128B atomic + # (32 N-rows x 4 K-scales) = one 32-lane WMMA scale VGPR. op_sel pairs an atom's + # two 16-N halves into one b32 load (1 load -> 2 WMMAs); enabled for every even- + # rep FP8/A8W4 schedule. Odd rep / warp_tile_n=16 owns half an atom (runtime + # 16-half select) and fp4 is already 1 atom = 1 WMMA, so both stay op_sel=0. + bs32_opsel = False if use_n4k4_bscale: - if K_scale % 4 != 0: - raise ValueError(f"N4K4 B-scale requires K_scale % 4 == 0, got {K_scale}") - n4k4_n_groups = N // 64 - n4k4_bs_global_row_stride = (K // WMMA_K) * 256 - n4k4_bs_lds_row_stride = k_wmma_steps * 256 - # Per-tile N-group count padded to a power of two so the TDM warp split - # stays clean (a non-pow2 count, e.g. 192->3, miscopies LDS). Cost-free - # for 64/128/256 (1/2/4 groups); tile_n=192 copies 1 extra oob-clipped group. - n4k4_bs_tile_groups = 1 << ((tile_n + 63) // 64 - 1).bit_length() - n4k4_bs_lds_rows = n4k4_bs_tile_groups - # N op_sel: pack blocks (j, j+rep/2) into one VGPR via lane_kgrp (kgrp1 = - # the "second half"), halving B-scale loads & VGPRs. Power-of-2 rep only - # (then the kgrp byte offset is uniform); rep 1/3/6/12 stay off. - use_n4k4_opsel = wmma_n_rep >= 2 and (wmma_n_rep & (wmma_n_rep - 1)) == 0 - _half = wmma_n_rep // 2 - n4k4_opsel_kgrp_off = (_half // 4) * n4k4_bs_lds_row_stride + (_half % 4) * 4 + bs32_atom_bytes = 128 # 32 rows x 4 K-scales + bs32_global_row_stride = (K // WMMA_K) * bs32_atom_bytes # bytes per atom row (= K) + bs32_lds_row_stride = k_wmma_steps * bs32_atom_bytes # LDS bytes per atom row + bs32_tile_atoms = tile_n // 32 + # Pad atom count to pow2 so the TDM warp split stays clean (non-pow2, e.g. + # 6, miscopies LDS). Cost-free for pow2 atom counts; else 1-2 oob-clipped. + bs32_tile_atoms_pad = 1 << (bs32_tile_atoms - 1).bit_length() + bs32_opsel = (not is_fp4) and (wmma_n_rep % 2 == 0) + bs32_n_load = (wmma_n_rep // 2) if bs32_opsel else wmma_n_rep # b32 loads per ks # A-scale natural buffer_load->VGPR (no reshuffle), paired with N4K4 B (TDM). # TEMP: ascale_load_path='tdm' fallback + non-ws kept until those legacy paths @@ -414,7 +412,8 @@ def compile_fp8fp4_gemm( # The general VGPR scale path holds scales in registers (no ds_load), so it # contributes zero. Finalized below once use_natural_ascale is known. _a_scale_ds = (wmma_m_rep + 3) // 4 - _b_scale_ds = (b_scale_load_rep + 3) // 4 + # 32x4 B-scale issues bs32_n_load b32 ds_loads per ks; legacy packs into b128s. + _b_scale_ds = bs32_n_load if use_n4k4_bscale else (b_scale_load_rep + 3) // 4 _scale_ds_loads = _a_scale_ds + _b_scale_ds _a_frag_ds = wmma_m_rep * _a_frag_loads_per_wm _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _scale_ds_loads @@ -428,7 +427,7 @@ def compile_fp8fp4_gemm( # Natural A-scale lives in VGPRs (buffer_load), so it needs no LDS. lds_a_scale_bytes = 0 if (is_ptpc or use_natural_ascale) else tile_m * scale_k_per_tile + _scale_guard_bytes if use_n4k4_bscale: - lds_b_scale_bytes = n4k4_bs_lds_rows * n4k4_bs_lds_row_stride + _scale_guard_bytes + lds_b_scale_bytes = bs32_tile_atoms_pad * bs32_lds_row_stride + _scale_guard_bytes else: lds_b_scale_bytes = 0 if is_ptpc else tile_n * scale_k_per_tile + _scale_guard_bytes interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile @@ -628,10 +627,7 @@ def _pick_compute_schedule_kind(): _fp8_half_wn = wmma_n_rep // 2 _fp8_group_size = _fp8_half_wm * _fp8_half_wn if use_n4k4_bscale: - # N4K4 B-scale ds_load instruction count (matches load_n4k4_bscale). - _n4k4_bn = b_scale_load_rep // 2 if use_n4k4_opsel else b_scale_load_rep - _n4k4_bpl = 4 if _n4k4_bn % 4 == 0 else (2 if _n4k4_bn % 2 == 0 else 1) - _fp8_b_scale_loads = _n4k4_bn // _n4k4_bpl + _fp8_b_scale_loads = bs32_n_load # 32x4: one b32 per atom-or-WMMA per ks else: _fp8_b_scale_loads = 0 if is_ptpc else (b_scale_load_rep + 3) // 4 if use_fp8_deep_pipeline_schedule: @@ -814,18 +810,18 @@ def make_desc_as(memref, k_base): def make_desc_bs(memref, k_base): if const_expr(use_n4k4_bscale): - # N4K4: copy this tile's N-groups x K-blocks slice of the - # preshuffled [N//64, (K//128)*256] B-scale tensor. Each row is - # one 64-N group; the contiguous dim1 = tile_k//128 * 256B blocks. - g_off = blk_n // arith.index(64) - col_off = (k_base // arith.index(WMMA_K)) * arith.index(256) + # 32x4: copy this tile's 32-N atoms x K-blocks slice of the + # preshuffled [N//32, (K//128)*128] B-scale tensor. Each row is one + # 32-N atom group; contiguous dim1 = tile_k//128 * 128B atomics. + a_off = blk_n // arith.index(32) + col_off = (k_base // arith.index(WMMA_K)) * arith.index(bs32_atom_bytes) return _make_tdm_desc( global_ptr=arg_b_scale, lds_memref=memref, - global_offset=(g_off, col_off), - tensor_shape=(n4k4_n_groups, n4k4_bs_global_row_stride), - strides=(n4k4_bs_global_row_stride, 1), - tile_shape=(n4k4_bs_tile_groups, n4k4_bs_lds_row_stride), + global_offset=(a_off, col_off), + tensor_shape=(N // 32, bs32_global_row_stride), + strides=(bs32_global_row_stride, 1), + tile_shape=(bs32_tile_atoms_pad, bs32_lds_row_stride), elem_bytes=1, pad_interval=0, pad_amount=0, @@ -833,7 +829,7 @@ def make_desc_bs(memref, k_base): workgroup_mask=b_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, - oob_outer_bound=n4k4_n_groups, + oob_outer_bound=N // 32, ) k_scale_off = k_base // arith.index(SCALE_BLOCK) outer_off = blk_n // arith.index(b_scale_load_rep) @@ -993,50 +989,43 @@ def _precompute_scale_lane_bases(lds_ptr, warp_base, reps, interleaved_cols): base = base + lane_kgrp * arith.index(SCALES_PER_WMMA) return lds_ptr, [base] - def _precompute_n4k4_bscale_bases(lds_ptr): - """Precompute (first_block, lane_byte) for this warp's N4K4 reads. + def _precompute_bs32_bases(lds_ptr): + """Tile-local 32-N atom base for the warp's 32x4 B-scale read. - The TDM copies the 64-N group(s) containing the tile; within a copied - group a lane's 4 N-blocks are 16 contiguous bytes and consecutive - groups are n4k4_bs_lds_row_stride apart. ``b0`` is the warp's first - N-block *inside the copied group(s)*: its own N offset plus, for tiles - smaller than one 64-N group, the tile's slice offset within its group. - Only lanes 0..15 carry the consumed scale (scaleAType=0, no op_sel); - lanes 16..31 read the same word. + An LDS atom row (32 N-rows x 4 K-scales = 128B) is one 32-lane WMMA scale + VGPR. op_sel path (even rep): the warp owns whole atoms atom0+j. Else + (fp4 / odd rep): each WMMA reads its own 16/32-N into the operand lanes. """ - b0 = wave_n_idx * arith.index(b_scale_load_rep) - if const_expr(tile_n < 64): - # Sub-64 tile: whole containing group was copied; shift to this - # tile's slice (row offset 0/16/32/48 -> N-block 0/1/2/3). - b0 = b0 + (blk_n % arith.index(64)) // arith.index(16) - lane_off = lane16 * arith.index(16) - return lds_ptr, (b0, lane_off) - - _N4K4_LOADERS = {1: lds_load_b32_raw, 2: lds_load_b64_raw, 4: lds_load_b128_raw} - - def load_n4k4_bscale(lds_buffer, bases, reps, ks=0): - """Load N4K4 B-scale i32s for K-subtile *ks*.""" - b0, lane_off = bases - ks_off = arith.index(ks * 256) - row_stride = arith.index(n4k4_bs_lds_row_stride) - if const_expr(use_n4k4_opsel): - n_load = reps // 2 # read first half; kgrp1 supplies the matching second half - lane = lane_off + lane_kgrp * arith.index(n4k4_opsel_kgrp_off) - else: - n_load = reps - lane = lane_off - per_load = 4 if n_load % 4 == 0 else (2 if n_load % 2 == 0 else 1) + return lds_ptr, warp_n_base // arith.index(32) + + def load_bs32_bscale(lds_buffer, atom0, ks): + """Load 32x4 B-scale i32s for K-subtile *ks* (one b32 per atom-or-WMMA).""" + stride = arith.index(bs32_lds_row_stride) + ks_off = arith.index(ks * bs32_atom_bytes) results = [] - for i in range_constexpr(n_load // per_load): - blk = b0 + arith.index(i * per_load) - off = (blk // arith.index(4)) * row_stride + (blk % arith.index(4)) * arith.index(4) + lane + ks_off - raw = _N4K4_LOADERS[per_load](lds_buffer, off) - if const_expr(per_load == 1): - results.append(raw) - else: - vec = fx.Vector(raw) - for j in range_constexpr(per_load): - results.append(vec[j]) + if const_expr(bs32_opsel): + # Even rep: full 32-lane atom; op_sel picks the 16-half in _emit_wmma. + lane = (lane_kgrp * arith.index(16) + lane16) * arith.index(4) + for j in range_constexpr(wmma_n_rep // 2): + off = (atom0 + arith.index(j)) * stride + ks_off + lane + results.append(lds_load_b32_raw(lds_buffer, off)) + elif const_expr(is_fp4): + # fp4: one 32-N atom per WMMA (no op_sel). + lane = (lane_kgrp * arith.index(16) + lane16) * arith.index(4) + for wn in range_constexpr(wmma_n_rep): + off = (atom0 + arith.index(wn)) * stride + ks_off + lane + results.append(lds_load_b32_raw(lds_buffer, off)) + else: + # fp8 odd rep: each WMMA's 16-N into lanes 0-15 (op_sel=0); the atom + # and its 16-half are runtime (warp may start mid-atom). + for wn in range_constexpr(wmma_n_rep): + row16 = warp_n_base + arith.index(wn * 16) + off = ( + (row16 // arith.index(32)) * stride + + ks_off + + (row16 % arith.index(32) + lane16) * arith.index(4) + ) + results.append(lds_load_b32_raw(lds_buffer, off)) return results def load_scale_b128(lds_buffer, scale_base, reps, ks=0): @@ -1074,9 +1063,9 @@ def load_scale_slice_b128(lds_buffer, scale_base, full_reps, rep_start, rep_coun _vgpr_scale_box = [None] def _load_b_scale_lds(bs_buf, bs_bases, ks): - """Load B-scale from LDS, dispatching to the N4K4 or legacy layout.""" + """Load B-scale from LDS, dispatching to the 32x4 or legacy layout.""" if const_expr(use_n4k4_bscale): - return load_n4k4_bscale(bs_buf, bs_bases, b_scale_load_rep, ks) + return load_bs32_bscale(bs_buf, bs_bases, ks) return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): @@ -1169,13 +1158,12 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): scaleBType=a_opsel, ) else: - # 16x16x128 WMMA: A8W4 (fmtA=FP4) or FP8 (fmtA=FP8) - if const_expr(use_scale_opsel): + # 16x16x128 WMMA: A8W4 (fmtA=FP4) or FP8 (fmtA=FP8). op_sel pairs + # adjacent 16-N halves (legacy use_scale_opsel or 32x4 even rep); + # else one scale per WMMA (32x4 odd rep, or no op_sel). + if const_expr(use_scale_opsel or bs32_opsel): b_scale_idx = wn // 2 b_opsel = wn % 2 - elif const_expr(use_n4k4_opsel): - b_scale_idx = wn % (wmma_n_rep // 2) - b_opsel = wn // (wmma_n_rep // 2) else: b_scale_idx = wn b_opsel = 0 @@ -1289,7 +1277,7 @@ def compute_tile( lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a ) if const_expr(use_n4k4_bscale): - bs_buf, bs_bases = _precompute_n4k4_bscale_bases(lds_bs) + bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) else: bs_buf, bs_bases = _precompute_scale_lane_bases( lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b @@ -1492,7 +1480,7 @@ def compute_tile_fp8_quadrant( b_buf, b_bases = _precompute_b_lane_bases(lds_b) as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) if const_expr(use_n4k4_bscale): - bs_buf, bs_bases = _precompute_n4k4_bscale_bases(lds_bs) + bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) else: bs_buf, bs_bases = _precompute_scale_lane_bases( lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b @@ -1586,7 +1574,11 @@ def _emit_group_col(wm_base, wn_base, a_frags, b_frags, a_scales, b_scales, wn_l ) b_left_frags, b_scales = _load_b_left_bundle(0) - _first_top_row_keep = max((_fp8_half_wm - 1) * DS_LOADS_PER_A_FRAG - _fp8_b_scale_loads, 0) + # Margin = a-top drain depth (b-scale is issued earlier, so it is unrelated); + # keep it at the per-WMMA count so op_sel's fewer b-scale loads don't widen + # keep and race the top-row A frags. + _top_keep_margin = b_scale_load_rep if const_expr(bs32_opsel) else _fp8_b_scale_loads + _first_top_row_keep = max((_fp8_half_wm - 1) * DS_LOADS_PER_A_FRAG - _top_keep_margin, 0) _bottom_left_keep = max(_b_half_loads - DS_LOADS_PER_A_FRAG, 0) for ks in range_constexpr(k_wmma_steps): @@ -2325,7 +2317,7 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): adv_as_i32 = fx.Int32(tile_k // SCALE_BLOCK * wmma_m_rep) # N4K4 advances by one tile's worth of K-blocks (k_wmma_steps*256B) per # K-step; the legacy interleaved layout advances by scale_k_per_tile*rep. - adv_bs_i32 = fx.Int32(n4k4_bs_lds_row_stride if use_n4k4_bscale else tile_k // SCALE_BLOCK * b_scale_load_rep) + adv_bs_i32 = fx.Int32(bs32_lds_row_stride if use_n4k4_bscale else tile_k // SCALE_BLOCK * b_scale_load_rep) pred_const = fx.Int32(1) if const_expr(wave_specialized_tdm): diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 2a7d80af..30667f26 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -66,26 +66,19 @@ def preshuffle_e8m0_scale( return g.reshape(-1, k_groups * k_wmma_steps * wmma_rep * SCALES_PER_WMMA) -def preshuffle_e8m0_bscale_n4k4(scale: torch.Tensor) -> torch.Tensor: - """Tile-independent N4K4 B-scale preshuffle: [N, K_scale] -> [N//64, (K_scale//4)*256]. +def preshuffle_scale(scale: torch.Tensor) -> torch.Tensor: + """32x4 scale layout (A or B): [R, Ks] -> [R//32, K] (Ks = K//32). - Atomic block = 4 N-blocks x 1 K-block = 64 N-rows x 4 scale-bytes = 256B, so - the byte layout depends only on the constants (64, 16, 4, 4) and never on - tile_n/n_warp/tile_k. Weights are preshuffled once and served to any tile - config landing on the default row-major streaming schedule. + out[r_o, k_o, r_i, k_i] = scale[r_o*32 + r_i, k_o*4 + k_i] - B_scale_pre[g, kb, n, r, k] = scale[g*64 + r*16 + n, kb*4 + k] - - where g = N//64 group, kb = K//128 block (4 scale bytes = one WMMA's K=128), - n = lane16 (N-row within a 16-block), r = the 4 N-blocks in a 64-group, - k = the 4 scale bytes within one WMMA. Mirrors the kernel's N4K4 TDM+LDS - read (see flydsl_fp8_perf/verify_n4k4_bscale_layout.py for the parity proof). + A 128B atomic (32 rows x 4 K-scales) is exactly one 32-lane WMMA scale VGPR + (lane L = row L's 4 K-scales). tile-independent; serves fp8 (16x16, op_sel) and + fp4 (32x16) uniformly. """ - N, Ks = scale.shape - assert N % 64 == 0 and Ks % 4 == 0, f"N4K4 B-scale needs N%64==0, Ks%4==0; got N={N} Ks={Ks}" - g = scale.view(N // 64, 4, 16, Ks // 4, 4) # [g, r, n, kb, k] - g = g.permute(0, 3, 2, 1, 4).contiguous() # [g, kb, n, r, k] - return g.reshape(N // 64, (Ks // 4) * 256) + R, Ks = scale.shape + assert R % 32 == 0 and Ks % 4 == 0, f"preshuffle_scale needs R%32==0, Ks%4==0; got R={R} Ks={Ks}" + x = scale.view(R // 32, 32, Ks // 4, 4).permute(0, 2, 1, 3).contiguous() # [R//32, Ks//4, 32, 4] + return x.reshape(R // 32, -1) # [R//32, K] def preshuffle_scale_for_load_path(scale, warp_tile, skt, *, row_align=None): @@ -487,7 +480,7 @@ def _run_mxscale_gemm_test( out_dtype=out_dtype, wave_specialized_tdm=wave_specialized_tdm, ): - b_scale = preshuffle_e8m0_bscale_n4k4(b_scale) + b_scale = preshuffle_scale(b_scale) else: b_scale = preshuffle_scale_for_load_path(b_scale, warp_tile_n, skt, row_align=tile_n) @@ -814,9 +807,9 @@ def test_a8w4_gemm_irregular_m_tile16(M, N, K, use_tdm_store): # group counts 1/2/4 and the non-power-of-2 group count 3 that exercises the # TDM warp-distribution power-of-two padding), both data formats, k_wmma_steps # 1/2/4, wave-spec on/off, f32/bf16, multi-buffer, and ragged/decode M. -_N4K4_N_FOR_TN = {16: 128, 32: 128, 64: 128, 128: 256, 192: 384, 256: 512} +_N4K4_N_FOR_TN = {32: 128, 64: 128, 128: 256, 192: 384, 256: 512} _N4K4_TN_NW = [ - (16, 1), (32, 1), (32, 2), (64, 1), (64, 2), (64, 4), + (32, 1), (32, 2), (64, 1), (64, 2), (64, 4), (128, 1), (128, 2), (128, 4), (192, 1), (192, 2), (192, 4), (256, 1), (256, 2), (256, 4), ] # fmt: skip From 6b537fa3873b4468b8a27eb3b25f1df10c30a15c Mon Sep 17 00:00:00 2001 From: aoli Date: Mon, 15 Jun 2026 15:57:14 +0000 Subject: [PATCH 13/16] deep pipeline use 32x4 b scale preshuffle --- kernels/gemm_fp8fp4_gfx1250.py | 161 +++++++++++++++------- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 41 ++++++ 2 files changed, 155 insertions(+), 47 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 48e496b4..88c0437f 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -86,6 +86,39 @@ def _vec_chunks(n: int): LDS_GFX1250_MAX_BYTES = 5 * LDS_SEGMENT_BYTES +def _is_fp8_deep_pipeline( + *, + data_format, + tile_m, + tile_n, + tile_k, + m_warp, + n_warp, + num_buffers, + out_dtype, + wave_specialized_tdm, + use_scale_opsel, + fp8_schedule, +): + """Whether this config takes the FP8/A8W4 deep-pipeline schedule (fixed + 256x256x128 / nbuf4 / wave-spec shape). Mirrors the compile-time eligibility + + schedule selection so the module-level gates agree with the kernel.""" + if data_format not in ("fp8", "a8w4"): + return False + eligible = ( + tile_m == 256 + and tile_n == 256 + and tile_k == 128 + and m_warp == 2 + and n_warp == 2 + and num_buffers == 4 + and wave_specialized_tdm + and out_dtype == "bf16" + and not use_scale_opsel + ) + return fp8_schedule == "deep-pipeline" or (fp8_schedule == "auto" and eligible) + + def _is_row_major_streaming(wmma_m_rep, wmma_n_rep, n_accs): # Mirrors _pick_compute_schedule_kind: row-major when a rep is odd or n_accs < 8. return wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8 @@ -107,10 +140,10 @@ def use_n4k4_bscale_layout( wave_specialized_tdm=False, fp8_schedule="auto", ): - """B-scale uses the tile-independent N4K4 layout on row-major-streaming and the - FP8/A8W4 quadrant schedule (one preshuffle serves both); deep-pipeline and fp4 - keep the legacy layout. (Production layout is 32x4 `preshuffle_scale`; the gate - name keeps "n4k4" until the naming cleanup TODO.)""" + """B-scale uses the 32x4 `preshuffle_scale` layout on every FP8/A8W4 mxscale + schedule (row-major, quadrant, deep-pipeline); fp4 keeps the legacy layout. (The + gate name keeps "n4k4" until the naming cleanup TODO; extra kwargs are kept for + signature compatibility.)""" if scale_mode != "mxscale": return False if data_format not in ("fp8", "a8w4"): @@ -123,26 +156,7 @@ def use_n4k4_bscale_layout( return False if tile_n % 32 != 0: return False - wmma_m_rep = (tile_m // m_warp) // WMMA_M - wmma_n_rep = (tile_n // n_warp) // WMMA_N - n_accs = wmma_m_rep * wmma_n_rep - if _is_row_major_streaming(wmma_m_rep, wmma_n_rep, n_accs): - return True - # Even-rep FP8/A8W4: quadrant uses N4K4; the deep-pipeline shape keeps legacy - # (its B-scale rides the VGPR ring, not LDS). - deep_eligible = ( - tile_m == 256 - and tile_n == 256 - and tile_k == 128 - and m_warp == 2 - and n_warp == 2 - and num_buffers == 4 - and wave_specialized_tdm - and out_dtype == "bf16" - and not use_scale_opsel - ) - is_deep = fp8_schedule == "deep-pipeline" or (fp8_schedule == "auto" and deep_eligible) - return not is_deep + return True def use_natural_ascale_vgpr( @@ -158,12 +172,15 @@ def use_natural_ascale_vgpr( ascale_load_path="vgpr", use_scale_opsel=False, wave_specialized_tdm=False, + num_buffers=2, + out_dtype="bf16", + fp8_schedule="auto", ): """Whether A-scale uses the natural (un-reshuffled) buffer_load->VGPR path. - Row-major-streaming (decode) only: pairs with N4K4 B (TDM), A read straight from - runtime ``A_scale[M, K//32]`` into VGPRs, loop-ahead prefetched. Quadrant (prefill) - keeps legacy/TDM A-scale (its target is A=tdm-M4K4). Requires wave-specialized TDM.""" + Row-major-streaming (decode) and deep-pipeline: A read straight from runtime + ``A_scale[M, K//32]`` into VGPRs (loop-ahead prefetched), paired with 32x4 B (TDM). + Quadrant keeps legacy/TDM A-scale (its target is A=tdm-M4K4). Requires wave-spec TDM.""" if ascale_load_path != "vgpr": return False if not wave_specialized_tdm: @@ -182,7 +199,22 @@ def use_natural_ascale_vgpr( return False wmma_m_rep = (tile_m // m_warp) // WMMA_M wmma_n_rep = (tile_n // n_warp) // WMMA_N - return _is_row_major_streaming(wmma_m_rep, wmma_n_rep, wmma_m_rep * wmma_n_rep) + if _is_row_major_streaming(wmma_m_rep, wmma_n_rep, wmma_m_rep * wmma_n_rep): + return True + # Deep-pipeline (even rep) also defaults to natural A-scale; quadrant does not. + return _is_fp8_deep_pipeline( + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=num_buffers, + out_dtype=out_dtype, + wave_specialized_tdm=wave_specialized_tdm, + use_scale_opsel=use_scale_opsel, + fp8_schedule=fp8_schedule, + ) @functools.lru_cache(maxsize=256) @@ -210,7 +242,7 @@ def compile_fp8fp4_gemm( use_scale_opsel: bool = False, expert_sched_mode: bool = True, atomic_barrier_enable: bool = False, - ascale_load_path: str = "vgpr", + ascale_load_path: str = "auto", fp8_schedule: str = "auto", ): """Compile an FP4/FP8/A8W4 GEMM kernel with TDM async copy. @@ -237,6 +269,28 @@ def compile_fp8fp4_gemm( if scale_mode == "ptpc" and data_format not in ("fp8", "a8w4"): raise ValueError("scale_mode='ptpc' currently only supports data_format='fp8' or 'a8w4'") + # Deep-pipeline defaults to TDM A-scale: the natural A-scale VGPR ring adds + # register pressure that spills on mxfp8 deep. vgpr is the default everywhere + # else and stays explicitly selectable for deep until that spill is resolved. + if ascale_load_path == "auto": + ascale_load_path = ( + "tdm" + if _is_fp8_deep_pipeline( + data_format=data_format, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + m_warp=m_warp, + n_warp=n_warp, + num_buffers=num_buffers, + out_dtype=out_dtype, + wave_specialized_tdm=wave_specialized_tdm, + use_scale_opsel=use_scale_opsel, + fp8_schedule=fp8_schedule, + ) + else "vgpr" + ) + is_fp4 = data_format == "fp4" is_a8w4 = data_format == "a8w4" is_ptpc = scale_mode == "ptpc" @@ -392,6 +446,9 @@ def compile_fp8fp4_gemm( ascale_load_path=ascale_load_path, use_scale_opsel=use_scale_opsel, wave_specialized_tdm=wave_specialized_tdm, + num_buffers=num_buffers, + out_dtype=out_dtype, + fp8_schedule=fp8_schedule, ) # M op_sel pairs A-blocks (wm, wm+rep/2) into one VGPR via lane_kgrp; power-of-2 rep. use_natural_ascale_opsel = use_natural_ascale and wmma_m_rep >= 2 and (wmma_m_rep & (wmma_m_rep - 1)) == 0 @@ -595,7 +652,11 @@ def _pick_compute_schedule_kind(): use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE if use_n4k4_bscale: - assert compute_schedule_kind in (COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING, COMPUTE_SCHEDULE_FP8_QUADRANT) + assert compute_schedule_kind in ( + COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING, + COMPUTE_SCHEDULE_FP8_QUADRANT, + COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE, + ) use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) @@ -637,7 +698,9 @@ def _pick_compute_schedule_kind(): _fp8_wn_pairs = wmma_n_rep // _fp8_pair_wn _fp8_pair_a_loads = _fp8_pair_wm * DS_LOADS_PER_A_FRAG _fp8_pair_b_loads = _fp8_pair_wn * _b_frag_loads_per_wn - _fp8_scale_loads = 0 if is_ptpc else (wmma_m_rep + 3) // 4 + (b_scale_load_rep + 3) // 4 + # Scale ds_loads issued at the loop top: a-scale (0 when natural-VGPR/ptpc) + + # b-scale (bs32_n_load for 32x4). Uses the finalized module-level ds counts. + _fp8_scale_loads = 0 if is_ptpc else (_a_scale_ds + _b_scale_ds) @flyc.kernel(known_block_size=[block_threads, 1, 1]) def kernel_mxscale_gemm( @@ -1669,12 +1732,27 @@ def compute_tile_fp8_deep_pipeline( pf_b_scales=None, ): current_accs = list(accs_in) + if const_expr(use_natural_ascale): + # A-scale from the VGPR ring (loop-prefetched, else inline tail load). + if const_expr(pf_a_scales is not None): + _vgpr_scale_box[0] = (pf_a_scales, pf_b_scales) + else: + rocdl.sched_barrier(0) + _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) + if const_expr(use_natural_ascale): + as_buf, as_bases = None, None # A-scale from the VGPR ring, not LDS + else: + as_buf, as_bases = _precompute_scale_lane_bases( + lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a + ) + if const_expr(use_n4k4_bscale): + bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) + else: + bs_buf, bs_bases = _precompute_scale_lane_bases( + lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b + ) def load_a_pair(wm_pair, ks): wm_base = wm_pair * _fp8_pair_wm @@ -1688,16 +1766,6 @@ def load_b_pair(wn_pair, ks): load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) for wn_local in range_constexpr(_fp8_pair_wn) ] - def _load_a_scales(ks): - if const_expr(is_ptpc): - return None # PTPC: scale applied in epilogue, not in K-loop - return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) - - def _load_b_scales(ks): - if const_expr(is_ptpc): - return None # PTPC: scale applied in epilogue, not in K-loop - return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) - def emit_panel_2x2( wm_pair, wn_pair, @@ -1752,8 +1820,7 @@ def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): for ks in range_constexpr(k_wmma_steps): is_last_ks = ks == k_wmma_steps - 1 - a_scales = _load_a_scales(ks) - b_scales = _load_b_scales(ks) + a_scales, b_scales = _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks) scale_pair = (a_scales, b_scales) b0 = load_b_pair(0, ks) diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index 30667f26..f6da0ef5 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -25,6 +25,7 @@ from flydsl.runtime.device import get_rocm_arch # noqa: E402 from kernels.gemm_fp8fp4_gfx1250 import ( # noqa: E402 + _is_fp8_deep_pipeline, compile_mxscale_gemm, compile_ptpc_gemm, use_n4k4_bscale_layout, @@ -460,6 +461,8 @@ def _run_mxscale_gemm_test( ascale_load_path=ascale_load_path, use_scale_opsel=use_scale_opsel, wave_specialized_tdm=wave_specialized_tdm, + num_buffers=num_buffers, + out_dtype=out_dtype, ) if _natural_ascale: # Natural path reads A_scale[M, K//32] straight from VGPRs -- no reshuffle, @@ -942,6 +945,44 @@ def test_mxscale_natural_ascale(data_format, M, tile_m, tile_n, tile_k, m_warp, ) +@pytest.mark.parametrize("ascale_load_path", ["vgpr", "tdm"]) +@pytest.mark.parametrize("data_format", ["fp8", "a8w4"]) +def test_mxscale_deep_pipeline(data_format, ascale_load_path): + # Deep-pipeline (fixed 256x256x128 / nbuf4 / wave-spec): 32x4 B-scale + A-scale + # via natural VGPR (default) or 32x4 TDM. Guard: must hit the deep schedule. + assert _is_fp8_deep_pipeline( + data_format=data_format, + tile_m=256, + tile_n=256, + tile_k=128, + m_warp=2, + n_warp=2, + num_buffers=4, + out_dtype="bf16", + wave_specialized_tdm=True, + use_scale_opsel=False, + fp8_schedule="auto", + ), "config does not hit the deep-pipeline schedule" + _run_mxscale_gemm_test( + data_format, + 256, + 256, + 512, + 256, + 256, + 128, + 2, + 2, + 4, + use_tdm_store=True, + out_dtype="bf16", + wave_specialized_tdm=True, + l2_prefetch_distance=0, + use_scale_opsel=False, + ascale_load_path=ascale_load_path, + ) + + @pytest.mark.parametrize( "M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, cluster_m, cluster_n", [ From 118299991ad241e53c201df0036f3611b25ce919 Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 16 Jun 2026 05:46:16 +0000 Subject: [PATCH 14/16] remove old bscale preshuffle & use_scale_opsel option --- kernels/gemm_fp8fp4_gfx1250.py | 283 +++++++++------------- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 52 +--- 2 files changed, 124 insertions(+), 211 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 88c0437f..1954fbb3 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -97,7 +97,6 @@ def _is_fp8_deep_pipeline( num_buffers, out_dtype, wave_specialized_tdm, - use_scale_opsel, fp8_schedule, ): """Whether this config takes the FP8/A8W4 deep-pipeline schedule (fixed @@ -114,16 +113,10 @@ def _is_fp8_deep_pipeline( and num_buffers == 4 and wave_specialized_tdm and out_dtype == "bf16" - and not use_scale_opsel ) return fp8_schedule == "deep-pipeline" or (fp8_schedule == "auto" and eligible) -def _is_row_major_streaming(wmma_m_rep, wmma_n_rep, n_accs): - # Mirrors _pick_compute_schedule_kind: row-major when a rep is odd or n_accs < 8. - return wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8 - - def use_n4k4_bscale_layout( *, data_format, @@ -134,27 +127,19 @@ def use_n4k4_bscale_layout( n_warp, n, scale_mode="mxscale", - use_scale_opsel=False, num_buffers=2, out_dtype="f32", wave_specialized_tdm=False, fp8_schedule="auto", ): - """B-scale uses the 32x4 `preshuffle_scale` layout on every FP8/A8W4 mxscale - schedule (row-major, quadrant, deep-pipeline); fp4 keeps the legacy layout. (The - gate name keeps "n4k4" until the naming cleanup TODO; extra kwargs are kept for - signature compatibility.)""" + """B-scale uses the 32x4 `preshuffle_scale` layout on every mxscale schedule + (row-major, quadrant, deep-pipeline, fp4 bank-friendly); ptpc has no K-loop + B-scale. The N/tile_n%32 and tile_k%128 requirements are enforced in + compile_fp8fp4_gemm. (The gate name keeps "n4k4" until the naming cleanup TODO; + extra kwargs are kept for signature compatibility.)""" if scale_mode != "mxscale": return False - if data_format not in ("fp8", "a8w4"): - return False - if use_scale_opsel: - return False - if tile_k % 128 != 0: - return False - if n % 32 != 0: - return False - if tile_n % 32 != 0: + if data_format not in ("fp8", "a8w4", "fp4"): return False return True @@ -170,7 +155,6 @@ def use_natural_ascale_vgpr( n, scale_mode="mxscale", ascale_load_path="vgpr", - use_scale_opsel=False, wave_specialized_tdm=False, num_buffers=2, out_dtype="bf16", @@ -178,9 +162,11 @@ def use_natural_ascale_vgpr( ): """Whether A-scale uses the natural (un-reshuffled) buffer_load->VGPR path. - Row-major-streaming (decode) and deep-pipeline: A read straight from runtime - ``A_scale[M, K//32]`` into VGPRs (loop-ahead prefetched), paired with 32x4 B (TDM). - Quadrant keeps legacy/TDM A-scale (its target is A=tdm-M4K4). Requires wave-spec TDM.""" + A read straight from runtime ``A_scale[M, K//32]`` into VGPRs (loop-ahead + prefetched), paired with 32x4 B (TDM). Used by every ws schedule (row-major, + quadrant, fp4 bank-friendly, deep-with-explicit-vgpr). Requires wave-spec TDM.""" + if data_format not in ("fp8", "a8w4", "fp4"): + return False if ascale_load_path != "vgpr": return False if not wave_specialized_tdm: @@ -194,27 +180,9 @@ def use_natural_ascale_vgpr( n_warp=n_warp, n=n, scale_mode=scale_mode, - use_scale_opsel=use_scale_opsel, ): return False - wmma_m_rep = (tile_m // m_warp) // WMMA_M - wmma_n_rep = (tile_n // n_warp) // WMMA_N - if _is_row_major_streaming(wmma_m_rep, wmma_n_rep, wmma_m_rep * wmma_n_rep): - return True - # Deep-pipeline (even rep) also defaults to natural A-scale; quadrant does not. - return _is_fp8_deep_pipeline( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - num_buffers=num_buffers, - out_dtype=out_dtype, - wave_specialized_tdm=wave_specialized_tdm, - use_scale_opsel=use_scale_opsel, - fp8_schedule=fp8_schedule, - ) + return True @functools.lru_cache(maxsize=256) @@ -239,7 +207,6 @@ def compile_fp8fp4_gemm( inst_prefetch: bool = False, wave_specialized_tdm: bool = False, split_k: int = 1, - use_scale_opsel: bool = False, expert_sched_mode: bool = True, atomic_barrier_enable: bool = False, ascale_load_path: str = "auto", @@ -285,7 +252,6 @@ def compile_fp8fp4_gemm( num_buffers=num_buffers, out_dtype=out_dtype, wave_specialized_tdm=wave_specialized_tdm, - use_scale_opsel=use_scale_opsel, fp8_schedule=fp8_schedule, ) else "vgpr" @@ -379,6 +345,14 @@ def compile_fp8fp4_gemm( if warp_tile_n % WMMA_N_EFF != 0: raise ValueError(f"warp_tile_n={warp_tile_n} must be a multiple of {WMMA_N_EFF}") + # mxscale B-scale is always the 32x4 `preshuffle_scale` layout: require N/tile_n a + # multiple of 32 and tile_k a multiple of 128 (no legacy sub-32 fallback). + if scale_mode == "mxscale" and (N % 32 != 0 or tile_n % 32 != 0 or tile_k % 128 != 0): + raise ValueError( + f"mxscale 32x4 B-scale requires N%32==0, tile_n%32==0, tile_k%128==0; " + f"got N={N}, tile_n={tile_n}, tile_k={tile_k}" + ) + if split_k > 1 and use_tdm_store: raise ValueError("split_k > 1 currently requires use_tdm_store=False") @@ -406,7 +380,6 @@ def compile_fp8fp4_gemm( n_warp=n_warp, n=N, scale_mode=scale_mode, - use_scale_opsel=use_scale_opsel, num_buffers=num_buffers, out_dtype=out_dtype, wave_specialized_tdm=wave_specialized_tdm, @@ -444,7 +417,6 @@ def compile_fp8fp4_gemm( n=N, scale_mode=scale_mode, ascale_load_path=ascale_load_path, - use_scale_opsel=use_scale_opsel, wave_specialized_tdm=wave_specialized_tdm, num_buffers=num_buffers, out_dtype=out_dtype, @@ -488,7 +460,6 @@ def compile_fp8fp4_gemm( else: lds_b_scale_bytes = 0 if is_ptpc else tile_n * scale_k_per_tile + _scale_guard_bytes interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile - interleaved_scale_cols_b = b_scale_load_rep * scale_k_per_tile def _align_up(value: int, align: int) -> int: if value % align == 0: @@ -620,13 +591,11 @@ def _align_up(value: int, align: int) -> int: and num_buffers == 4 and wave_specialized_tdm and out_dtype == "bf16" - and not use_scale_opsel ) if fp8_schedule == "deep-pipeline" and not fp8_deep_pipeline_eligible: raise ValueError( "fp8_schedule='deep-pipeline' requires fp8 256x256x128, " - "m_warp=n_warp=2, num_buffers=4, wave_specialized_tdm=True, " - "out_dtype='bf16', and use_scale_opsel=False" + "m_warp=n_warp=2, num_buffers=4, wave_specialized_tdm=True, out_dtype='bf16'" ) def _pick_compute_schedule_kind(): @@ -656,6 +625,7 @@ def _pick_compute_schedule_kind(): COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING, COMPUTE_SCHEDULE_FP8_QUADRANT, COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE, + COMPUTE_SCHEDULE_FP4_COL_BAND, ) use_ws_tdm_split_signal_overlap = ( wave_specialized_tdm @@ -668,7 +638,6 @@ def _pick_compute_schedule_kind(): _bank_half_wm = wmma_m_rep // 2 _bank_half_wn = wmma_n_rep // 2 _bank_group_size = _bank_half_wm * _bank_half_wn - _bank_half_b_scale_rep = b_scale_load_rep // 2 _bank_group_to_row_major = [] for _wm in range(_bank_half_wm): for _wn in range(_bank_half_wn): @@ -872,38 +841,19 @@ def make_desc_as(memref, k_base): ) def make_desc_bs(memref, k_base): - if const_expr(use_n4k4_bscale): - # 32x4: copy this tile's 32-N atoms x K-blocks slice of the - # preshuffled [N//32, (K//128)*128] B-scale tensor. Each row is one - # 32-N atom group; contiguous dim1 = tile_k//128 * 128B atomics. - a_off = blk_n // arith.index(32) - col_off = (k_base // arith.index(WMMA_K)) * arith.index(bs32_atom_bytes) - return _make_tdm_desc( - global_ptr=arg_b_scale, - lds_memref=memref, - global_offset=(a_off, col_off), - tensor_shape=(N // 32, bs32_global_row_stride), - strides=(bs32_global_row_stride, 1), - tile_shape=(bs32_tile_atoms_pad, bs32_lds_row_stride), - elem_bytes=1, - pad_interval=0, - pad_amount=0, - num_warps=tdm_desc_num_warps, - workgroup_mask=b_mcast_mask, - atomic_barrier_enable=atomic_barrier_enable, - early_timeout=True, - oob_outer_bound=N // 32, - ) - k_scale_off = k_base // arith.index(SCALE_BLOCK) - outer_off = blk_n // arith.index(b_scale_load_rep) - inner_off = k_scale_off * arith.index(b_scale_load_rep) + # 32x4: copy this tile's 32-N atoms x K-blocks slice of the preshuffled + # [N//32, (K//128)*128] B-scale tensor. Each row is one 32-N atom group; + # contiguous dim1 = tile_k//128 * 128B atomics. (mxscale only; ptpc never + # calls this -- its desc_bs is a placeholder.) + a_off = blk_n // arith.index(32) + col_off = (k_base // arith.index(WMMA_K)) * arith.index(bs32_atom_bytes) return _make_tdm_desc( global_ptr=arg_b_scale, lds_memref=memref, - global_offset=(outer_off, inner_off), - tensor_shape=(WMMA_M * n_warp, interleaved_scale_cols_b), - strides=(b_scale_load_rep * K_scale, 1), - tile_shape=(WMMA_M * n_warp, interleaved_scale_cols_b), + global_offset=(a_off, col_off), + tensor_shape=(N // 32, bs32_global_row_stride), + strides=(bs32_global_row_stride, 1), + tile_shape=(bs32_tile_atoms_pad, bs32_lds_row_stride), elem_bytes=1, pad_interval=0, pad_amount=0, @@ -911,6 +861,7 @@ def make_desc_bs(memref, k_base): workgroup_mask=b_mcast_mask, atomic_barrier_enable=atomic_barrier_enable, early_timeout=True, + oob_outer_bound=N // 32, ) if const_expr(wave_specialized_tdm): @@ -1046,10 +997,6 @@ def _precompute_scale_lane_bases(lds_ptr, warp_base, reps, interleaved_cols): if const_expr(is_fp4 or is_a8w4): # FP4/A8W4: always add lane_kgrp offset (no opsel on BScale) base = base + lane_kgrp * arith.index(SCALES_PER_WMMA) - else: - # FP8: conditional on opsel - if const_expr(use_scale_opsel): - base = base + lane_kgrp * arith.index(SCALES_PER_WMMA) return lds_ptr, [base] def _precompute_bs32_bases(lds_ptr): @@ -1105,31 +1052,15 @@ def load_scale_b128(lds_buffer, scale_base, reps, ks=0): results.append(vecs[i // 4][i % 4]) return results - def load_scale_slice_b128(lds_buffer, scale_base, full_reps, rep_start, rep_count, ks=0): - """Load a contiguous slice of packed scale VGPRs for one K-subtile.""" - ks_byte_off = (ks * full_reps + rep_start) * SCALES_PER_WMMA - eff_base = scale_base if ks_byte_off == 0 else scale_base + arith.index(ks_byte_off) - num_loads = (rep_count + 3) // 4 - vecs = [] - for ld in range_constexpr(num_loads): - off = eff_base if ld == 0 else eff_base + arith.index(ld * 16) - vecs.append(fx.Vector(lds_load_b128_raw(lds_buffer, off))) - results = [] - for i in range_constexpr(rep_count): - results.append(vecs[i // 4][i % 4]) - return results - # Holds the current tile's prefetched VGPR scales (a_flat, b_flat), each # ordered [k_wmma_step][rep]. compute_tile sets it before emitting; the # general-vgpr branch of _scales_for_emit slices it per K-subtile. Set-then- # consume is sequential at emit time (same pattern as epi_addrs_box). _vgpr_scale_box = [None] - def _load_b_scale_lds(bs_buf, bs_bases, ks): - """Load B-scale from LDS, dispatching to the 32x4 or legacy layout.""" - if const_expr(use_n4k4_bscale): - return load_bs32_bscale(bs_buf, bs_bases, ks) - return load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) + def _load_b_scale_lds(bs_buf, bs_atom0, ks): + """Load 32x4 B-scale from LDS (mxscale only; ptpc reads no K-loop B-scale).""" + return load_bs32_bscale(bs_buf, bs_atom0, ks) def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): """Load both scale tensors and apply op_sel downsampling per format. @@ -1146,13 +1077,8 @@ def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): a = pf_a[ks * nat_as_load : (ks + 1) * nat_as_load] b = _load_b_scale_lds(bs_buf, bs_bases, ks) return a, b - a_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) - b_all = _load_b_scale_lds(bs_buf, bs_bases, ks) - if const_expr(use_scale_opsel): - a = a_all[::2] - b = b_all if const_expr(is_fp4) else b_all[::2] - else: - a, b = a_all, b_all + a = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + b = _load_b_scale_lds(bs_buf, bs_bases, ks) return a, b def _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, ks): @@ -1197,10 +1123,7 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): fmtB=0, ) return - if const_expr(use_scale_opsel): - a_scale_idx = wm // 2 - a_opsel = wm % 2 - elif const_expr(use_natural_ascale_opsel): + if const_expr(use_natural_ascale_opsel): # Natural A M op_sel pairs (j, j+rep/2): kgrp1 carries the second half. a_scale_idx = wm % nat_as_half a_opsel = wm // nat_as_half @@ -1209,22 +1132,23 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): a_opsel = 0 if const_expr(is_fp4): - # 32x16 WMMA with A/B swap: SRC0=B, SRC1=A + # 32x16 WMMA with A/B swap: SRC0=B, SRC1=A. 32x4 reads one 32-N atom + # per WMMA (idx wn). accs[idx] = rocdl.wmma_scale_f32_32x16x128_f4( T.vec(16, T.f32), b_frag, a_frag, accs[idx], - b_scales[wn * 2], + b_scales[wn], a_scales[a_scale_idx], scaleAType=0, scaleBType=a_opsel, ) else: # 16x16x128 WMMA: A8W4 (fmtA=FP4) or FP8 (fmtA=FP8). op_sel pairs - # adjacent 16-N halves (legacy use_scale_opsel or 32x4 even rep); - # else one scale per WMMA (32x4 odd rep, or no op_sel). - if const_expr(use_scale_opsel or bs32_opsel): + # adjacent 16-N halves (32x4 even rep); else one scale per WMMA + # (32x4 odd rep, or no op_sel). + if const_expr(bs32_opsel): b_scale_idx = wn // 2 b_opsel = wn % 2 else: @@ -1342,9 +1266,7 @@ def compute_tile( if const_expr(use_n4k4_bscale): bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) else: - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) + bs_buf, bs_bases = lds_bs, None # ptpc: B-scale in epilogue, bases unused if const_expr(k_wmma_steps == 1): b_frags, b_scales, a_scales = _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, 0) @@ -1387,19 +1309,33 @@ def compute_tile_fp4_bank_friendly( lds_bs, emit_filler=None, mid_compute_callback=None, + scale_k_base=None, + pf_a_scales=None, + pf_b_scales=None, ): current_accs = list(accs_in) + if const_expr(use_natural_ascale): + # A-scale from the VGPR ring (loop-prefetched, else inline tail load). + if const_expr(pf_a_scales is not None): + _vgpr_scale_box[0] = (pf_a_scales, pf_b_scales) + else: + rocdl.sched_barrier(0) + _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) - _b_half_scale_loads = (_bank_half_b_scale_rep + 3) // 4 + if const_expr(use_natural_ascale): + as_buf, as_bases = None, None # A-scale from the VGPR ring, not LDS + else: + as_buf, as_bases = _precompute_scale_lane_bases( + lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a + ) + bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) + _b_half_scale_loads = _bank_half_wn # 32x4: one b32 per 32-N atom/WMMA def _fp4_get_a_scale_and_opsel(a_scales_all, wm_idx): - if const_expr(use_scale_opsel): - return a_scales_all[(wm_idx // 2) * 2], wm_idx % 2 + if const_expr(use_natural_ascale_opsel): + # Natural M op_sel pairs (wm, wm+rep/2): kgrp1 carries the second half. + return a_scales_all[wm_idx % nat_as_half], wm_idx // nat_as_half return a_scales_all[wm_idx], 0 def _load_a_group(wm_base, wm_count, ks): @@ -1410,11 +1346,19 @@ def _load_b_half(wn_base, ks): load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) for wn_local in range_constexpr(_bank_half_wn) ] - def _load_b_half_bundle(wn_base, rep_start, ks): + def _load_bs32_b_half(atom0, wn_base, ks): + # 32x4: load this N-half's atoms, one ds_load_b32 per 32-N WMMA (no op_sel). + stride = arith.index(bs32_lds_row_stride) + ks_off = arith.index(ks * bs32_atom_bytes) + lane = (lane_kgrp * arith.index(16) + lane16) * arith.index(4) + return [ + lds_load_b32_raw(bs_buf, (atom0 + arith.index(wn_base + wn_local)) * stride + ks_off + lane) + for wn_local in range_constexpr(_bank_half_wn) + ] + + def _load_b_half_bundle(wn_base, ks): b_frags = _load_b_half(wn_base, ks) - b_scales = load_scale_slice_b128( - bs_buf, bs_bases[0], b_scale_load_rep, rep_start, _bank_half_b_scale_rep, ks - ) + b_scales = _load_bs32_b_half(bs_bases, wn_base, ks) return b_frags, b_scales def _emit_group_rows( @@ -1436,7 +1380,7 @@ def _emit_group_rows( b_frags[wn_local], a_frag, current_accs[idx], - b_scales[wn_local * 2], + b_scales[wn_local], a_scale, scaleAType=0, scaleBType=a_opsel, @@ -1455,11 +1399,15 @@ def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_ emit_filler_now=emit_filler_now, ) - b_left_frags, b_left_scales = _load_b_half_bundle(0, 0, 0) + b_left_frags, b_left_scales = _load_b_half_bundle(0, 0) for ks in range_constexpr(k_wmma_steps): is_last_ks = ks == k_wmma_steps - 1 - a_scales_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + if const_expr(use_natural_ascale): + pf_a, _ = _vgpr_scale_box[0] + a_scales_all = pf_a[ks * nat_as_load : (ks + 1) * nat_as_load] + else: + a_scales_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) a_top_frags = _load_a_group(0, _bank_half_wm, ks) a_bottom_frags = _load_a_group(_bank_half_wm, _bank_half_wm, ks) @@ -1480,7 +1428,7 @@ def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_ rocdl.sched_barrier(0) mid_compute_callback() - b_right_frags, b_right_scales = _load_b_half_bundle(_bank_half_wn, _bank_half_b_scale_rep, ks) + b_right_frags, b_right_scales = _load_b_half_bundle(_bank_half_wn, ks) # Hold only the next B half outstanding while the second # quadrant consumes the current left-half fragments. @@ -1496,7 +1444,7 @@ def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_ ) if const_expr(not is_last_ks): - next_left_frags, next_left_scales = _load_b_half_bundle(0, 0, ks + 1) + next_left_frags, next_left_scales = _load_b_half_bundle(0, ks + 1) # Older right-half loads must be ready before consuming # them, while the next ks left-half preload can remain in # flight under the final two quadrants. @@ -1537,17 +1485,30 @@ def compute_tile_fp8_quadrant( emit_filler=None, mid_compute_callback=None, late_compute_callback=None, + scale_k_base=None, + pf_a_scales=None, + pf_b_scales=None, ): current_accs = list(accs_in) + if const_expr(use_natural_ascale): + # A-scale from the VGPR ring (loop-prefetched, else inline tail load). + if const_expr(pf_a_scales is not None): + _vgpr_scale_box[0] = (pf_a_scales, pf_b_scales) + else: + rocdl.sched_barrier(0) + _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - as_buf, as_bases = _precompute_scale_lane_bases(lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a) + if const_expr(use_natural_ascale): + as_buf, as_bases = None, None # A-scale from the VGPR ring, not LDS + else: + as_buf, as_bases = _precompute_scale_lane_bases( + lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a + ) if const_expr(use_n4k4_bscale): bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) else: - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) + bs_buf, bs_bases = lds_bs, None # ptpc: B-scale in epilogue, bases unused _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn _b_left_bundle_loads = _b_half_loads + _fp8_b_scale_loads @@ -1562,20 +1523,16 @@ def _load_b_half(wn_base, ks): def _load_a_scales(ks): if const_expr(is_ptpc): return None # PTPC: scale applied in epilogue, not in K-loop - a_scales = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) - if const_expr(use_scale_opsel): - return a_scales[::2] - return a_scales + if const_expr(use_natural_ascale): + # A from the VGPR ring (slice this ks); M op_sel via nat_as_half in _emit. + pf_a, _ = _vgpr_scale_box[0] + return pf_a[ks * nat_as_load : (ks + 1) * nat_as_load] + return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) def _load_b_scales(ks): if const_expr(is_ptpc): return None # PTPC: scale applied in epilogue, not in K-loop - if const_expr(use_n4k4_bscale): - return _load_b_scale_lds(bs_buf, bs_bases, ks) # op_sel in _emit_wmma - b_scales = load_scale_b128(bs_buf, bs_bases[0], b_scale_load_rep, ks) - if const_expr(use_scale_opsel): - return b_scales[::2] - return b_scales + return _load_b_scale_lds(bs_buf, bs_bases, ks) # 32x4; op_sel in _emit_wmma def _load_b_left_bundle(ks): return _load_b_half(0, ks), _load_b_scales(ks) @@ -1750,9 +1707,7 @@ def compute_tile_fp8_deep_pipeline( if const_expr(use_n4k4_bscale): bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) else: - bs_buf, bs_bases = _precompute_scale_lane_bases( - lds_bs, warp_n_base, b_scale_load_rep, interleaved_scale_cols_b - ) + bs_buf, bs_bases = lds_bs, None # ptpc: B-scale in epilogue, bases unused def load_a_pair(wm_pair, ks): wm_base = wm_pair * _fp8_pair_wm @@ -1918,9 +1873,9 @@ def hot_loop_scheduler(): def hot_loop_scheduler_fp4_bank_friendly(): _a_all_loads = wmma_m_rep * DS_LOADS_PER_A_FRAG - _a_scale_loads = (wmma_m_rep + 3) // 4 + _a_scale_loads = 0 if use_natural_ascale else (wmma_m_rep + 3) // 4 _b_half_loads = _bank_half_wn * 4 - _b_half_scale_loads = (_bank_half_b_scale_rep + 3) // 4 + _b_half_scale_loads = _bank_half_wn # 32x4: one b32 per 32-N atom/WMMA _group_wmma = _bank_group_size _right_half_loads = _b_half_loads + _b_half_scale_loads @@ -1939,7 +1894,7 @@ def hot_loop_scheduler_fp4_bank_friendly(): rocdl.sched_barrier(0) def hot_loop_scheduler_fp8_quadrant(): - _a_scale_loads = 0 if is_ptpc else (wmma_m_rep + 3) // 4 + _a_scale_loads = 0 if (is_ptpc or use_natural_ascale) else (wmma_m_rep + 3) // 4 _a_top_loads = _fp8_half_wm * DS_LOADS_PER_A_FRAG _a_bottom_loads = _a_top_loads _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn @@ -2028,6 +1983,9 @@ def compute_tile_scheduled( lds_bs, emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, + scale_k_base=scale_k_base, + pf_a_scales=pf_a_scales, + pf_b_scales=pf_b_scales, ) if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT): return compute_tile_fp8_quadrant( @@ -2039,6 +1997,9 @@ def compute_tile_scheduled( emit_filler=emit_filler, mid_compute_callback=mid_compute_callback, late_compute_callback=late_compute_callback, + scale_k_base=scale_k_base, + pf_a_scales=pf_a_scales, + pf_b_scales=pf_b_scales, ) if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE): return compute_tile_fp8_deep_pipeline( @@ -2915,7 +2876,6 @@ def _emit_buffer_store(): inst_prefetch, wave_specialized_tdm, split_k, - use_scale_opsel, expert_sched_mode, atomic_barrier_enable, ascale_load_path, @@ -3031,7 +2991,6 @@ def compile_ptpc_gemm( data_format=data_format, scale_mode="ptpc", wave_specialized_tdm=True, - use_scale_opsel=False, fp8_schedule="auto", use_tdm_store=(split_k == 1), N=N, diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index f6da0ef5..cb689ad2 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -361,7 +361,6 @@ def _run_mxscale_gemm_test( use_tdm_store, out_dtype, wave_specialized_tdm=False, - use_scale_opsel=False, l2_prefetch_distance=0, cluster_m=1, cluster_n=1, @@ -380,9 +379,6 @@ def _run_mxscale_gemm_test( if arch != "gfx1250": pytest.skip(f"WMMA_SCALE requires gfx1250, got {arch}") - if use_scale_opsel and is_fp4: - pytest.skip("FP4 32x16 WMMA scaleBType op_sel ignored by AM simulator") - if K % SCALE_BLOCK != 0: pytest.skip(f"K={K} must be divisible by SCALE_BLOCK={SCALE_BLOCK}") @@ -449,7 +445,6 @@ def _run_mxscale_gemm_test( # Preshuffle scales skt = tile_k // SCALE_BLOCK warp_tile_m = tile_m // m_warp - warp_tile_n = tile_n // n_warp _natural_ascale = use_natural_ascale_vgpr( data_format=data_format, tile_m=tile_m, @@ -459,7 +454,6 @@ def _run_mxscale_gemm_test( n_warp=n_warp, n=padded_n, ascale_load_path=ascale_load_path, - use_scale_opsel=use_scale_opsel, wave_specialized_tdm=wave_specialized_tdm, num_buffers=num_buffers, out_dtype=out_dtype, @@ -470,22 +464,7 @@ def _run_mxscale_gemm_test( pass else: a_scale = preshuffle_scale_for_load_path(a_scale, warp_tile_m, skt, row_align=tile_m) - if use_n4k4_bscale_layout( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - n=padded_n, - use_scale_opsel=use_scale_opsel, - num_buffers=num_buffers, - out_dtype=out_dtype, - wave_specialized_tdm=wave_specialized_tdm, - ): - b_scale = preshuffle_scale(b_scale) - else: - b_scale = preshuffle_scale_for_load_path(b_scale, warp_tile_n, skt, row_align=tile_n) + b_scale = preshuffle_scale(b_scale) # Preshuffle B data K_packed = padded_k // padded_shape["pack_b"] @@ -517,7 +496,6 @@ def _run_mxscale_gemm_test( inst_prefetch=inst_prefetch, wave_specialized_tdm=wave_specialized_tdm, split_k=split_k, - use_scale_opsel=use_scale_opsel, expert_sched_mode=expert_sched_mode, ascale_load_path=ascale_load_path, ) @@ -638,7 +616,6 @@ def _extract_i64_metadata(compiled_ir: str, key: str) -> int: @pytest.mark.parametrize("num_buffers", [2, 3, 4]) @pytest.mark.parametrize("use_tdm_store", [True, False]) @pytest.mark.parametrize("wave_specialized_tdm", [True, False]) -@pytest.mark.parametrize("use_scale_opsel", [True, False]) @pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) def test_mxfp4_gemm( M, @@ -653,7 +630,6 @@ def test_mxfp4_gemm( use_tdm_store, out_dtype, wave_specialized_tdm, - use_scale_opsel, ): _run_mxscale_gemm_test( "fp4", @@ -669,7 +645,6 @@ def test_mxfp4_gemm( use_tdm_store, out_dtype, wave_specialized_tdm=wave_specialized_tdm, - use_scale_opsel=use_scale_opsel, ) @@ -683,7 +658,6 @@ def test_mxfp4_gemm( ) @pytest.mark.parametrize("num_buffers", [2, 3]) @pytest.mark.parametrize("use_tdm_store", [True, False]) -@pytest.mark.parametrize("use_scale_opsel", [True, False]) @pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) def test_mxfp8_gemm( M, @@ -697,7 +671,6 @@ def test_mxfp8_gemm( num_buffers, use_tdm_store, out_dtype, - use_scale_opsel, ): _run_mxscale_gemm_test( "fp8", @@ -713,7 +686,6 @@ def test_mxfp8_gemm( use_tdm_store, out_dtype, l2_prefetch_distance=2, - use_scale_opsel=use_scale_opsel, ) @@ -753,11 +725,8 @@ def test_mxfp8_gemm_splitk(split_k, out_dtype): ) @pytest.mark.parametrize("num_buffers", [2, 3]) @pytest.mark.parametrize("use_tdm_store", [True, False]) -@pytest.mark.parametrize("use_scale_opsel", [True, False]) @pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) -def test_a8w4_gemm( - M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, use_tdm_store, out_dtype, use_scale_opsel -): +def test_a8w4_gemm(M, N, K, tile_m, tile_n, tile_k, m_warp, n_warp, num_buffers, use_tdm_store, out_dtype): _run_mxscale_gemm_test( "a8w4", M, @@ -772,7 +741,6 @@ def test_a8w4_gemm( use_tdm_store, out_dtype, l2_prefetch_distance=2, - use_scale_opsel=use_scale_opsel, ) @@ -799,7 +767,6 @@ def test_a8w4_gemm_irregular_m_tile16(M, N, K, use_tdm_store): use_tdm_store=use_tdm_store, out_dtype="bf16", l2_prefetch_distance=2, - use_scale_opsel=False, ) @@ -877,7 +844,6 @@ def test_mxscale_n4k4_bscale(data_format, M, N, K, tile_n, tile_k, n_warp, num_b out_dtype=out_dtype, wave_specialized_tdm=ws, l2_prefetch_distance=0, - use_scale_opsel=False, ) @@ -941,7 +907,6 @@ def test_mxscale_natural_ascale(data_format, M, tile_m, tile_n, tile_k, m_warp, out_dtype="bf16", wave_specialized_tdm=True, l2_prefetch_distance=0, - use_scale_opsel=False, ) @@ -960,7 +925,6 @@ def test_mxscale_deep_pipeline(data_format, ascale_load_path): num_buffers=4, out_dtype="bf16", wave_specialized_tdm=True, - use_scale_opsel=False, fp8_schedule="auto", ), "config does not hit the deep-pipeline schedule" _run_mxscale_gemm_test( @@ -978,7 +942,6 @@ def test_mxscale_deep_pipeline(data_format, ascale_load_path): out_dtype="bf16", wave_specialized_tdm=True, l2_prefetch_distance=0, - use_scale_opsel=False, ascale_load_path=ascale_load_path, ) @@ -1838,10 +1801,7 @@ def _run_benchmark(args): if needs_pad: print(f" Kernel pad: M={padded_m}, N={padded_n}, K={padded_k}") print(f" Tile: ({tile_m}, {tile_n}, {tile_k}), warps=({args.m_warp}x{args.n_warp})") - print( - f" Buffers={args.num_buffers}, out={args.out_dtype}, " - f"opsel={args.use_scale_opsel}, inst_prefetch={args.inst_prefetch}" - ) + print(f" Buffers={args.num_buffers}, out={args.out_dtype}, " f"inst_prefetch={args.inst_prefetch}") if args.warmup < 0: raise ValueError(f"--warmup must be >= 0, got {args.warmup}") if args.iters <= 0: @@ -1867,8 +1827,6 @@ def _run_benchmark(args): _ptpc_ignored.append("--no-tdm-store") if not args.wave_spec_tdm: _ptpc_ignored.append("--no-wave-spec-tdm") - if args.use_scale_opsel: - _ptpc_ignored.append("--use-scale-opsel") if _ptpc_ignored: print(f" Note: PTPC ignores (forced internally): {', '.join(_ptpc_ignored)}") print("=" * 72) @@ -1978,7 +1936,6 @@ def _run_benchmark(args): inst_prefetch=args.inst_prefetch, wave_specialized_tdm=args.wave_spec_tdm, split_k=args.split_k, - use_scale_opsel=args.use_scale_opsel, expert_sched_mode=args.expert_sched_mode, atomic_barrier_enable=args.atomic_barrier_enable, ) @@ -2231,7 +2188,6 @@ def _run_graph_verify(args): inst_prefetch=args.inst_prefetch, wave_specialized_tdm=args.wave_spec_tdm, split_k=args.split_k, - use_scale_opsel=args.use_scale_opsel, expert_sched_mode=args.expert_sched_mode, atomic_barrier_enable=args.atomic_barrier_enable, ) @@ -2344,7 +2300,6 @@ def launch(): parser.add_argument("--inst-prefetch", action="store_true", default=False) parser.add_argument("--no-wave-spec-tdm", dest="wave_spec_tdm", action="store_false", default=True) parser.add_argument("--waves-per-eu", type=int, default=None) - parser.add_argument("--use-scale-opsel", action="store_true", default=False) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument( "--atomic-barrier-enable", @@ -2443,7 +2398,6 @@ def _run_correctness_test(): out_dtype=args.out_dtype, wave_specialized_tdm=args.wave_spec_tdm, split_k=args.split_k, - use_scale_opsel=args.use_scale_opsel, l2_prefetch_distance=args.l2_prefetch_distance, cluster_m=args.cluster_m, cluster_n=args.cluster_n, From d78de457e74a5466352b6ccc703cd2be2049342b Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 16 Jun 2026 09:39:50 +0000 Subject: [PATCH 15/16] gfx1250 mxscale GEMM: unify scale paths (A=vgpr, B=32x4), drop legacy paths - B-scale: every scheduler (row-major / quadrant / deep / fp4 bank-friendly) uses the 32x4 preshuffle_scale layout; op_sel on even-rep fp8/a8w4. - A-scale: always the natural buffer_load->VGPR ring (M op_sel for pow2 rep); the LDS scaffolding is kept for a future TDM->LDS A-scale path. - Drop legacy paths: use_scale_opsel, legacy B preshuffle, the nws cooperative TDM loop, and A-tdm. TDM is always wave-specialized (requires num_warps>=2). - ptpc-fp8 uses the dedicated no-scale WMMA (wmma_f32_16x16x128_fp8_fp8); fix the ptpc row-major s_wait_dscnt scale-ds over-count. - Consolidate predicates to is_mxscale; drop "n4k4"/"natural" naming; remove dead code (lds_load_b64_raw, _load_a_and_scales). Verified: ptpc/a8w4 and the unchanged paths stay ISA byte-identical, ptpc-fp8 ISA change is the intended no-scale WMMA; mxscale schedulers pass offline parity + GPU cosine across schedulers, tiles, and dtypes. --- kernels/gemm_common_gfx1250.py | 14 +- kernels/gemm_fp8fp4_gfx1250.py | 1007 ++++++--------------- tests/kernels/test_gemm_fp8fp4_gfx1250.py | 188 +--- 3 files changed, 339 insertions(+), 870 deletions(-) diff --git a/kernels/gemm_common_gfx1250.py b/kernels/gemm_common_gfx1250.py index 863fea77..9d9eb5ae 100644 --- a/kernels/gemm_common_gfx1250.py +++ b/kernels/gemm_common_gfx1250.py @@ -98,24 +98,12 @@ def lds_load_b32_raw(lds_base_idx, byte_offset): Unlike :func:`lds_load_b128_raw`, this only requires 4-byte alignment, so it suits scale layouts where consumed words sit at 4-byte (not 16-byte) granular - offsets (e.g. the N4K4 B-scale layout's per-N-block reads). + offsets (e.g. the 32x4 B-scale layout's one-i32-per-atom reads). """ ptr_val = _raw_lds_ptr(lds_base_idx, byte_offset) return llvm_dialect.load(ir.IntegerType.get_signless(32), ptr_val) -def lds_load_b64_raw(lds_base_idx, byte_offset): - """Load 8 bytes (``vector<2xi32>``) from LDS using a pre-extracted base index. - - Requires 8-byte alignment. Sits between :func:`lds_load_b32_raw` and - :func:`lds_load_b128_raw` for layouts whose contiguous read width is 2 words - (e.g. the N4K4 B-scale layout when ``wmma_n_rep`` is even but not a multiple - of 4, where each aligned batch covers exactly 2 N-blocks). - """ - ptr_val = _raw_lds_ptr(lds_base_idx, byte_offset) - return llvm_dialect.load(ir.VectorType.get([2], ir.IntegerType.get_signless(32)), ptr_val) - - def lds_transpose_load_raw(result_type, lds_base_idx, byte_offset): """Transpose-load 16 bytes from LDS using a pre-extracted base index.""" from flydsl._mlir.dialects import rocdl as _rocdl diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 1954fbb3..99d64b33 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -23,7 +23,6 @@ from kernels.gemm_common_gfx1250 import ( extract_lds_base_idx, get_lds_memref, - issue_tdm_loads, lds_load_b32_raw, lds_load_b128_raw, pipeline_fence, @@ -96,7 +95,6 @@ def _is_fp8_deep_pipeline( n_warp, num_buffers, out_dtype, - wave_specialized_tdm, fp8_schedule, ): """Whether this config takes the FP8/A8W4 deep-pipeline schedule (fixed @@ -111,80 +109,11 @@ def _is_fp8_deep_pipeline( and m_warp == 2 and n_warp == 2 and num_buffers == 4 - and wave_specialized_tdm and out_dtype == "bf16" ) return fp8_schedule == "deep-pipeline" or (fp8_schedule == "auto" and eligible) -def use_n4k4_bscale_layout( - *, - data_format, - tile_m, - tile_n, - tile_k, - m_warp, - n_warp, - n, - scale_mode="mxscale", - num_buffers=2, - out_dtype="f32", - wave_specialized_tdm=False, - fp8_schedule="auto", -): - """B-scale uses the 32x4 `preshuffle_scale` layout on every mxscale schedule - (row-major, quadrant, deep-pipeline, fp4 bank-friendly); ptpc has no K-loop - B-scale. The N/tile_n%32 and tile_k%128 requirements are enforced in - compile_fp8fp4_gemm. (The gate name keeps "n4k4" until the naming cleanup TODO; - extra kwargs are kept for signature compatibility.)""" - if scale_mode != "mxscale": - return False - if data_format not in ("fp8", "a8w4", "fp4"): - return False - return True - - -def use_natural_ascale_vgpr( - *, - data_format, - tile_m, - tile_n, - tile_k, - m_warp, - n_warp, - n, - scale_mode="mxscale", - ascale_load_path="vgpr", - wave_specialized_tdm=False, - num_buffers=2, - out_dtype="bf16", - fp8_schedule="auto", -): - """Whether A-scale uses the natural (un-reshuffled) buffer_load->VGPR path. - - A read straight from runtime ``A_scale[M, K//32]`` into VGPRs (loop-ahead - prefetched), paired with 32x4 B (TDM). Used by every ws schedule (row-major, - quadrant, fp4 bank-friendly, deep-with-explicit-vgpr). Requires wave-spec TDM.""" - if data_format not in ("fp8", "a8w4", "fp4"): - return False - if ascale_load_path != "vgpr": - return False - if not wave_specialized_tdm: - return False - if not use_n4k4_bscale_layout( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - n=n, - scale_mode=scale_mode, - ): - return False - return True - - @functools.lru_cache(maxsize=256) def compile_fp8fp4_gemm( *, @@ -205,11 +134,9 @@ def compile_fp8fp4_gemm( use_tdm_store: bool = True, out_dtype: str = "f32", inst_prefetch: bool = False, - wave_specialized_tdm: bool = False, split_k: int = 1, expert_sched_mode: bool = True, atomic_barrier_enable: bool = False, - ascale_load_path: str = "auto", fp8_schedule: str = "auto", ): """Compile an FP4/FP8/A8W4 GEMM kernel with TDM async copy. @@ -236,27 +163,6 @@ def compile_fp8fp4_gemm( if scale_mode == "ptpc" and data_format not in ("fp8", "a8w4"): raise ValueError("scale_mode='ptpc' currently only supports data_format='fp8' or 'a8w4'") - # Deep-pipeline defaults to TDM A-scale: the natural A-scale VGPR ring adds - # register pressure that spills on mxfp8 deep. vgpr is the default everywhere - # else and stays explicitly selectable for deep until that spill is resolved. - if ascale_load_path == "auto": - ascale_load_path = ( - "tdm" - if _is_fp8_deep_pipeline( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - num_buffers=num_buffers, - out_dtype=out_dtype, - wave_specialized_tdm=wave_specialized_tdm, - fp8_schedule=fp8_schedule, - ) - else "vgpr" - ) - is_fp4 = data_format == "fp4" is_a8w4 = data_format == "a8w4" is_ptpc = scale_mode == "ptpc" @@ -287,14 +193,6 @@ def compile_fp8fp4_gemm( if block_threads > 1024: raise ValueError(f"block_threads must be <= 1024, got {block_threads}") - # Wave-specialized TDM dedicates one loader wave per TDM tensor. - # Scales bypass TDM (no dedicated loader waves) for ptpc or the buffer->VGPR - # scale path, leaving only A + B -> 2 waves; otherwise A + B + A_scale + - # B_scale -> 4 waves. - _drop_scale_loader_waves = is_ptpc - # Min loader-wave check is finalized after use_natural_ascale is known (natural - # A-scale frees its wave, allowing >=2; see below). - # ── Format-dependent compile-time constants ── # A8W4: activation is FP8 (PACK_FACTOR_A=1), weight is FP4 (PACK_FACTOR_B=2) if is_a8w4: @@ -371,27 +269,16 @@ def compile_fp8fp4_gemm( # FP4 A/B swap: BScale rep derived from WMMA_M, not WMMA_N_EFF b_scale_load_rep = warp_tile_n // WMMA_M if is_fp4 else wmma_n_rep - use_n4k4_bscale = use_n4k4_bscale_layout( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - n=N, - scale_mode=scale_mode, - num_buffers=num_buffers, - out_dtype=out_dtype, - wave_specialized_tdm=wave_specialized_tdm, - fp8_schedule=fp8_schedule, - ) + # mxscale carries per-K-block scales (A=buffer_load->VGPR ring, B=32x4 TDM); + # ptpc has no K-loop scale (per-token/per-channel fp32 applied in the epilogue). + is_mxscale = not is_ptpc # 32x4 B-scale layout (preshuffle_scale): [N//32, K//128, 32, 4]. A 128B atomic # (32 N-rows x 4 K-scales) = one 32-lane WMMA scale VGPR. op_sel pairs an atom's # two 16-N halves into one b32 load (1 load -> 2 WMMAs); enabled for every even- # rep FP8/A8W4 schedule. Odd rep / warp_tile_n=16 owns half an atom (runtime # 16-half select) and fp4 is already 1 atom = 1 WMMA, so both stay op_sel=0. bs32_opsel = False - if use_n4k4_bscale: + if is_mxscale: bs32_atom_bytes = 128 # 32 rows x 4 K-scales bs32_global_row_stride = (K // WMMA_K) * bs32_atom_bytes # bytes per atom row (= K) bs32_lds_row_stride = k_wmma_steps * bs32_atom_bytes # LDS bytes per atom row @@ -402,47 +289,26 @@ def compile_fp8fp4_gemm( bs32_opsel = (not is_fp4) and (wmma_n_rep % 2 == 0) bs32_n_load = (wmma_n_rep // 2) if bs32_opsel else wmma_n_rep # b32 loads per ks - # A-scale natural buffer_load->VGPR (no reshuffle), paired with N4K4 B (TDM). - # TEMP: ascale_load_path='tdm' fallback + non-ws kept until those legacy paths - # are retired; the natural vgpr path is the target. - if ascale_load_path not in ("vgpr", "tdm"): - raise ValueError(f"ascale_load_path must be 'vgpr' or 'tdm', got {ascale_load_path!r}") - use_natural_ascale = use_natural_ascale_vgpr( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - n=N, - scale_mode=scale_mode, - ascale_load_path=ascale_load_path, - wave_specialized_tdm=wave_specialized_tdm, - num_buffers=num_buffers, - out_dtype=out_dtype, - fp8_schedule=fp8_schedule, - ) - # M op_sel pairs A-blocks (wm, wm+rep/2) into one VGPR via lane_kgrp; power-of-2 rep. - use_natural_ascale_opsel = use_natural_ascale and wmma_m_rep >= 2 and (wmma_m_rep & (wmma_m_rep - 1)) == 0 - nat_as_half = wmma_m_rep // 2 - nat_as_load = nat_as_half if use_natural_ascale_opsel else wmma_m_rep - # Natural TDM tensors = {A-data, B-data, B-scale}; at exactly 2 waves wave0 also - # issues B-scale (secondary), so the natural path needs only >=2 loader waves. - natural_two_wave = use_natural_ascale and num_warps == 2 - - _min_wave_spec_warps = 2 if (_drop_scale_loader_waves or use_natural_ascale) else 4 - if wave_specialized_tdm and num_warps < _min_wave_spec_warps: - raise ValueError(f"wave_specialized_tdm requires at least {_min_wave_spec_warps} waves, got {num_warps}") + # A-scale M op_sel pairs A-blocks (wm, wm+rep/2) into one VGPR via lane_kgrp; power-of-2 rep. + ascale_opsel = is_mxscale and wmma_m_rep >= 2 and (wmma_m_rep & (wmma_m_rep - 1)) == 0 + ascale_half = wmma_m_rep // 2 + ascale_load = ascale_half if ascale_opsel else wmma_m_rep + # TDM tensors = {A-data, B-data, B-scale}; at exactly 2 waves wave0 also + # issues B-scale (secondary), so this needs only >=2 loader waves. + two_wave_bscale = is_mxscale and num_warps == 2 + + # (mxscale = {A-data, B-data, B-scale}; ptpc = {A-data, B-data}). + if num_warps < 2: + raise ValueError(f"wave-specialized TDM requires at least 2 waves, got {num_warps}") _b_frag_loads_per_wn = 2 if is_a8w4 else 4 _a_frag_loads_per_wm = 2 if is_fp4 else 4 - # _scale_ds_loads counts scale ds_loads issued alongside A/B fragment loads in - # the streaming schedule (used for the partial-drain s_wait_dscnt bookkeeping). - # The general VGPR scale path holds scales in registers (no ds_load), so it - # contributes zero. Finalized below once use_natural_ascale is known. - _a_scale_ds = (wmma_m_rep + 3) // 4 - # 32x4 B-scale issues bs32_n_load b32 ds_loads per ks; legacy packs into b128s. - _b_scale_ds = bs32_n_load if use_n4k4_bscale else (b_scale_load_rep + 3) // 4 + # Scale ds_loads issued alongside A/B fragment loads in the streaming schedule + # (for the partial-drain s_wait_dscnt bookkeeping). A-scale is never a K-loop + # ds_load (VGPR ring for mxscale, epilogue for ptpc) -> 0. B-scale is one b32 per + # atom/ks for mxscale (32x4 LDS) and 0 for ptpc (epilogue). + _a_scale_ds = 0 + _b_scale_ds = bs32_n_load if is_mxscale else 0 _scale_ds_loads = _a_scale_ds + _b_scale_ds _a_frag_ds = wmma_m_rep * _a_frag_loads_per_wm _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _scale_ds_loads @@ -453,13 +319,12 @@ def compile_fp8fp4_gemm( lds_a_data_bytes = tile_m * lds_a_stride_bytes lds_b_data_bytes = tile_n * packed_tile_k_b _scale_guard_bytes = 16 - # Natural A-scale lives in VGPRs (buffer_load), so it needs no LDS. - lds_a_scale_bytes = 0 if (is_ptpc or use_natural_ascale) else tile_m * scale_k_per_tile + _scale_guard_bytes - if use_n4k4_bscale: - lds_b_scale_bytes = bs32_tile_atoms_pad * bs32_lds_row_stride + _scale_guard_bytes - else: - lds_b_scale_bytes = 0 if is_ptpc else tile_n * scale_k_per_tile + _scale_guard_bytes - interleaved_scale_cols_a = wmma_m_rep * scale_k_per_tile + # A-scale is in VGPRs now (lds_a_scale_bytes=0); the LDS scaffolding (stage region, + # stages_as, lds_as) is kept so a future TDM->LDS A-scale path can be re-enabled by + # allocating nonzero bytes here. + lds_a_scale_bytes = 0 + # B-scale: 32x4 LDS for mxscale; ptpc has none (scale in epilogue). + lds_b_scale_bytes = (bs32_tile_atoms_pad * bs32_lds_row_stride + _scale_guard_bytes) if is_mxscale else 0 def _align_up(value: int, align: int) -> int: if value % align == 0: @@ -470,7 +335,7 @@ def _align_up(value: int, align: int) -> int: # deriving per-wave offsets from ``wave_id``. In wave-specialized mode we # dedicate one loader wave to each tensor (A/B/A_scale/B_scale), so each # active loader wave must issue a full-tile descriptor by itself. - tdm_desc_num_warps = 1 if wave_specialized_tdm else num_warps + tdm_desc_num_warps = 1 # All pipeline stages share the same intra-stage layout in the generic # arena path. The active gfx1250 FP8 TDM tile uses a separate reference @@ -503,18 +368,12 @@ def _align_up(value: int, align: int) -> int: ), ) - if use_natural_ascale: - # A-scale leaves LDS (VGPR); B-scale stays an N4K4 ds_load. - _a_scale_ds = 0 - _scale_ds_loads = _b_scale_ds - _bs_ds_loads = wmma_n_rep * _b_frag_loads_per_wn + _scale_ds_loads - _as_ds_loads = _a_frag_ds + _scale_ds_loads # Scale prefetch depth (K-tiles ahead) for the A-scale VGPR ring: prefetch # deeper so each scale buffer_load overlaps an earlier tile's TDM wait. - _bvs_D_default = 3 if use_natural_ascale else 1 + _bvs_D_default = 3 if is_mxscale else 1 _bvs_D = max(1, int(os.environ.get("FLYDSL_BUFFER_VGPR_SCALE_DEPTH", str(_bvs_D_default)))) - # The buffer_load->VGPR scale ring is built only for the natural A-scale path. - _bvs_active = use_natural_ascale + # The buffer_load->VGPR A-scale ring is built for mxscale only. + _bvs_active = is_mxscale stage_phys_order = [i for i in range(num_buffers) if i != _last_compute_stage] stage_phys_order.append(_last_compute_stage) @@ -550,11 +409,8 @@ def _align_up(value: int, align: int) -> int: check_smem_capacity(arena_total_bytes, gpu_arch) # TENSORcnt is tracked per-wave in hardware. Wave-specialized TDM issues one - # tensor_load per wave per step; otherwise all 4 (A/B/A_scale/B_scale). - if wave_specialized_tdm: - TDM_LOADS_PER_STEP = 1 - else: - TDM_LOADS_PER_STEP = 4 + # tensor_load per wave per step. + TDM_LOADS_PER_STEP = 1 tail_plan = [(ls, cs, o * TDM_LOADS_PER_STEP // 2 if o > 0 else o) for ls, cs, o in _base_tail_plan] # Pre-compute epilogue sub-tile layout (unified for FP4 vec16 and FP8 vec8) @@ -589,13 +445,11 @@ def _align_up(value: int, align: int) -> int: and m_warp == 2 and n_warp == 2 and num_buffers == 4 - and wave_specialized_tdm and out_dtype == "bf16" ) if fp8_schedule == "deep-pipeline" and not fp8_deep_pipeline_eligible: raise ValueError( - "fp8_schedule='deep-pipeline' requires fp8 256x256x128, " - "m_warp=n_warp=2, num_buffers=4, wave_specialized_tdm=True, out_dtype='bf16'" + "fp8_schedule='deep-pipeline' requires fp8 256x256x128, " "m_warp=n_warp=2, num_buffers=4, out_dtype='bf16'" ) def _pick_compute_schedule_kind(): @@ -620,7 +474,7 @@ def _pick_compute_schedule_kind(): use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE - if use_n4k4_bscale: + if is_mxscale: assert compute_schedule_kind in ( COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING, COMPUTE_SCHEDULE_FP8_QUADRANT, @@ -628,10 +482,7 @@ def _pick_compute_schedule_kind(): COMPUTE_SCHEDULE_FP4_COL_BAND, ) use_ws_tdm_split_signal_overlap = ( - wave_specialized_tdm - and (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) - and num_buffers == 4 - and use_cluster + (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) and num_buffers == 4 and use_cluster ) if use_fp4_bank_friendly_schedule: @@ -656,7 +507,7 @@ def _pick_compute_schedule_kind(): _fp8_half_wm = wmma_m_rep // 2 _fp8_half_wn = wmma_n_rep // 2 _fp8_group_size = _fp8_half_wm * _fp8_half_wn - if use_n4k4_bscale: + if is_mxscale: _fp8_b_scale_loads = bs32_n_load # 32x4: one b32 per atom-or-WMMA per ks else: _fp8_b_scale_loads = 0 if is_ptpc else (b_scale_load_rep + 3) // 4 @@ -667,7 +518,7 @@ def _pick_compute_schedule_kind(): _fp8_wn_pairs = wmma_n_rep // _fp8_pair_wn _fp8_pair_a_loads = _fp8_pair_wm * DS_LOADS_PER_A_FRAG _fp8_pair_b_loads = _fp8_pair_wn * _b_frag_loads_per_wn - # Scale ds_loads issued at the loop top: a-scale (0 when natural-VGPR/ptpc) + + # Scale ds_loads issued at the loop top: a-scale (0, A is in VGPRs/ptpc) + # b-scale (bs32_n_load for 32x4). Uses the finalized module-level ds counts. _fp8_scale_loads = 0 if is_ptpc else (_a_scale_ds + _b_scale_ds) @@ -740,32 +591,31 @@ def _load_contig_i32(rsrc, base_idx, n, soff): out[start + c] = rv[c] return out - if const_expr(use_natural_ascale): - # Natural A-scale: read A_scale[M, K//32] straight into VGPRs (no reshuffle). + if const_expr(is_mxscale): + # A-scale: read A_scale[M, K//32] straight into VGPRs (no reshuffle). # A row's K-scales are contiguous -> one wide load per M-block grabs all ks. # kt rides the scalar soffset so the per-lane voffset is K-tile-invariant # (CSE'd -> loads fully hidden). M op_sel: kgrp1 reads block wm+rep/2. - _nat_as_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) - _nat_row_i32 = K_scale // 4 # i32 elements per A_scale row (K_scale = K//32, %4==0) - _nat_row0 = blk_m + warp_m_base + lane16 - if const_expr(use_natural_ascale_opsel): - _nat_row0 = _nat_row0 + lane_kgrp * arith.index(nat_as_half * WMMA_M) - _vs_tile_a = k_wmma_steps * nat_as_load - _vs_tile_b = 0 - - def _nat_as_load(k_base): + _ascale_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False) + _ascale_row_i32 = K_scale // 4 # i32 elements per A_scale row (K_scale = K//32, %4==0) + _ascale_row0 = blk_m + warp_m_base + lane16 + if const_expr(ascale_opsel): + _ascale_row0 = _ascale_row0 + lane_kgrp * arith.index(ascale_half * WMMA_M) + _vs_tile_a = k_wmma_steps * ascale_load + + def _load_ascale(k_base): kt = k_base // arith.index(tile_k) soff = arith.index_cast(T.i32, kt * arith.index(scale_k_per_tile)) - vals = [None] * (k_wmma_steps * nat_as_load) - for i in range_constexpr(nat_as_load): - vidx = (_nat_row0 + arith.index(i * WMMA_M)) * arith.index(_nat_row_i32) - ks_vals = _load_contig_i32(_nat_as_rsrc, vidx, k_wmma_steps, soff) + vals = [None] * (k_wmma_steps * ascale_load) + for i in range_constexpr(ascale_load): + vidx = (_ascale_row0 + arith.index(i * WMMA_M)) * arith.index(_ascale_row_i32) + ks_vals = _load_contig_i32(_ascale_rsrc, vidx, k_wmma_steps, soff) for ks in range_constexpr(k_wmma_steps): - vals[ks * nat_as_load + i] = ks_vals[ks] + vals[ks * ascale_load + i] = ks_vals[ks] return vals - def _bvs_prefetch(k_base): - return _nat_as_load(k_base), [] + # Prefetch one K-tile's A-scale VGPRs + _bvs_prefetch = _load_ascale m_idx = fx.Index(i32_m) # Runtime leading-dim strides (strided A/C). Dense callers pass lda == K, @@ -820,26 +670,6 @@ def make_desc_b(memref, k_base): early_timeout=True, ) - def make_desc_as(memref, k_base): - k_scale_off = k_base // arith.index(SCALE_BLOCK) - outer_off = blk_m // arith.index(wmma_m_rep) - inner_off = k_scale_off * arith.index(wmma_m_rep) - return _make_tdm_desc( - global_ptr=arg_a_scale, - lds_memref=memref, - global_offset=(outer_off, inner_off), - tensor_shape=(WMMA_M * m_warp, interleaved_scale_cols_a), - strides=(wmma_m_rep * K_scale, 1), - tile_shape=(WMMA_M * m_warp, interleaved_scale_cols_a), - elem_bytes=1, - pad_interval=0, - pad_amount=0, - num_warps=tdm_desc_num_warps, - workgroup_mask=a_mcast_mask, - atomic_barrier_enable=atomic_barrier_enable, - early_timeout=True, - ) - def make_desc_bs(memref, k_base): # 32x4: copy this tile's 32-N atoms x K-blocks slice of the preshuffled # [N//32, (K//128)*128] B-scale tensor. Each row is one 32-N atom group; @@ -864,16 +694,15 @@ def make_desc_bs(memref, k_base): oob_outer_bound=N // 32, ) - if const_expr(wave_specialized_tdm): - tdm_wave_id = rocdl.wave_id() - tdm_wave_is_a = tdm_wave_id == fx.Int32(0) - tdm_wave_is_b = tdm_wave_id == fx.Int32(1) - tdm_wave_is_as = tdm_wave_id == fx.Int32(2) + tdm_wave_id = rocdl.wave_id() + tdm_wave_is_a = tdm_wave_id == fx.Int32(0) + tdm_wave_is_b = tdm_wave_id == fx.Int32(1) + tdm_wave_is_as = tdm_wave_id == fx.Int32(2) - def _select_wave_tdm_value(a_value, b_value, as_value, bs_value): - result = arith.select(tdm_wave_is_as, as_value, bs_value) - result = arith.select(tdm_wave_is_b, b_value, result) - return arith.select(tdm_wave_is_a, a_value, result) + def _select_wave_tdm_value(a_value, b_value, as_value, bs_value): + result = arith.select(tdm_wave_is_as, as_value, bs_value) + result = arith.select(tdm_wave_is_b, b_value, result) + return arith.select(tdm_wave_is_a, a_value, result) elem_ty_lds = T.f16 @@ -990,15 +819,6 @@ def load_b_frag(lds_buffer, b_lane_bases, wn, ks): v23 = v2.shuffle(v3, list(range(8))) return v01.shuffle(v23, list(range(16))) - def _precompute_scale_lane_bases(lds_ptr, warp_base, reps, interleaved_cols): - """Precompute scale lane bases (byte offsets).""" - warp_lds_row = warp_base // arith.index(reps) + lane16 - base = warp_lds_row * arith.index(interleaved_cols) - if const_expr(is_fp4 or is_a8w4): - # FP4/A8W4: always add lane_kgrp offset (no opsel on BScale) - base = base + lane_kgrp * arith.index(SCALES_PER_WMMA) - return lds_ptr, [base] - def _precompute_bs32_bases(lds_ptr): """Tile-local 32-N atom base for the warp's 32x4 B-scale read. @@ -1038,59 +858,31 @@ def load_bs32_bscale(lds_buffer, atom0, ks): results.append(lds_load_b32_raw(lds_buffer, off)) return results - def load_scale_b128(lds_buffer, scale_base, reps, ks=0): - """Load all wmma_rep scales via ds_load_b128(s) for K-subtile *ks*.""" - ks_byte_off = ks * reps * SCALES_PER_WMMA - eff_base = scale_base if ks_byte_off == 0 else scale_base + arith.index(ks_byte_off) - num_loads = (reps + 3) // 4 - vecs = [] - for ld in range_constexpr(num_loads): - off = eff_base if ld == 0 else eff_base + arith.index(ld * 16) - vecs.append(fx.Vector(lds_load_b128_raw(lds_buffer, off))) - results = [] - for i in range_constexpr(reps): - results.append(vecs[i // 4][i % 4]) - return results - - # Holds the current tile's prefetched VGPR scales (a_flat, b_flat), each - # ordered [k_wmma_step][rep]. compute_tile sets it before emitting; the - # general-vgpr branch of _scales_for_emit slices it per K-subtile. Set-then- - # consume is sequential at emit time (same pattern as epi_addrs_box). + # Holds the current tile's prefetched A-scale VGPRs (a_flat, ordered + # [k_wmma_step][rep]). compute_tile sets it before emitting; _scales_for_emit + # slices it per K-subtile. Set-then-consume is sequential at emit time (same + # pattern as epi_addrs_box). _vgpr_scale_box = [None] def _load_b_scale_lds(bs_buf, bs_atom0, ks): """Load 32x4 B-scale from LDS (mxscale only; ptpc reads no K-loop B-scale).""" return load_bs32_bscale(bs_buf, bs_atom0, ks) - def _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks): - """Load both scale tensors and apply op_sel downsampling per format. - - FP4 BScale has no op_sel (scaleAType=0 fixed); only AScale halves. - FP8/A8W4 16x16 supports op_sel on both. - """ + def _scales_for_emit(bs_buf, bs_bases, ks): + """Load both scale tensors for K-subtile *ks*: A from the VGPR ring + (M op_sel handled in _emit via ascale_half), B from the 32x4 LDS layout.""" if const_expr(is_ptpc): return None, None - if const_expr(use_natural_ascale): - # A from the natural VGPR ring (slice this ks; M op_sel handled in - # _emit via nat_as_half); B from the N4K4 LDS layout. - pf_a, _ = _vgpr_scale_box[0] - a = pf_a[ks * nat_as_load : (ks + 1) * nat_as_load] - b = _load_b_scale_lds(bs_buf, bs_bases, ks) - return a, b - a = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + pf_a = _vgpr_scale_box[0] + a = pf_a[ks * ascale_load : (ks + 1) * ascale_load] b = _load_b_scale_lds(bs_buf, bs_bases, ks) return a, b - def _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, ks): + def _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, ks): b_frags = [load_b_frag(b_buf, b_bases, wn, ks) for wn in range_constexpr(wmma_n_rep)] - a_scales, b_scales = _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks) + a_scales, b_scales = _scales_for_emit(bs_buf, bs_bases, ks) return b_frags, b_scales, a_scales - def _load_a_and_scales(a_buf, a_bases, as_buf, as_bases, bs_buf, bs_bases, ks): - a_frags = [load_a_frag(a_buf, a_bases[wm], ks) for wm in range_constexpr(wmma_m_rep)] - a_scales, b_scales = _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks) - return a_frags, a_scales, b_scales - def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): """Emit one WMMA instruction (format-specific).""" idx = wm * wmma_n_rep + wn @@ -1107,26 +899,13 @@ def _emit_wmma(accs, wm, wn, a_frag, b_frag, a_scales, b_scales): fmtB=0, ) else: - # PTPC-FP8 needs no per-K scaling. We emit the scaled f8f6f4 op - # with an identity E8M0 scale (0x7F = 2^0 = 1.0) for toolchain - # compatibility; it is numerically equivalent to the dedicated - # no-scale op. Future: switch to the equivalent no-scale wmma: - # accs[idx] = rocdl.wmma_f32_16x16x128_fp8_fp8(T.vec(8, T.f32), b_frag, a_frag, accs[idx]) - accs[idx] = rocdl.wmma_scale_f32_16x16x128_f8f6f4( - T.vec(8, T.f32), - b_frag, - a_frag, - accs[idx], - 0x7F7F7F7F, - 0x7F7F7F7F, - fmtA=0, - fmtB=0, - ) + # PTPC-FP8 needs no per-K scaling: dedicated no-scale E4M3 WMMA. + accs[idx] = rocdl.wmma_f32_16x16x128_fp8_fp8(T.vec(8, T.f32), b_frag, a_frag, accs[idx]) return - if const_expr(use_natural_ascale_opsel): - # Natural A M op_sel pairs (j, j+rep/2): kgrp1 carries the second half. - a_scale_idx = wm % nat_as_half - a_opsel = wm // nat_as_half + if const_expr(ascale_opsel): + # A-scale M op_sel pairs (j, j+rep/2): kgrp1 carries the second half. + a_scale_idx = wm % ascale_half + a_opsel = wm // ascale_half else: a_scale_idx = wm a_opsel = 0 @@ -1205,8 +984,8 @@ def _emit_rows(start_wm, a_frags): _use_partial_drain = next_bs_info is not None and _front_wm * wmma_n_rep >= 4 if const_expr(_use_partial_drain): - nb_buf, nb_bases, nbs_buf, nbs_bases, nas_buf, nas_bases, n_ks = next_bs_info - next_result = _load_b_and_scales(nb_buf, nb_bases, nbs_buf, nbs_bases, nas_buf, nas_bases, n_ks) + nb_buf, nb_bases, nbs_buf, nbs_bases, n_ks = next_bs_info + next_result = _load_b_and_scales(nb_buf, nb_bases, nbs_buf, nbs_bases, n_ks) rocdl.s_wait_dscnt(_bs_ds_loads) else: rocdl.s_wait_dscnt(0) @@ -1226,8 +1005,8 @@ def _emit_rows(start_wm, a_frags): if const_expr(_use_partial_drain): return accs, next_result if const_expr(next_bs_info is not None): - nb_buf, nb_bases, nbs_buf, nbs_bases, nas_buf, nas_bases, n_ks = next_bs_info - next_result = _load_b_and_scales(nb_buf, nb_bases, nbs_buf, nbs_bases, nas_buf, nas_bases, n_ks) + nb_buf, nb_bases, nbs_buf, nbs_bases, n_ks = next_bs_info + next_result = _load_b_and_scales(nb_buf, nb_bases, nbs_buf, nbs_bases, n_ks) return accs, next_result return accs @@ -1242,14 +1021,13 @@ def compute_tile( mid_compute_callback=None, scale_k_base=None, pf_a_scales=None, - pf_b_scales=None, ): current_accs = list(accs_in) - if const_expr(use_natural_ascale): + if const_expr(is_mxscale): # A-scale comes from VGPR: use the loop-prefetched ring when provided, # else issue the buffer_loads inline (tail path) for scale_k_base. if const_expr(pf_a_scales is not None): - _vgpr_scale_box[0] = (pf_a_scales, pf_b_scales) + _vgpr_scale_box[0] = pf_a_scales else: # Inline tail load: barrier so the buffer_loads can't be hoisted # above the caller's pipeline fence (mirrors the main-loop path). @@ -1257,19 +1035,13 @@ def compute_tile( _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - if const_expr(use_natural_ascale): - as_buf, as_bases = None, None # A-scale from the VGPR ring, not LDS - else: - as_buf, as_bases = _precompute_scale_lane_bases( - lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a - ) - if const_expr(use_n4k4_bscale): + if const_expr(is_mxscale): bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) else: bs_buf, bs_bases = lds_bs, None # ptpc: B-scale in epilogue, bases unused if const_expr(k_wmma_steps == 1): - b_frags, b_scales, a_scales = _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, 0) + b_frags, b_scales, a_scales = _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, 0) current_accs = _a_streaming_compute( current_accs, a_buf, @@ -1282,7 +1054,7 @@ def compute_tile( mid_compute_callback=mid_compute_callback, ) else: - prev_b, prev_bs, prev_as = _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, 0) + prev_b, prev_bs, prev_as = _load_b_and_scales(b_buf, b_bases, bs_buf, bs_bases, 0) for ks in range_constexpr(k_wmma_steps - 1): _mid_cb = mid_compute_callback if ks == 0 else None current_accs, (prev_b, prev_bs, prev_as) = _a_streaming_compute( @@ -1293,7 +1065,7 @@ def compute_tile( prev_bs, prev_as, ks, - next_bs_info=(b_buf, b_bases, bs_buf, bs_bases, as_buf, as_bases, ks + 1), + next_bs_info=(b_buf, b_bases, bs_buf, bs_bases, ks + 1), mid_compute_callback=_mid_cb, ) current_accs = _a_streaming_compute( @@ -1311,31 +1083,24 @@ def compute_tile_fp4_bank_friendly( mid_compute_callback=None, scale_k_base=None, pf_a_scales=None, - pf_b_scales=None, ): current_accs = list(accs_in) - if const_expr(use_natural_ascale): + if const_expr(is_mxscale): # A-scale from the VGPR ring (loop-prefetched, else inline tail load). if const_expr(pf_a_scales is not None): - _vgpr_scale_box[0] = (pf_a_scales, pf_b_scales) + _vgpr_scale_box[0] = pf_a_scales else: rocdl.sched_barrier(0) _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - if const_expr(use_natural_ascale): - as_buf, as_bases = None, None # A-scale from the VGPR ring, not LDS - else: - as_buf, as_bases = _precompute_scale_lane_bases( - lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a - ) bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) _b_half_scale_loads = _bank_half_wn # 32x4: one b32 per 32-N atom/WMMA def _fp4_get_a_scale_and_opsel(a_scales_all, wm_idx): - if const_expr(use_natural_ascale_opsel): - # Natural M op_sel pairs (wm, wm+rep/2): kgrp1 carries the second half. - return a_scales_all[wm_idx % nat_as_half], wm_idx // nat_as_half + if const_expr(ascale_opsel): + # A-scale M op_sel pairs (wm, wm+rep/2): kgrp1 carries the second half. + return a_scales_all[wm_idx % ascale_half], wm_idx // ascale_half return a_scales_all[wm_idx], 0 def _load_a_group(wm_base, wm_count, ks): @@ -1403,11 +1168,8 @@ def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_ for ks in range_constexpr(k_wmma_steps): is_last_ks = ks == k_wmma_steps - 1 - if const_expr(use_natural_ascale): - pf_a, _ = _vgpr_scale_box[0] - a_scales_all = pf_a[ks * nat_as_load : (ks + 1) * nat_as_load] - else: - a_scales_all = load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + pf_a = _vgpr_scale_box[0] + a_scales_all = pf_a[ks * ascale_load : (ks + 1) * ascale_load] a_top_frags = _load_a_group(0, _bank_half_wm, ks) a_bottom_frags = _load_a_group(_bank_half_wm, _bank_half_wm, ks) @@ -1487,25 +1249,18 @@ def compute_tile_fp8_quadrant( late_compute_callback=None, scale_k_base=None, pf_a_scales=None, - pf_b_scales=None, ): current_accs = list(accs_in) - if const_expr(use_natural_ascale): + if const_expr(is_mxscale): # A-scale from the VGPR ring (loop-prefetched, else inline tail load). if const_expr(pf_a_scales is not None): - _vgpr_scale_box[0] = (pf_a_scales, pf_b_scales) + _vgpr_scale_box[0] = pf_a_scales else: rocdl.sched_barrier(0) _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - if const_expr(use_natural_ascale): - as_buf, as_bases = None, None # A-scale from the VGPR ring, not LDS - else: - as_buf, as_bases = _precompute_scale_lane_bases( - lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a - ) - if const_expr(use_n4k4_bscale): + if const_expr(is_mxscale): bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) else: bs_buf, bs_bases = lds_bs, None # ptpc: B-scale in epilogue, bases unused @@ -1523,11 +1278,9 @@ def _load_b_half(wn_base, ks): def _load_a_scales(ks): if const_expr(is_ptpc): return None # PTPC: scale applied in epilogue, not in K-loop - if const_expr(use_natural_ascale): - # A from the VGPR ring (slice this ks); M op_sel via nat_as_half in _emit. - pf_a, _ = _vgpr_scale_box[0] - return pf_a[ks * nat_as_load : (ks + 1) * nat_as_load] - return load_scale_b128(as_buf, as_bases[0], wmma_m_rep, ks) + # A from the VGPR ring (slice this ks); M op_sel via ascale_half in _emit. + pf_a = _vgpr_scale_box[0] + return pf_a[ks * ascale_load : (ks + 1) * ascale_load] def _load_b_scales(ks): if const_expr(is_ptpc): @@ -1686,25 +1439,18 @@ def compute_tile_fp8_deep_pipeline( a0_prefetch=None, scale_k_base=None, pf_a_scales=None, - pf_b_scales=None, ): current_accs = list(accs_in) - if const_expr(use_natural_ascale): + if const_expr(is_mxscale): # A-scale from the VGPR ring (loop-prefetched, else inline tail load). if const_expr(pf_a_scales is not None): - _vgpr_scale_box[0] = (pf_a_scales, pf_b_scales) + _vgpr_scale_box[0] = pf_a_scales else: rocdl.sched_barrier(0) _vgpr_scale_box[0] = _bvs_prefetch(scale_k_base) a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) - if const_expr(use_natural_ascale): - as_buf, as_bases = None, None # A-scale from the VGPR ring, not LDS - else: - as_buf, as_bases = _precompute_scale_lane_bases( - lds_as, warp_m_base, wmma_m_rep, interleaved_scale_cols_a - ) - if const_expr(use_n4k4_bscale): + if const_expr(is_mxscale): bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) else: bs_buf, bs_bases = lds_bs, None # ptpc: B-scale in epilogue, bases unused @@ -1775,7 +1521,7 @@ def emit_panel_2x2_row(wm_pair, wn_pair, row_local, a_pair, b_pair, scale_pair): for ks in range_constexpr(k_wmma_steps): is_last_ks = ks == k_wmma_steps - 1 - a_scales, b_scales = _scales_for_emit(as_buf, as_bases, bs_buf, bs_bases, ks) + a_scales, b_scales = _scales_for_emit(bs_buf, bs_bases, ks) scale_pair = (a_scales, b_scales) b0 = load_b_pair(0, ks) @@ -1873,7 +1619,7 @@ def hot_loop_scheduler(): def hot_loop_scheduler_fp4_bank_friendly(): _a_all_loads = wmma_m_rep * DS_LOADS_PER_A_FRAG - _a_scale_loads = 0 if use_natural_ascale else (wmma_m_rep + 3) // 4 + _a_scale_loads = 0 # A-scale is in VGPRs, not ds_load'd _b_half_loads = _bank_half_wn * 4 _b_half_scale_loads = _bank_half_wn # 32x4: one b32 per 32-N atom/WMMA _group_wmma = _bank_group_size @@ -1894,7 +1640,7 @@ def hot_loop_scheduler_fp4_bank_friendly(): rocdl.sched_barrier(0) def hot_loop_scheduler_fp8_quadrant(): - _a_scale_loads = 0 if (is_ptpc or use_natural_ascale) else (wmma_m_rep + 3) // 4 + _a_scale_loads = 0 # A-scale is in VGPRs (mxscale) or epilogue (ptpc), not ds_load'd _a_top_loads = _fp8_half_wm * DS_LOADS_PER_A_FRAG _a_bottom_loads = _a_top_loads _b_half_loads = _fp8_half_wn * _b_frag_loads_per_wn @@ -1972,7 +1718,6 @@ def compute_tile_scheduled( a0_prefetch=None, scale_k_base=None, pf_a_scales=None, - pf_b_scales=None, ): if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): return compute_tile_fp4_bank_friendly( @@ -1985,7 +1730,6 @@ def compute_tile_scheduled( mid_compute_callback=mid_compute_callback, scale_k_base=scale_k_base, pf_a_scales=pf_a_scales, - pf_b_scales=pf_b_scales, ) if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT): return compute_tile_fp8_quadrant( @@ -1999,7 +1743,6 @@ def compute_tile_scheduled( late_compute_callback=late_compute_callback, scale_k_base=scale_k_base, pf_a_scales=pf_a_scales, - pf_b_scales=pf_b_scales, ) if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE): return compute_tile_fp8_deep_pipeline( @@ -2014,7 +1757,6 @@ def compute_tile_scheduled( a0_prefetch=a0_prefetch, scale_k_base=scale_k_base, pf_a_scales=pf_a_scales, - pf_b_scales=pf_b_scales, ) return compute_tile( accs_in, @@ -2026,7 +1768,6 @@ def compute_tile_scheduled( mid_compute_callback=mid_compute_callback, scale_k_base=scale_k_base, pf_a_scales=pf_a_scales, - pf_b_scales=pf_b_scales, ) def hot_loop_scheduler_scheduled(): @@ -2240,11 +1981,12 @@ def _l2_prefetch(k_base): ] if const_expr(is_ptpc): # PTPC applies sa*sb in the epilogue from global memory: no scale LDS. - # Alias the scale stage handles to A/B so the shared plumbing stays - # valid; for PTPC they are never written (no scale TDM) or read. + # Alias the scale stage handles to A/B so the shared plumbing stays valid; + # for PTPC they are never written (no scale TDM) or read. stages_as = stages_a stages_bs = stages_b else: + # A-scale LDS region is currently 0-byte (A in VGPRs); kept as scaffolding. stages_as = [ SmemPtr(arena_base_ptr, stage_a_scale_off[i], elem_ty_lds, shape=(lds_a_scale_f16,)) for i in range_constexpr(num_buffers) @@ -2256,7 +1998,6 @@ def _l2_prefetch(k_base): stages_a_mem = [stages_a[i].get() for i in range_constexpr(num_buffers)] stages_b_mem = [stages_b[i].get() for i in range_constexpr(num_buffers)] - stages_as_mem = [stages_as[i].get() for i in range_constexpr(num_buffers)] stages_bs_mem = [stages_bs[i].get() for i in range_constexpr(num_buffers)] stages_a_idx = [extract_lds_base_idx(stages_a[i]) for i in range_constexpr(num_buffers)] @@ -2317,9 +2058,7 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): stages_a_lds_addr.append(_dg0_lane(make_desc_a(stages_a_mem[i], arith.index(0)), 1)) stages_b_lds_addr.append(_dg0_lane(make_desc_b(stages_b_mem[i], arith.index(0)), 1)) if const_expr(not is_ptpc): - # Natural A-scale has no TDM (VGPR); B-scale keeps its TDM descriptor. - if const_expr(not use_natural_ascale): - stages_as_lds_addr.append(_dg0_lane(make_desc_as(stages_as_mem[i], arith.index(0)), 1)) + # A-scale has no TDM (VGPR ring); B-scale keeps its TDM descriptor. stages_bs_lds_addr.append(_dg0_lane(make_desc_bs(stages_bs_mem[i], arith.index(0)), 1)) desc_a_init = make_desc_a(stages_a_mem[0], split_k_base) @@ -2332,96 +2071,73 @@ def _pack_dg0(pred, lds_addr, addr_lo, addr_hi): desc_as_init = desc_a_init desc_bs_init = desc_b_init else: - if const_expr(use_natural_ascale): - # A-scale on VGPR: alias its (never-issued) TDM slot to A; wave2 carries B-scale. - stages_as_lds_addr = stages_a_lds_addr - desc_as_init = desc_a_init - else: - desc_as_init = make_desc_as(stages_as_mem[0], split_k_base) + # A-scale on VGPR: alias its (never-issued) TDM slot to A; wave2 carries B-scale. + stages_as_lds_addr = stages_a_lds_addr + desc_as_init = desc_a_init desc_bs_init = make_desc_bs(stages_bs_mem[0], split_k_base) adv_a_i32 = fx.Int32(tile_k // PACK_FACTOR_A) adv_b_i32 = fx.Int32(packed_tile_k_b * 16) adv_as_i32 = fx.Int32(tile_k // SCALE_BLOCK * wmma_m_rep) - # N4K4 advances by one tile's worth of K-blocks (k_wmma_steps*256B) per - # K-step; the legacy interleaved layout advances by scale_k_per_tile*rep. - adv_bs_i32 = fx.Int32(bs32_lds_row_stride if use_n4k4_bscale else tile_k // SCALE_BLOCK * b_scale_load_rep) - - pred_const = fx.Int32(1) - if const_expr(wave_specialized_tdm): - _drop_scale_waves = is_ptpc - if const_expr(use_natural_ascale): - # wave0=A, wave1=B, wave2=B-scale (>=3 waves); A-scale is VGPR. At 2 - # waves wave2 doesn't exist and B-scale rides wave0 as a secondary. - _active_wave_limit = min(num_warps, 3) - else: - _active_wave_limit = 2 if _drop_scale_waves else 4 - active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) - - def _select4(values): - return _select_wave_tdm_value(values[0], values[1], values[2], values[3]) - - def _desc_lanes(descs, lane): - return [_dg0_lane(desc, lane) for desc in descs] - - def _select_active_tdm(stage_lds_addrs, descs, advs): - active_stages = [ - _select_wave_tdm_value( - stage_lds_addrs[0][i], - stage_lds_addrs[1][i], - stage_lds_addrs[2][i], - stage_lds_addrs[3][i], - ) - for i in range_constexpr(num_buffers) - ] - return ( - active_stages, - _select4(_desc_lanes(descs, 2)), - _select4(_desc_lanes(descs, 3)), - _select4([desc.dgroup1 for desc in descs]), - _select4(advs), - ) - + # 32x4 B-scale advances one tile's K-blocks per K-step (ptpc's slot is aliased + # to B-data and never issued, so its value is unused). + adv_bs_i32 = fx.Int32(bs32_lds_row_stride if is_mxscale else tile_k // SCALE_BLOCK * b_scale_load_rep) + + _drop_scale_waves = is_ptpc + if const_expr(is_mxscale): + # wave0=A, wave1=B, wave2=B-scale (>=3 waves); A-scale is VGPR. At 2 + # waves wave2 doesn't exist and B-scale rides wave0 as a secondary. + _active_wave_limit = min(num_warps, 3) else: - active_pred_const = pred_const - - if const_expr(wave_specialized_tdm): - if const_expr(use_natural_ascale): - # Remap: wave2 (the old A-scale slot) now issues B-scale; wave3 is the - # padded 4th slot (predicated off by _active_wave_limit=3). - _tdm_stage_sel = (stages_a_lds_addr, stages_b_lds_addr, stages_bs_lds_addr, stages_bs_lds_addr) - _tdm_desc_sel = (desc_a_init, desc_b_init, desc_bs_init, desc_bs_init) - _tdm_adv_sel = (adv_a_i32, adv_b_i32, adv_bs_i32, adv_bs_i32) - else: - _tdm_stage_sel = (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr) - _tdm_desc_sel = (desc_a_init, desc_b_init, desc_as_init, desc_bs_init) - _tdm_adv_sel = (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32) - active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( - _tdm_stage_sel, _tdm_desc_sel, _tdm_adv_sel + _active_wave_limit = 2 if _drop_scale_waves else 4 + active_pred_const = arith.select(tdm_wave_id < fx.Int32(_active_wave_limit), fx.Int32(1), fx.Int32(0)) + + def _select4(values): + return _select_wave_tdm_value(values[0], values[1], values[2], values[3]) + + def _desc_lanes(descs, lane): + return [_dg0_lane(desc, lane) for desc in descs] + + def _select_active_tdm(stage_lds_addrs, descs, advs): + active_stages = [ + _select_wave_tdm_value( + stage_lds_addrs[0][i], + stage_lds_addrs[1][i], + stage_lds_addrs[2][i], + stage_lds_addrs[3][i], + ) + for i in range_constexpr(num_buffers) + ] + return ( + active_stages, + _select4(_desc_lanes(descs, 2)), + _select4(_desc_lanes(descs, 3)), + _select4([desc.dgroup1 for desc in descs]), + _select4(advs), ) - if const_expr(natural_two_wave): - # Secondary TDM: B-scale issued by wave0 only (2-wave packs A-data + - # B-scale onto wave0). Static, wave-independent; carried addr_lo below. - sec_pred_const = arith.select(tdm_wave_id == fx.Int32(0), fx.Int32(1), fx.Int32(0)) - sec_stage_lds_addr = stages_bs_lds_addr - sec_addr_hi = _dg0_lane(desc_bs_init, 3) - sec_dgroup1 = desc_bs_init.dgroup1 - sec_adv_i32 = adv_bs_i32 - sec_addr_lo_init = _dg0_lane(desc_bs_init, 2) + + if const_expr(is_mxscale): + # Remap: wave2 (the old A-scale slot) now issues B-scale; wave3 is the + # padded 4th slot (predicated off by _active_wave_limit=3). + _tdm_stage_sel = (stages_a_lds_addr, stages_b_lds_addr, stages_bs_lds_addr, stages_bs_lds_addr) + _tdm_desc_sel = (desc_a_init, desc_b_init, desc_bs_init, desc_bs_init) + _tdm_adv_sel = (adv_a_i32, adv_b_i32, adv_bs_i32, adv_bs_i32) else: - addr_lo_a = _dg0_lane(desc_a_init, 2) - addr_hi_a = _dg0_lane(desc_a_init, 3) - addr_lo_b = _dg0_lane(desc_b_init, 2) - addr_hi_b = _dg0_lane(desc_b_init, 3) - addr_lo_as = _dg0_lane(desc_as_init, 2) - addr_hi_as = _dg0_lane(desc_as_init, 3) - addr_lo_bs = _dg0_lane(desc_bs_init, 2) - addr_hi_bs = _dg0_lane(desc_bs_init, 3) - - dgroup1_a = desc_a_init.dgroup1 - dgroup1_b = desc_b_init.dgroup1 - dgroup1_as = desc_as_init.dgroup1 - dgroup1_bs = desc_bs_init.dgroup1 + _tdm_stage_sel = (stages_a_lds_addr, stages_b_lds_addr, stages_as_lds_addr, stages_bs_lds_addr) + _tdm_desc_sel = (desc_a_init, desc_b_init, desc_as_init, desc_bs_init) + _tdm_adv_sel = (adv_a_i32, adv_b_i32, adv_as_i32, adv_bs_i32) + active_stage_lds_addr, active_addr_lo, active_addr_hi, active_dgroup1, active_adv_i32 = _select_active_tdm( + _tdm_stage_sel, _tdm_desc_sel, _tdm_adv_sel + ) + if const_expr(two_wave_bscale): + # Secondary TDM: B-scale issued by wave0 only (2-wave packs A-data + + # B-scale onto wave0). Static, wave-independent; carried addr_lo below. + sec_pred_const = arith.select(tdm_wave_id == fx.Int32(0), fx.Int32(1), fx.Int32(0)) + sec_stage_lds_addr = stages_bs_lds_addr + sec_addr_hi = _dg0_lane(desc_bs_init, 3) + sec_dgroup1 = desc_bs_init.dgroup1 + sec_adv_i32 = adv_bs_i32 + sec_addr_lo_init = _dg0_lane(desc_bs_init, 2) def _pipeline_fence(outstanding=0): pipeline_fence(outstanding=outstanding, use_cluster=use_cluster) @@ -2429,59 +2145,36 @@ def _pipeline_fence(outstanding=0): def _pipeline_fence_signal(outstanding=0): pipeline_fence_signal(outstanding=outstanding, use_cluster=use_cluster) - if const_expr(wave_specialized_tdm): - - def _issue_active_tdm(load_stage, addr_box, k_prefetch=None, sec_box=None): - dg0 = _pack_dg0(active_pred_const, active_stage_lds_addr[load_stage], addr_box[0], active_addr_hi) - tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) - addr_box[0] = addr_box[0] + active_adv_i32 - if const_expr(natural_two_wave): - # wave0's second descriptor: B-scale (predicated to wave0). - dg0s = _pack_dg0(sec_pred_const, sec_stage_lds_addr[load_stage], sec_box[0], sec_addr_hi) - tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0s, sec_dgroup1)) - sec_box[0] = sec_box[0] + sec_adv_i32 - if k_prefetch is not None: - _l2_prefetch(k_prefetch) + def _issue_active_tdm(load_stage, addr_box, k_prefetch=None, sec_box=None): + dg0 = _pack_dg0(active_pred_const, active_stage_lds_addr[load_stage], addr_box[0], active_addr_hi) + tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0, active_dgroup1)) + addr_box[0] = addr_box[0] + active_adv_i32 + if const_expr(two_wave_bscale): + # wave0's second descriptor: B-scale (predicated to wave0). + dg0s = _pack_dg0(sec_pred_const, sec_stage_lds_addr[load_stage], sec_box[0], sec_addr_hi) + tdm_ops.tensor_load_2d(tdm_ops.TDMDescriptor2D(dg0s, sec_dgroup1)) + sec_box[0] = sec_box[0] + sec_adv_i32 + if k_prefetch is not None: + _l2_prefetch(k_prefetch) # Prologue - if const_expr(wave_specialized_tdm): - if const_expr(natural_two_wave): - active_sec_lo = sec_addr_lo_init - for i in range_constexpr(pre_loaded): - addr_box = [active_addr_lo] - if const_expr(natural_two_wave): - sec_box = [active_sec_lo] - _issue_active_tdm(i, addr_box, sec_box=sec_box) - active_sec_lo = sec_box[0] - else: - _issue_active_tdm(i, addr_box) - active_addr_lo = addr_box[0] - else: - for i in range_constexpr(pre_loaded): - dg0_a = _pack_dg0(pred_const, stages_a_lds_addr[i], addr_lo_a, addr_hi_a) - dg0_b = _pack_dg0(pred_const, stages_b_lds_addr[i], addr_lo_b, addr_hi_b) - dg0_as = _pack_dg0(pred_const, stages_as_lds_addr[i], addr_lo_as, addr_hi_as) - dg0_bs = _pack_dg0(pred_const, stages_bs_lds_addr[i], addr_lo_bs, addr_hi_bs) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), - tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), - wave_specialized=wave_specialized_tdm, - ) - - addr_lo_a = addr_lo_a + adv_a_i32 - addr_lo_b = addr_lo_b + adv_b_i32 - addr_lo_as = addr_lo_as + adv_as_i32 - addr_lo_bs = addr_lo_bs + adv_bs_i32 - + if const_expr(two_wave_bscale): + active_sec_lo = sec_addr_lo_init + for i in range_constexpr(pre_loaded): + addr_box = [active_addr_lo] + if const_expr(two_wave_bscale): + sec_box = [active_sec_lo] + _issue_active_tdm(i, addr_box, sec_box=sec_box) + active_sec_lo = sec_box[0] + else: + _issue_active_tdm(i, addr_box) + active_addr_lo = addr_box[0] if const_expr(_bvs_active and loop_iters > 0): # Prologue: prefetch the first _bvs_D K-tiles (global->VGPR). Carried as # FLAT lists of i32 (list-of-tuples can't be loop-carried). Only when the # main loop runs; a tail-only problem (loop_iters == 0) loads inline. _bvs_pf = [_bvs_prefetch(split_k_base + arith.index(_d * tile_k)) for _d in range(_bvs_D)] - _bvs_ra = [_v for (_a, _b) in _bvs_pf for _v in _a] - _bvs_rb = [_v for (_a, _b) in _bvs_pf for _v in _b] + _bvs_ra = [_v for _a in _bvs_pf for _v in _a] _pipeline_fence(outstanding=TDM_LOADS_PER_STEP * (num_buffers - 2)) @@ -2493,175 +2186,96 @@ def _issue_active_tdm(load_stage, addr_box, k_prefetch=None, sec_box=None): _pipeline_fence_signal(outstanding=_fence_outstanding) if const_expr(loop_iters > 0): - if const_expr(wave_specialized_tdm): - init_args = list(accs) + [active_addr_lo] - if const_expr(natural_two_wave): - init_args = init_args + [active_sec_lo] + init_args = list(accs) + [active_addr_lo] + if const_expr(two_wave_bscale): + init_args = init_args + [active_sec_lo] + if const_expr(_bvs_active): + init_args = init_args + _bvs_ra + + for loop_iter, state in range(0, loop_iters, 1, init=init_args): + accs_in = list(state[:n_accs]) + cur_addr_lo = state[n_accs] + _state_off = n_accs + 1 + if const_expr(two_wave_bscale): + cur_sec_lo = state[_state_off] + _state_off = _state_off + 1 if const_expr(_bvs_active): - init_args = init_args + _bvs_ra + _bvs_rb - - for loop_iter, state in range(0, loop_iters, 1, init=init_args): - accs_in = list(state[:n_accs]) - cur_addr_lo = state[n_accs] - _state_off = n_accs + 1 - if const_expr(natural_two_wave): - cur_sec_lo = state[_state_off] - _state_off = _state_off + 1 - if const_expr(_bvs_active): - _ra0 = _state_off - _ring_a = list(state[_ra0 : _ra0 + _bvs_D * _vs_tile_a]) - _rb0 = _ra0 + _bvs_D * _vs_tile_a - _ring_b = list(state[_rb0 : _rb0 + _bvs_D * _vs_tile_b]) - _state_off = _rb0 + _bvs_D * _vs_tile_b - - for buf_idx in range_constexpr(num_buffers): - load_stage = (buf_idx + num_buffers - 1) % num_buffers - - addr_box = [cur_addr_lo] - sec_box = [cur_sec_lo] if natural_two_wave else None - - def _mid_tdm_ws( - _ls=load_stage, - _ab=addr_box, - _sb=sec_box, - _k_off=( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index(buf_idx * tile_k) - ), - ): - _issue_active_tdm(_ls, _ab, k_prefetch=_k_off, sec_box=_sb) - - if const_expr(not use_ws_tdm_split_signal_overlap): - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) - - _late_tdm_ws_fence_signal = None - if const_expr(use_ws_tdm_split_signal_overlap): + _ra0 = _state_off + _ring_a = list(state[_ra0 : _ra0 + _bvs_D * _vs_tile_a]) + _state_off = _ra0 + _bvs_D * _vs_tile_a + + for buf_idx in range_constexpr(num_buffers): + load_stage = (buf_idx + num_buffers - 1) % num_buffers + + addr_box = [cur_addr_lo] + sec_box = [cur_sec_lo] if two_wave_bscale else None + + def _mid_tdm_ws( + _ls=load_stage, + _ab=addr_box, + _sb=sec_box, + _k_off=( + split_k_base + loop_iter * arith.index(num_buffers * tile_k) + arith.index(buf_idx * tile_k) + ), + ): + _issue_active_tdm(_ls, _ab, k_prefetch=_k_off, sec_box=_sb) + + if const_expr(not use_ws_tdm_split_signal_overlap): + _pipeline_fence_signal(outstanding=_fence_outstanding) + pipeline_fence_wait(use_cluster=use_cluster) - def _late_tdm_ws_split_signal(): - _pipeline_fence_signal(outstanding=_fence_outstanding) + _late_tdm_ws_fence_signal = None + if const_expr(use_ws_tdm_split_signal_overlap): - _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal + def _late_tdm_ws_split_signal(): + _pipeline_fence_signal(outstanding=_fence_outstanding) - a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) - # Consume scale prefetched _bvs_D K-tiles ago; issue the - # K-tile +_bvs_D prefetch now (overlaps this tile's WMMAs). - # NOTE: must stay AFTER the fence; issuing the scale - # buffer_loads before the cluster barrier hangs the vgpr path. - if const_expr(_bvs_active): - _cur_a = _ring_a[:_vs_tile_a] - _cur_b = _ring_b[:_vs_tile_b] - _next_kb = ( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index((buf_idx + _bvs_D) * tile_k) - ) - _na, _nb2 = _bvs_prefetch(_next_kb) - _ring_a = _ring_a[_vs_tile_a:] + list(_na) - _ring_b = _ring_b[_vs_tile_b:] + list(_nb2) - else: - _cur_a = None - _cur_b = None - - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_ws, - late_compute_callback=_late_tdm_ws_fence_signal, - a0_prefetch=a0_prefetch, - pf_a_scales=_cur_a, - pf_b_scales=_cur_b, - ) - cur_addr_lo = addr_box[0] - if const_expr(natural_two_wave): - cur_sec_lo = sec_box[0] - hot_loop_scheduler_scheduled() + _late_tdm_ws_fence_signal = _late_tdm_ws_split_signal + a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) + rocdl.sched_barrier(0) + # Consume scale prefetched _bvs_D K-tiles ago; issue the + # K-tile +_bvs_D prefetch now (overlaps this tile's WMMAs). + # NOTE: must stay AFTER the fence; issuing the scale + # buffer_loads before the cluster barrier hangs the vgpr path. if const_expr(_bvs_active): - _bvs_yield = _ring_a + _ring_b - else: - _bvs_yield = [] - _sec_yield = [cur_sec_lo] if natural_two_wave else [] - results = yield list(accs_in) + [cur_addr_lo] + _sec_yield + _bvs_yield - - accs = list(results[:n_accs]) - active_addr_lo = results[n_accs] - if const_expr(natural_two_wave): - active_sec_lo = results[n_accs + 1] - else: - init_args = list(accs) + [addr_lo_a, addr_lo_b, addr_lo_as, addr_lo_bs] - - for loop_iter, state in range(0, loop_iters, 1, init=init_args): - accs_in = list(state[:n_accs]) - cur_lo_a = state[n_accs] - cur_lo_b = state[n_accs + 1] - cur_lo_as = state[n_accs + 2] - cur_lo_bs = state[n_accs + 3] - - for buf_idx in range_constexpr(num_buffers): - load_stage = (buf_idx + num_buffers - 1) % num_buffers - - _pipeline_fence_signal(outstanding=_fence_outstanding) - pipeline_fence_wait(use_cluster=use_cluster) - - addr_boxes = [[cur_lo_a], [cur_lo_b], [cur_lo_as], [cur_lo_bs]] - - def _mid_tdm_nws( - _ls=load_stage, - _ab=addr_boxes, - _k_off=( - split_k_base - + loop_iter * arith.index(num_buffers * tile_k) - + arith.index(buf_idx * tile_k) - ), - ): - dg0_a = _pack_dg0(pred_const, stages_a_lds_addr[_ls], _ab[0][0], addr_hi_a) - dg0_b = _pack_dg0(pred_const, stages_b_lds_addr[_ls], _ab[1][0], addr_hi_b) - dg0_as = _pack_dg0(pred_const, stages_as_lds_addr[_ls], _ab[2][0], addr_hi_as) - dg0_bs = _pack_dg0(pred_const, stages_bs_lds_addr[_ls], _ab[3][0], addr_hi_bs) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), - tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), - wave_specialized=wave_specialized_tdm, - ) - _ab[0][0] = _ab[0][0] + adv_a_i32 - _ab[1][0] = _ab[1][0] + adv_b_i32 - _ab[2][0] = _ab[2][0] + adv_as_i32 - _ab[3][0] = _ab[3][0] + adv_bs_i32 - _l2_prefetch(_k_off) - - a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[buf_idx]) - rocdl.sched_barrier(0) - accs_in = compute_tile_scheduled( - accs_in, - stages_a_idx[buf_idx], - stages_b_idx[buf_idx], - stages_as_idx[buf_idx], - stages_bs_idx[buf_idx], - mid_compute_callback=_mid_tdm_nws, - a0_prefetch=a0_prefetch, + _cur_a = _ring_a[:_vs_tile_a] + _next_kb = ( + split_k_base + + loop_iter * arith.index(num_buffers * tile_k) + + arith.index((buf_idx + _bvs_D) * tile_k) ) - cur_lo_a = addr_boxes[0][0] - cur_lo_b = addr_boxes[1][0] - cur_lo_as = addr_boxes[2][0] - cur_lo_bs = addr_boxes[3][0] - hot_loop_scheduler_scheduled() - - results = yield list(accs_in) + [cur_lo_a, cur_lo_b, cur_lo_as, cur_lo_bs] - - accs = list(results[:n_accs]) - addr_lo_a = results[n_accs] - addr_lo_b = results[n_accs + 1] - addr_lo_as = results[n_accs + 2] - addr_lo_bs = results[n_accs + 3] + _ring_a = _ring_a[_vs_tile_a:] + list(_bvs_prefetch(_next_kb)) + else: + _cur_a = None + + accs_in = compute_tile_scheduled( + accs_in, + stages_a_idx[buf_idx], + stages_b_idx[buf_idx], + stages_as_idx[buf_idx], + stages_bs_idx[buf_idx], + mid_compute_callback=_mid_tdm_ws, + late_compute_callback=_late_tdm_ws_fence_signal, + a0_prefetch=a0_prefetch, + pf_a_scales=_cur_a, + ) + cur_addr_lo = addr_box[0] + if const_expr(two_wave_bscale): + cur_sec_lo = sec_box[0] + hot_loop_scheduler_scheduled() + if const_expr(_bvs_active): + _bvs_yield = _ring_a + else: + _bvs_yield = [] + _sec_yield = [cur_sec_lo] if two_wave_bscale else [] + results = yield list(accs_in) + [cur_addr_lo] + _sec_yield + _bvs_yield + + accs = list(results[:n_accs]) + active_addr_lo = results[n_accs] + if const_expr(two_wave_bscale): + active_sec_lo = results[n_accs + 1] # Tail — same acc_mixed pattern: fence at top, TDM mid-compute. if const_expr(loop_iters > 0 and use_ws_tdm_split_signal_overlap): pipeline_fence_wait(use_cluster=use_cluster) @@ -2687,10 +2301,10 @@ def _bvs_tail_kb(): _bvs_tail_kt[0] += 1 return kb - # Natural A-scale: prefetch the tail's scales _bvs_D K-tiles ahead so each + # A-scale: prefetch the tail's scales _bvs_D K-tiles ahead so each # scale buffer_load overlaps an earlier tile's TDM wait instead of stalling # the WMMA inline. - _bvs_tail_pf = use_natural_ascale + _bvs_tail_pf = is_mxscale _bvs_tail_ring = [] _bvs_tail_issue_kt = [loop_iters * num_buffers] @@ -2701,12 +2315,11 @@ def _bvs_tail_issue_one(): _bvs_tail_issue_kt[0] += 1 def _bvs_tail_scales(): - # Per-tile (scale_k_base, pf_a_scales, pf_b_scales): consume the prefetch - # ring on the natural path, else fall back to the inline-load k_base. + # Per-tile (scale_k_base, pf_a_scales): consume the A-scale prefetch ring + # for mxscale, else fall back to the inline-load k_base. if const_expr(_bvs_tail_pf): - _cur_a, _cur_b = _bvs_tail_ring.pop(0) - return None, _cur_a, _cur_b - return _bvs_tail_kb(), None, None + return None, _bvs_tail_ring.pop(0) + return _bvs_tail_kb(), None if const_expr(_bvs_tail_pf): # Prime the ring before the first tail fence so even tile 0's scale @@ -2716,7 +2329,7 @@ def _bvs_tail_scales(): _bvs_tail_issue_one() for _load_stage, _compute_stage, _outstanding in tail_plan: - _entry_kb, _pf_a_scales, _pf_b_scales = _bvs_tail_scales() + _entry_kb, _pf_a_scales = _bvs_tail_scales() if const_expr(_outstanding == -1): if const_expr(_tail_had_load): _pipeline_fence(outstanding=0) @@ -2732,7 +2345,6 @@ def _bvs_tail_scales(): a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, pf_a_scales=_pf_a_scales, - pf_b_scales=_pf_b_scales, ) else: @@ -2751,7 +2363,6 @@ def _emit_epi_addrs(): a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, pf_a_scales=_pf_a_scales, - pf_b_scales=_pf_b_scales, ) else: _pipeline_fence_signal(outstanding=_outstanding) @@ -2760,35 +2371,13 @@ def _emit_epi_addrs(): _tail_mid_cb = None if const_expr(_load_stage is not None): _tail_had_load = True - if const_expr(wave_specialized_tdm): - _tail_addr_box = [active_addr_lo] - _tail_sec_box = [active_sec_lo] if natural_two_wave else None + _tail_addr_box = [active_addr_lo] + _tail_sec_box = [active_sec_lo] if two_wave_bscale else None - def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box, _sb=_tail_sec_box): - _issue_active_tdm(_ls, _ab, sec_box=_sb) + def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box, _sb=_tail_sec_box): + _issue_active_tdm(_ls, _ab, sec_box=_sb) - _tail_mid_cb = _tail_mid_ws - else: - _tail_ab = [[addr_lo_a], [addr_lo_b], [addr_lo_as], [addr_lo_bs]] - - def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): - dg0_a = _pack_dg0(pred_const, stages_a_lds_addr[_ls], _ab[0][0], addr_hi_a) - dg0_b = _pack_dg0(pred_const, stages_b_lds_addr[_ls], _ab[1][0], addr_hi_b) - dg0_as = _pack_dg0(pred_const, stages_as_lds_addr[_ls], _ab[2][0], addr_hi_as) - dg0_bs = _pack_dg0(pred_const, stages_bs_lds_addr[_ls], _ab[3][0], addr_hi_bs) - issue_tdm_loads( - tdm_ops.TDMDescriptor2D(dg0_a, dgroup1_a), - tdm_ops.TDMDescriptor2D(dg0_b, dgroup1_b), - tdm_ops.TDMDescriptor2D(dg0_as, dgroup1_as), - tdm_ops.TDMDescriptor2D(dg0_bs, dgroup1_bs), - wave_specialized=wave_specialized_tdm, - ) - _ab[0][0] = _ab[0][0] + adv_a_i32 - _ab[1][0] = _ab[1][0] + adv_b_i32 - _ab[2][0] = _ab[2][0] + adv_as_i32 - _ab[3][0] = _ab[3][0] + adv_bs_i32 - - _tail_mid_cb = _tail_mid_nws + _tail_mid_cb = _tail_mid_ws a0_prefetch = maybe_prefetch_fp8_deep_a0(stages_a_idx[_compute_stage]) rocdl.sched_barrier(0) @@ -2803,19 +2392,12 @@ def _tail_mid_nws(_ls=_load_stage, _ab=_tail_ab): a0_prefetch=a0_prefetch, scale_k_base=_entry_kb, pf_a_scales=_pf_a_scales, - pf_b_scales=_pf_b_scales, ) if const_expr(_load_stage is not None): - if const_expr(wave_specialized_tdm): - active_addr_lo = _tail_addr_box[0] - if const_expr(natural_two_wave): - active_sec_lo = _tail_sec_box[0] - else: - addr_lo_a = _tail_ab[0][0] - addr_lo_b = _tail_ab[1][0] - addr_lo_as = _tail_ab[2][0] - addr_lo_bs = _tail_ab[3][0] + active_addr_lo = _tail_addr_box[0] + if const_expr(two_wave_bscale): + active_sec_lo = _tail_sec_box[0] hot_loop_scheduler_scheduled() @@ -2874,11 +2456,9 @@ def _emit_buffer_store(): use_tdm_store, out_dtype, inst_prefetch, - wave_specialized_tdm, split_k, expert_sched_mode, atomic_barrier_enable, - ascale_load_path, fp8_schedule, ) @@ -2985,12 +2565,11 @@ def compile_ptpc_gemm( the epilogue in fp32. split_k>1 is supported (atomic add path). data_format: "fp8" (FP8 act + FP8 weight) or "a8w4" (FP8 act + FP4 weight). - wave_specialized_tdm=True requires m_warp*n_warp >= 2. + Requires m_warp*n_warp >= 2 (wave-specialized TDM). """ return compile_fp8fp4_gemm( data_format=data_format, scale_mode="ptpc", - wave_specialized_tdm=True, fp8_schedule="auto", use_tdm_store=(split_k == 1), N=N, diff --git a/tests/kernels/test_gemm_fp8fp4_gfx1250.py b/tests/kernels/test_gemm_fp8fp4_gfx1250.py index cb689ad2..7dcaff02 100644 --- a/tests/kernels/test_gemm_fp8fp4_gfx1250.py +++ b/tests/kernels/test_gemm_fp8fp4_gfx1250.py @@ -28,8 +28,6 @@ _is_fp8_deep_pipeline, compile_mxscale_gemm, compile_ptpc_gemm, - use_n4k4_bscale_layout, - use_natural_ascale_vgpr, ) from tests.kernels.utils import fp4_utils # noqa: E402 @@ -40,33 +38,6 @@ SCALE_BLOCK = 32 -def preshuffle_e8m0_scale( - scale: torch.Tensor, - warp_tile: int, - scale_k_per_tile: int = 4, - WMMA_DIM: int = 16, - row_align: int = None, -) -> torch.Tensor: - """Preshuffle E8M0 scale: byte swap + interleave for WMMA TDM/LDS access.""" - rows, K_scale = scale.shape - assert K_scale % 4 == 0, f"K_scale must be divisible by 4, got {K_scale}" - # Accept an unpadded row count (M for a_scale / N for b_scale): pad rows to - # row_align (the GEMM reads tile_m-granular tiles, so callers pass row_align=tile_m) - # with E8M0 127 (=1.0). Padding rows feed only discarded output rows. No-op when - # already aligned. Defaults to warp_tile (the minimum the reshape needs). - align = row_align if row_align is not None else warp_tile - if rows % align != 0: - pad = _align_up(rows, align) - rows - scale = torch.cat([scale, torch.full((pad, K_scale), 127, dtype=scale.dtype, device=scale.device)], dim=0) - SCALES_PER_WMMA = 4 - wmma_rep = warp_tile // WMMA_DIM - k_groups = K_scale // scale_k_per_tile - k_wmma_steps = scale_k_per_tile // SCALES_PER_WMMA - g = scale.view(-1, wmma_rep, WMMA_DIM, k_groups, k_wmma_steps, SCALES_PER_WMMA) - g = g.permute(0, 2, 3, 4, 1, 5).contiguous() - return g.reshape(-1, k_groups * k_wmma_steps * wmma_rep * SCALES_PER_WMMA) - - def preshuffle_scale(scale: torch.Tensor) -> torch.Tensor: """32x4 scale layout (A or B): [R, Ks] -> [R//32, K] (Ks = K//32). @@ -82,11 +53,6 @@ def preshuffle_scale(scale: torch.Tensor) -> torch.Tensor: return x.reshape(R // 32, -1) # [R//32, K] -def preshuffle_scale_for_load_path(scale, warp_tile, skt, *, row_align=None): - """Host scale preshuffle for the TDM/LDS interleaved layout.""" - return preshuffle_e8m0_scale(scale, warp_tile, scale_k_per_tile=skt, row_align=row_align) - - def random_fp8_data(rows: int, cols: int, *, device="cpu") -> torch.Tensor: """Generate random FP8/E4M3 data as uint8. Avoids NaN (0x7F/0xFF).""" return torch.randint(0, 126, (rows, cols), dtype=torch.uint8, device=device) @@ -360,7 +326,6 @@ def _run_mxscale_gemm_test( num_buffers, use_tdm_store, out_dtype, - wave_specialized_tdm=False, l2_prefetch_distance=0, cluster_m=1, cluster_n=1, @@ -368,7 +333,6 @@ def _run_mxscale_gemm_test( waves_per_eu=None, expert_sched_mode=True, split_k=1, - ascale_load_path="vgpr", return_launch_fn=False, ): """Unified test body for FP4 and FP8.""" @@ -442,28 +406,8 @@ def _run_mxscale_gemm_test( a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) - # Preshuffle scales - skt = tile_k // SCALE_BLOCK - warp_tile_m = tile_m // m_warp - _natural_ascale = use_natural_ascale_vgpr( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - n=padded_n, - ascale_load_path=ascale_load_path, - wave_specialized_tdm=wave_specialized_tdm, - num_buffers=num_buffers, - out_dtype=out_dtype, - ) - if _natural_ascale: - # Natural path reads A_scale[M, K//32] straight from VGPRs -- no reshuffle, - # the (already row-padded) tensor is uploaded as-is. - pass - else: - a_scale = preshuffle_scale_for_load_path(a_scale, warp_tile_m, skt, row_align=tile_m) + # mxscale scales: A-scale is always read from its natural A_scale[M, K//32] layout + # straight into VGPRs (no reshuffle -- upload as-is); B-scale is always 32x4. b_scale = preshuffle_scale(b_scale) # Preshuffle B data @@ -494,10 +438,8 @@ def _run_mxscale_gemm_test( use_tdm_store=use_tdm_store, out_dtype=kernel_out_dtype, inst_prefetch=inst_prefetch, - wave_specialized_tdm=wave_specialized_tdm, split_k=split_k, expert_sched_mode=expert_sched_mode, - ascale_load_path=ascale_load_path, ) # Keep 2D — dynamic_layout=True packs shape as i32; flattening overflows for M*K >= 2^31. @@ -615,7 +557,6 @@ def _extract_i64_metadata(compiled_ir: str, key: str) -> int: ) @pytest.mark.parametrize("num_buffers", [2, 3, 4]) @pytest.mark.parametrize("use_tdm_store", [True, False]) -@pytest.mark.parametrize("wave_specialized_tdm", [True, False]) @pytest.mark.parametrize("out_dtype", ["f32", "bf16"]) def test_mxfp4_gemm( M, @@ -629,7 +570,6 @@ def test_mxfp4_gemm( num_buffers, use_tdm_store, out_dtype, - wave_specialized_tdm, ): _run_mxscale_gemm_test( "fp4", @@ -644,7 +584,6 @@ def test_mxfp4_gemm( num_buffers, use_tdm_store, out_dtype, - wave_specialized_tdm=wave_specialized_tdm, ) @@ -770,65 +709,61 @@ def test_a8w4_gemm_irregular_m_tile16(M, N, K, use_tdm_store): ) -# ── Tile-independent N4K4 B-scale coverage ── +# ── Tile-independent 32x4 B-scale coverage ── # tile_m=16, m_warp=1 -> wmma_m_rep=1 (odd) -> the default row-major streaming -# schedule, which is the (phase-1) N4K4 B-scale path. The sweep covers every +# schedule, exercising the 32x4 B-scale path. The sweep covers every # tile_n/n_warp that maps to a distinct read shape (b32/b64/b128 per_load and # group counts 1/2/4 and the non-power-of-2 group count 3 that exercises the # TDM warp-distribution power-of-two padding), both data formats, k_wmma_steps # 1/2/4, wave-spec on/off, f32/bf16, multi-buffer, and ragged/decode M. -_N4K4_N_FOR_TN = {32: 128, 64: 128, 128: 256, 192: 384, 256: 512} -_N4K4_TN_NW = [ - (32, 1), (32, 2), (64, 1), (64, 2), (64, 4), - (128, 1), (128, 2), (128, 4), (192, 1), (192, 2), (192, 4), - (256, 1), (256, 2), (256, 4), -] # fmt: skip - - -def _gen_n4k4_configs(): +_BS32_N_FOR_TN = {32: 128, 64: 128, 128: 256, 192: 384, 256: 512} +_BS32_TN_NW = [ + (32, 2), + (64, 2), + (64, 4), + (128, 2), + (128, 4), + (192, 2), + (192, 4), + (256, 2), + (256, 4), +] # fmt: skip (n_warp>=2: wave-specialized TDM requires >=2 waves) + + +def _gen_bs32_configs(): cfgs, seen = [], set() - def add(fmt, M, tile_n, n_warp, tile_k, nbuf, od, ws): - N = _N4K4_N_FOR_TN[tile_n] + def add(fmt, M, tile_n, n_warp, tile_k, nbuf, od): + N = _BS32_N_FOR_TN[tile_n] K = tile_k * max(nbuf, 2) # >= nbuf K-tiles for double/triple buffering - key = (fmt, M, N, K, tile_n, tile_k, n_warp, nbuf, od, ws) + key = (fmt, M, N, K, tile_n, tile_k, n_warp, nbuf, od) if key not in seen: seen.add(key) cfgs.append(key) for fmt in ("fp8", "a8w4"): - # 1) full tile_n x n_warp shape sweep (all rep/group/per_load cases), - # non-wave-spec so the cooperative TDM warp distribution is exercised. - for tn, nw in _N4K4_TN_NW: - add(fmt, 16, tn, nw, 256, 2, "bf16", False) - # 2) wave-spec (needs >=4 waves -> n_warp=4), M=1 decode-like. The real - # decode shape (tile_n=64) uses deep K + 4 buffers; larger tile_n keeps - # a modest tile so LDS fits while still exercising the wave-spec TDM. - add(fmt, 1, 64, 4, 512, 4, "bf16", True) + # 1) full tile_n x n_warp shape sweep (all rep/group/per_load cases). + for tn, nw in _BS32_TN_NW: + add(fmt, 16, tn, nw, 256, 2, "bf16") + # 2) M=1 decode-like. The real decode shape (tile_n=64) uses deep K + 4 buffers. + add(fmt, 1, 64, 4, 512, 4, "bf16") for tn in (128, 192, 256): - add(fmt, 1, tn, 4, 256, 2, "bf16", True) + add(fmt, 1, tn, 4, 256, 2, "bf16") # 3) k_wmma_steps 1/2/4 on the next_pow2 (192) and clean (256/64) shapes. for tn, nw in [(192, 4), (256, 4), (64, 4)]: for tk in (128, 512): - add(fmt, 16, tn, nw, tk, 2, "bf16", False) + add(fmt, 16, tn, nw, tk, 2, "bf16") # 4) f32 + triple buffering on a few shapes. for tn, nw in [(192, 4), (128, 2), (32, 2)]: - add(fmt, 16, tn, nw, 256, 3, "f32", False) + add(fmt, 16, tn, nw, 256, 3, "f32") # 5) ragged / decode / OOB M. for M in (1, 13, 33): - add(fmt, M, 256, 4, 256, 2, "bf16", False) + add(fmt, M, 256, 4, 256, 2, "bf16") return cfgs -@pytest.mark.parametrize( - "data_format, M, N, K, tile_n, tile_k, n_warp, num_buffers, out_dtype, ws", _gen_n4k4_configs() -) -def test_mxscale_n4k4_bscale(data_format, M, N, K, tile_n, tile_k, n_warp, num_buffers, out_dtype, ws): - # Guard: every config here must actually take the N4K4 B-scale layout, else - # the sweep would silently test the legacy path instead. - assert use_n4k4_bscale_layout( - data_format=data_format, tile_m=16, tile_n=tile_n, tile_k=tile_k, m_warp=1, n_warp=n_warp, n=N - ), f"config does not hit the N4K4 gate: {(data_format, tile_n, tile_k, n_warp, N)}" +@pytest.mark.parametrize("data_format, M, N, K, tile_n, tile_k, n_warp, num_buffers, out_dtype", _gen_bs32_configs()) +def test_mxscale_bscale_32x4(data_format, M, N, K, tile_n, tile_k, n_warp, num_buffers, out_dtype): _run_mxscale_gemm_test( data_format, M, @@ -842,7 +777,6 @@ def test_mxscale_n4k4_bscale(data_format, M, N, K, tile_n, tile_k, n_warp, num_b num_buffers, use_tdm_store=True, out_dtype=out_dtype, - wave_specialized_tdm=ws, l2_prefetch_distance=0, ) @@ -879,19 +813,8 @@ def _gen_natural_ascale_configs(): @pytest.mark.parametrize("data_format, M, tile_m, tile_n, tile_k, m_warp, n_warp, nbuf", _gen_natural_ascale_configs()) def test_mxscale_natural_ascale(data_format, M, tile_m, tile_n, tile_k, m_warp, n_warp, nbuf): - # Guard: every config must take BOTH the natural A-scale path and N4K4 B-scale. N = 2 * tile_n K = tile_k * nbuf - assert use_natural_ascale_vgpr( - data_format=data_format, - tile_m=tile_m, - tile_n=tile_n, - tile_k=tile_k, - m_warp=m_warp, - n_warp=n_warp, - n=N, - wave_specialized_tdm=True, - ), f"config does not hit the natural A-scale gate: {(data_format, tile_m, tile_n, tile_k, m_warp, n_warp)}" _run_mxscale_gemm_test( data_format, M, @@ -905,16 +828,14 @@ def test_mxscale_natural_ascale(data_format, M, tile_m, tile_n, tile_k, m_warp, nbuf, use_tdm_store=True, out_dtype="bf16", - wave_specialized_tdm=True, l2_prefetch_distance=0, ) -@pytest.mark.parametrize("ascale_load_path", ["vgpr", "tdm"]) @pytest.mark.parametrize("data_format", ["fp8", "a8w4"]) -def test_mxscale_deep_pipeline(data_format, ascale_load_path): - # Deep-pipeline (fixed 256x256x128 / nbuf4 / wave-spec): 32x4 B-scale + A-scale - # via natural VGPR (default) or 32x4 TDM. Guard: must hit the deep schedule. +def test_mxscale_deep_pipeline(data_format): + # Deep-pipeline (fixed 256x256x128 / nbuf4): 32x4 B-scale + A-scale via natural + # VGPR ring. Guard: must hit the deep schedule. assert _is_fp8_deep_pipeline( data_format=data_format, tile_m=256, @@ -924,7 +845,6 @@ def test_mxscale_deep_pipeline(data_format, ascale_load_path): n_warp=2, num_buffers=4, out_dtype="bf16", - wave_specialized_tdm=True, fp8_schedule="auto", ), "config does not hit the deep-pipeline schedule" _run_mxscale_gemm_test( @@ -940,9 +860,7 @@ def test_mxscale_deep_pipeline(data_format, ascale_load_path): 4, use_tdm_store=True, out_dtype="bf16", - wave_specialized_tdm=True, l2_prefetch_distance=0, - ascale_load_path=ascale_load_path, ) @@ -1020,11 +938,9 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ a_scale = fp4_utils.random_e8m0(M, K // SCALE_BLOCK) b_scale = fp4_utils.random_e8m0(N, K // SCALE_BLOCK) - skt = tile_k // SCALE_BLOCK - warp_tile_m = tile_m // m_warp - warp_tile_n = tile_n // n_warp - a_scale_ps = preshuffle_e8m0_scale(a_scale, warp_tile_m, scale_k_per_tile=skt) - b_scale_ps = preshuffle_e8m0_scale(b_scale, warp_tile_n, scale_k_per_tile=skt) + # A-scale natural (as-is, VGPR ring); B-scale 32x4. + a_scale_ps = a_scale + b_scale_ps = preshuffle_scale(b_scale) pack_b = 2 if is_fp4 else 1 b_ps = fp4_utils.preshuffle_b_16x16(b, N, K // pack_b) @@ -1046,7 +962,6 @@ def test_mxscale_gemm_cudagraph(data_format, M, N, K, tile_m, tile_n, tile_k, m_ num_buffers=2, use_tdm_store=True, out_dtype="bf16", - wave_specialized_tdm=False, split_k=1, ) @@ -1595,11 +1510,10 @@ def _run_mxscale_mpad( a_scale = fp4_utils.random_e8m0(M, K // SCALE_BLOCK) # real M, unpadded b_scale = fp4_utils.random_e8m0(N, K // SCALE_BLOCK) ref = reference_mxfp8_gemm(a, b, a_scale, b_scale, M, N, K) - skt = tile_k // SCALE_BLOCK - # a_scale stays UNPADDED host-side; preshuffle pads rows to tile_m (the GEMM - # reads tile_m-granular scale tiles for the partial last M-tile). N is aligned. - as_ps = preshuffle_e8m0_scale(a_scale, tile_m // m_warp, scale_k_per_tile=skt, row_align=tile_m) - bs_ps = preshuffle_e8m0_scale(b_scale, tile_n // n_warp, scale_k_per_tile=skt) + # A-scale natural (as-is, unpadded): the VGPR ring reads A_scale[M, K//32] and + # OOB rows of the partial last M-tile read 0 via buffer bounds. B-scale 32x4. + as_ps = a_scale + bs_ps = preshuffle_scale(b_scale) b_ps = fp4_utils.preshuffle_b_16x16(b, N, K) c_gpu = torch.zeros(M, N, dtype=_DT[out_dtype], device="cuda") # real M launch = compile_mxscale_gemm( @@ -1825,8 +1739,6 @@ def _run_benchmark(args): _ptpc_ignored = [] if args.no_tdm_store: _ptpc_ignored.append("--no-tdm-store") - if not args.wave_spec_tdm: - _ptpc_ignored.append("--no-wave-spec-tdm") if _ptpc_ignored: print(f" Note: PTPC ignores (forced internally): {', '.join(_ptpc_ignored)}") print("=" * 72) @@ -1875,9 +1787,7 @@ def _run_benchmark(args): a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) - skt = tile_k // SCALE_BLOCK - a_scale = preshuffle_scale_for_load_path(a_scale, warp_tile_m, skt, row_align=tile_m) - b_scale = preshuffle_scale_for_load_path(b_scale, warp_tile_n, skt, row_align=tile_n) + b_scale = preshuffle_scale(b_scale) K_packed = padded_k // PACK_B b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -1934,7 +1844,6 @@ def _run_benchmark(args): use_tdm_store=use_tdm_store, out_dtype=kernel_out_dtype, inst_prefetch=args.inst_prefetch, - wave_specialized_tdm=args.wave_spec_tdm, split_k=args.split_k, expert_sched_mode=args.expert_sched_mode, atomic_barrier_enable=args.atomic_barrier_enable, @@ -2151,11 +2060,7 @@ def _run_graph_verify(args): a, b, a_scale, b_scale = _pad_mxscale_inputs(a, b, a_scale, b_scale, padded_shape) - skt = tile_k // SCALE_BLOCK - warp_tile_m = tile_m // args.m_warp - warp_tile_n = tile_n // args.n_warp - a_scale = preshuffle_scale_for_load_path(a_scale, warp_tile_m, skt, row_align=tile_m) - b_scale = preshuffle_scale_for_load_path(b_scale, warp_tile_n, skt, row_align=tile_n) + b_scale = preshuffle_scale(b_scale) K_packed = padded_k // padded_shape["pack_b"] b = fp4_utils.preshuffle_b_16x16(b, padded_n, K_packed) @@ -2186,7 +2091,6 @@ def _run_graph_verify(args): use_tdm_store=use_tdm_store, out_dtype=kernel_out_dtype, inst_prefetch=args.inst_prefetch, - wave_specialized_tdm=args.wave_spec_tdm, split_k=args.split_k, expert_sched_mode=args.expert_sched_mode, atomic_barrier_enable=args.atomic_barrier_enable, @@ -2298,7 +2202,6 @@ def launch(): parser.add_argument("--no-tdm-store", action="store_true", default=False) parser.add_argument("--out-dtype", type=str, default="bf16", choices=["f32", "bf16", "f16"]) parser.add_argument("--inst-prefetch", action="store_true", default=False) - parser.add_argument("--no-wave-spec-tdm", dest="wave_spec_tdm", action="store_false", default=True) parser.add_argument("--waves-per-eu", type=int, default=None) parser.add_argument("--disable-expert-sched-mode", dest="expert_sched_mode", action="store_false", default=True) parser.add_argument( @@ -2396,7 +2299,6 @@ def _run_correctness_test(): num_buffers=args.num_buffers, use_tdm_store=use_tdm_store, out_dtype=args.out_dtype, - wave_specialized_tdm=args.wave_spec_tdm, split_k=args.split_k, l2_prefetch_distance=args.l2_prefetch_distance, cluster_m=args.cluster_m, From 7f782e4e989e40cc7d84cef36d77ee968214e3fa Mon Sep 17 00:00:00 2001 From: aoli26 Date: Tue, 16 Jun 2026 15:44:25 +0000 Subject: [PATCH 16/16] fp4 row-majtors (drop bank-group permutation) --- kernels/gemm_fp8fp4_gfx1250.py | 109 ++++++++++++--------------------- 1 file changed, 40 insertions(+), 69 deletions(-) diff --git a/kernels/gemm_fp8fp4_gfx1250.py b/kernels/gemm_fp8fp4_gfx1250.py index 99d64b33..ef0643b5 100644 --- a/kernels/gemm_fp8fp4_gfx1250.py +++ b/kernels/gemm_fp8fp4_gfx1250.py @@ -433,7 +433,7 @@ def _align_up(value: int, align: int) -> int: _sub_tiles.append((acc_idx, 0, m_off, n_sub)) COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING = "row_major_streaming" - COMPUTE_SCHEDULE_FP4_COL_BAND = "fp4_col_band" + COMPUTE_SCHEDULE_FP4_QUADRANT = "fp4_quadrant" COMPUTE_SCHEDULE_FP8_QUADRANT = "fp8_quadrant" COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE = "fp8_deep_pipeline" @@ -455,12 +455,10 @@ def _align_up(value: int, align: int) -> int: def _pick_compute_schedule_kind(): if wmma_m_rep % 2 != 0 or wmma_n_rep % 2 != 0 or n_accs < 8: return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING - # Quadrant schedules split B into left/right halves and compute - # top-left, bottom-left, top-right, bottom-right. FP4 additionally - # changes accumulator layout for bank friendliness; FP8 keeps row-major - # accumulators and uses the split to increase LDS-load-to-WMMA distance. + # Quadrant: split B left/right, compute the 4 quadrants to widen the + # LDS-load-to-WMMA distance. FP4/FP8 differ only in per-format wait tuning. if is_fp4: - return COMPUTE_SCHEDULE_FP4_COL_BAND + return COMPUTE_SCHEDULE_FP4_QUADRANT # A8W4 (FP8 act + FP4 weight) shares FP8's accumulator layout and operand # path, so it reuses the FP8 schedules. if data_format in ("fp8", "a8w4"): @@ -470,7 +468,7 @@ def _pick_compute_schedule_kind(): return COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING compute_schedule_kind = _pick_compute_schedule_kind() - use_fp4_bank_friendly_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND + use_fp4_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP4_QUADRANT use_fp8_quadrant_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT use_fp8_deep_pipeline_schedule = compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE @@ -479,29 +477,16 @@ def _pick_compute_schedule_kind(): COMPUTE_SCHEDULE_ROW_MAJOR_STREAMING, COMPUTE_SCHEDULE_FP8_QUADRANT, COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE, - COMPUTE_SCHEDULE_FP4_COL_BAND, + COMPUTE_SCHEDULE_FP4_QUADRANT, ) use_ws_tdm_split_signal_overlap = ( (use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule) and num_buffers == 4 and use_cluster ) - if use_fp4_bank_friendly_schedule: - _bank_half_wm = wmma_m_rep // 2 - _bank_half_wn = wmma_n_rep // 2 - _bank_group_size = _bank_half_wm * _bank_half_wn - _bank_group_to_row_major = [] - for _wm in range(_bank_half_wm): - for _wn in range(_bank_half_wn): - _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) - for _wm in range(_bank_half_wm, wmma_m_rep): - for _wn in range(_bank_half_wn): - _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) - for _wm in range(_bank_half_wm): - for _wn in range(_bank_half_wn, wmma_n_rep): - _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) - for _wm in range(_bank_half_wm, wmma_m_rep): - for _wn in range(_bank_half_wn, wmma_n_rep): - _bank_group_to_row_major.append(_wm * wmma_n_rep + _wn) + if use_fp4_quadrant_schedule: + _fp4_half_wm = wmma_m_rep // 2 + _fp4_half_wn = wmma_n_rep // 2 + _fp4_group_size = _fp4_half_wm * _fp4_half_wn if use_fp8_quadrant_schedule or use_fp8_deep_pipeline_schedule: _fp8_half_wm = wmma_m_rep // 2 @@ -563,7 +548,7 @@ def kernel_mxscale_gemm( layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (WAVE_SIZE, m_warp * WAVE_SIZE, 16, 1)) else: layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1)) - thr_coord = idx2crd(tx, layout_thr) + thr_coord = idx2crd(fx.Int32(tx), layout_thr) wave_m_idx, wave_n_idx, lane_kgrp, lane16 = ( fx.get(thr_coord, 0), fx.get(thr_coord, 1), @@ -1073,7 +1058,7 @@ def compute_tile( ) return current_accs - def compute_tile_fp4_bank_friendly( + def compute_tile_fp4_quadrant( accs_in, lds_a, lds_b, @@ -1095,7 +1080,7 @@ def compute_tile_fp4_bank_friendly( a_buf, a_bases = _precompute_a_lane_bases(lds_a) b_buf, b_bases = _precompute_b_lane_bases(lds_b) bs_buf, bs_bases = _precompute_bs32_bases(lds_bs) - _b_half_scale_loads = _bank_half_wn # 32x4: one b32 per 32-N atom/WMMA + _b_half_scale_loads = _fp4_half_wn # 32x4: one b32 per 32-N atom/WMMA def _fp4_get_a_scale_and_opsel(a_scales_all, wm_idx): if const_expr(ascale_opsel): @@ -1108,7 +1093,7 @@ def _load_a_group(wm_base, wm_count, ks): def _load_b_half(wn_base, ks): return [ - load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) for wn_local in range_constexpr(_bank_half_wn) + load_b_frag(b_buf, b_bases, wn_base + wn_local, ks) for wn_local in range_constexpr(_fp4_half_wn) ] def _load_bs32_b_half(atom0, wn_base, ks): @@ -1118,7 +1103,7 @@ def _load_bs32_b_half(atom0, wn_base, ks): lane = (lane_kgrp * arith.index(16) + lane16) * arith.index(4) return [ lds_load_b32_raw(bs_buf, (atom0 + arith.index(wn_base + wn_local)) * stride + ks_off + lane) - for wn_local in range_constexpr(_bank_half_wn) + for wn_local in range_constexpr(_fp4_half_wn) ] def _load_b_half_bundle(wn_base, ks): @@ -1127,7 +1112,7 @@ def _load_b_half_bundle(wn_base, ks): return b_frags, b_scales def _emit_group_rows( - group_base, wm_base, a_frags, b_frags, a_scales, b_scales, row_start, row_count, emit_filler_now=False + wn_base, wm_base, a_frags, b_frags, a_scales, b_scales, row_start, row_count, emit_filler_now=False ): if const_expr(emit_filler_now and emit_filler is not None): rocdl.sched_barrier(0) @@ -1137,9 +1122,8 @@ def _emit_group_rows( a_frag = a_frags[wm_local] global_wm = wm_base + wm_local a_scale, a_opsel = _fp4_get_a_scale_and_opsel(a_scales, global_wm) - row_base = group_base + wm_local * _bank_half_wn - for wn_local in range_constexpr(_bank_half_wn): - idx = row_base + wn_local + for wn_local in range_constexpr(_fp4_half_wn): + idx = global_wm * wmma_n_rep + (wn_base + wn_local) # row-major slot current_accs[idx] = rocdl.wmma_scale_f32_32x16x128_f4( T.vec(16, T.f32), b_frags[wn_local], @@ -1151,16 +1135,16 @@ def _emit_group_rows( scaleBType=a_opsel, ) - def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_filler_now=False): + def _emit_group(wn_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_filler_now=False): _emit_group_rows( - group_base, + wn_base, wm_base, a_frags, b_frags, a_scales, b_scales, 0, - _bank_half_wm, + _fp4_half_wm, emit_filler_now=emit_filler_now, ) @@ -1171,11 +1155,11 @@ def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_ pf_a = _vgpr_scale_box[0] a_scales_all = pf_a[ks * ascale_load : (ks + 1) * ascale_load] - a_top_frags = _load_a_group(0, _bank_half_wm, ks) - a_bottom_frags = _load_a_group(_bank_half_wm, _bank_half_wm, ks) + a_top_frags = _load_a_group(0, _fp4_half_wm, ks) + a_bottom_frags = _load_a_group(_fp4_half_wm, _fp4_half_wm, ks) # Wait for bottom-A loads; top-A stays in flight during Q1. - rocdl.s_wait_dscnt(_bank_half_wm * DS_LOADS_PER_A_FRAG) + rocdl.s_wait_dscnt(_fp4_half_wm * DS_LOADS_PER_A_FRAG) _emit_group( 0, @@ -1190,15 +1174,15 @@ def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_ rocdl.sched_barrier(0) mid_compute_callback() - b_right_frags, b_right_scales = _load_b_half_bundle(_bank_half_wn, ks) + b_right_frags, b_right_scales = _load_b_half_bundle(_fp4_half_wn, ks) # Hold only the next B half outstanding while the second # quadrant consumes the current left-half fragments. - rocdl.s_wait_dscnt(_bank_half_wn * 4 + _b_half_scale_loads) + rocdl.s_wait_dscnt(_fp4_half_wn * 4 + _b_half_scale_loads) _emit_group( - _bank_group_size, - _bank_half_wm, + 0, + _fp4_half_wm, a_bottom_frags, b_left_frags, a_scales_all, @@ -1210,12 +1194,12 @@ def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_ # Older right-half loads must be ready before consuming # them, while the next ks left-half preload can remain in # flight under the final two quadrants. - rocdl.s_wait_dscnt(_bank_half_wn * 4 + _b_half_scale_loads) + rocdl.s_wait_dscnt(_fp4_half_wn * 4 + _b_half_scale_loads) else: rocdl.s_wait_dscnt(0) _emit_group( - _bank_group_size * 2, + _fp4_half_wn, 0, a_top_frags, b_right_frags, @@ -1223,8 +1207,8 @@ def _emit_group(group_base, wm_base, a_frags, b_frags, a_scales, b_scales, emit_ b_right_scales, ) _emit_group( - _bank_group_size * 3, - _bank_half_wm, + _fp4_half_wn, + _fp4_half_wm, a_bottom_frags, b_right_frags, a_scales_all, @@ -1617,12 +1601,12 @@ def hot_loop_scheduler(): rocdl.sched_dsrd(wmma_n_rep * _b_loads_per_frag + _scale_dsrd) rocdl.sched_barrier(0) - def hot_loop_scheduler_fp4_bank_friendly(): + def hot_loop_scheduler_fp4_quadrant(): _a_all_loads = wmma_m_rep * DS_LOADS_PER_A_FRAG _a_scale_loads = 0 # A-scale is in VGPRs, not ds_load'd - _b_half_loads = _bank_half_wn * 4 - _b_half_scale_loads = _bank_half_wn # 32x4: one b32 per 32-N atom/WMMA - _group_wmma = _bank_group_size + _b_half_loads = _fp4_half_wn * 4 + _b_half_scale_loads = _fp4_half_wn # 32x4: one b32 per 32-N atom/WMMA + _group_wmma = _fp4_group_size _right_half_loads = _b_half_loads + _b_half_scale_loads for _ks in range_constexpr(k_wmma_steps): @@ -1719,8 +1703,8 @@ def compute_tile_scheduled( scale_k_base=None, pf_a_scales=None, ): - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): - return compute_tile_fp4_bank_friendly( + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_QUADRANT): + return compute_tile_fp4_quadrant( accs_in, lds_a, lds_b, @@ -1771,8 +1755,8 @@ def compute_tile_scheduled( ) def hot_loop_scheduler_scheduled(): - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): - hot_loop_scheduler_fp4_bank_friendly() + if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_QUADRANT): + hot_loop_scheduler_fp4_quadrant() elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_DEEP_PIPELINE): hot_loop_scheduler_fp8_deep_pipeline() elif const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP8_QUADRANT): @@ -1883,17 +1867,6 @@ def epilogue_atomic_adds(final_accs, addrs): scf.YieldOp([]) addr_idx += n_slots - def grouped_accs_to_row_major(accs_grouped): - row_major = [None] * n_accs - for group_idx in range_constexpr(n_accs): - row_major[_bank_group_to_row_major[group_idx]] = accs_grouped[group_idx] - return row_major - - def finalize_acc_layout(accs_in): - if const_expr(compute_schedule_kind == COMPUTE_SCHEDULE_FP4_COL_BAND): - return grouped_accs_to_row_major(accs_in) - return accs_in - def epilogue_load_ptpc_scales(): # PTPC scales: sa[M] per-token (scalar per wm), sb[N] per-channel # (8 contiguous N cols per wn). Both fp32, constant along K. @@ -2401,8 +2374,6 @@ def _tail_mid_ws(_ls=_load_stage, _ab=_tail_addr_box, _sb=_tail_sec_box): hot_loop_scheduler_scheduled() - accs = finalize_acc_layout(accs) - if const_expr(is_ptpc): _load_ptpc_scales_once() _ptpc_sa, _ptpc_sb = _ptpc_scale_box[0]