Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions kernels/blockscale_preshuffle_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def kernel_gemm(
# ---- Wave / lane decomposition ----
wave_size = 64
layout_wave_lane = fx.make_layout((4, wave_size), (64, 1))
coord_wave_lane = fx.idx2crd(tx, layout_wave_lane)
coord_wave_lane = fx.idx2crd(fx.Int32(tx), layout_wave_lane)
wave_id = fx.get(coord_wave_lane, 0)
lane_id = fx.get(coord_wave_lane, 1)

layout_lane16 = fx.make_layout((4, 16), (16, 1))
coord_lane16 = fx.idx2crd(lane_id, layout_lane16)
coord_lane16 = fx.idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = fx.get(coord_lane16, 0)
lane_mod_16 = fx.get(coord_lane16, 1)

Expand Down Expand Up @@ -252,8 +252,8 @@ def load_b_packs_k64(base_k, ku: int, ni: int):
k0_base = base_k_bytes // c64_b
k0 = k0_base + ku
k1 = lane_div_16
coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
coord_pack = (n_blk_list[ni], k0, k1, n_intra_list[ni], fx.Int32(0))
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)
b16 = _buffer_load_vec(
buffer_ops,
vector,
Expand Down
40 changes: 20 additions & 20 deletions kernels/gemm_fp8fp4_gfx1250.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,7 @@ def kernel_mxscale_gemm(
layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (WAVE_SIZE, m_warp * WAVE_SIZE, 16, 1))
else:
layout_thr = fx.make_layout((m_warp, n_warp, 2, 16), (n_warp * WAVE_SIZE, WAVE_SIZE, 16, 1))
thr_coord = idx2crd(tx, layout_thr)
thr_coord = idx2crd(fx.Int32(tx), layout_thr)
Comment thread
xudoyuan marked this conversation as resolved.
wave_m_idx, wave_n_idx, lane_kgrp, lane16 = (
fx.get(thr_coord, 0),
fx.get(thr_coord, 1),
Expand All @@ -563,12 +563,12 @@ def kernel_mxscale_gemm(
_bvs_a_rsrc = buffer_ops.create_buffer_resource(arg_a_scale, max_size=False)
_bvs_b_rsrc = buffer_ops.create_buffer_resource(arg_b_scale, max_size=False)
_bvs_Kt = K // tile_k # total K-tiles
_bvs_mb_a = blk_m / arith.index(128) + wave_m_idx
_bvs_mb_b = blk_n / arith.index(128) + wave_n_idx
_bvs_mb_a = blk_m // arith.index(128) + wave_m_idx
_bvs_mb_b = blk_n // arith.index(128) + wave_n_idx
_bvs_lane4 = lane16 * arith.index(4)

def _bvs_load_scales(rsrc, mb, rep, k_base):
kt = k_base / arith.index(tile_k)
kt = k_base // arith.index(tile_k)
tile_i32 = (mb * arith.index(_bvs_Kt) + kt) * arith.index(128)
vals = []
for ld in range_constexpr(rep // 4): # rep=8 -> 2 groups of 4 i32
Expand All @@ -594,7 +594,7 @@ def _bvs_prefetch(k_base):
).result

def make_desc_a(memref, k_base):
k_packed_off = k_base / arith.index(PACK_FACTOR_A)
k_packed_off = k_base // arith.index(PACK_FACTOR_A)
return _make_tdm_desc(
global_ptr=arg_a,
lds_memref=memref,
Expand All @@ -612,11 +612,11 @@ def make_desc_a(memref, k_base):
)

def make_desc_b(memref, k_base):
k_packed_off = k_base / arith.index(PACK_FACTOR_B)
k_packed_off = k_base // arith.index(PACK_FACTOR_B)
return _make_tdm_desc(
global_ptr=arg_b,
lds_memref=memref,
global_offset=(blk_n / arith.index(16), k_packed_off * arith.index(16)),
global_offset=(blk_n // arith.index(16), k_packed_off * arith.index(16)),
tensor_shape=(N // 16, K_packed_b * 16),
strides=(K_packed_b * 16, 1),
tile_shape=(tile_n // 16, packed_tile_k_b * 16),
Expand All @@ -631,7 +631,7 @@ def make_desc_b(memref, k_base):

def make_desc_a_half(memref, k_base, m_half: int):
row_start = m_half * ab_split_a_rows
k_packed_off = k_base / arith.index(PACK_FACTOR_A)
k_packed_off = k_base // arith.index(PACK_FACTOR_A)
return _make_tdm_desc(
global_ptr=arg_a,
lds_memref=memref,
Expand All @@ -651,11 +651,11 @@ def make_desc_a_half(memref, k_base, m_half: int):

def make_desc_b_half(memref, k_base, n_half: int):
group_start = n_half * ab_split_b_groups
k_packed_off = k_base / arith.index(PACK_FACTOR_B)
k_packed_off = k_base // arith.index(PACK_FACTOR_B)
return _make_tdm_desc(
global_ptr=arg_b,
lds_memref=memref,
global_offset=(blk_n / arith.index(16) + arith.index(group_start), k_packed_off * arith.index(16)),
global_offset=(blk_n // arith.index(16) + arith.index(group_start), k_packed_off * arith.index(16)),
tensor_shape=(N // 16, K_packed_b * 16),
strides=(K_packed_b * 16, 1),
tile_shape=(ab_split_b_groups, packed_tile_k_b * 16),
Expand All @@ -670,8 +670,8 @@ def make_desc_b_half(memref, k_base, n_half: int):
)

def make_desc_as(memref, k_base):
k_scale_off = k_base / arith.index(SCALE_BLOCK)
outer_off = blk_m / arith.index(wmma_m_rep)
k_scale_off = k_base // arith.index(SCALE_BLOCK)
outer_off = blk_m // arith.index(wmma_m_rep)
inner_off = k_scale_off * arith.index(wmma_m_rep)
return _make_tdm_desc(
global_ptr=arg_a_scale,
Expand All @@ -690,8 +690,8 @@ def make_desc_as(memref, k_base):
)

def make_desc_bs(memref, k_base):
k_scale_off = k_base / arith.index(SCALE_BLOCK)
outer_off = blk_n / arith.index(b_scale_load_rep)
k_scale_off = k_base // arith.index(SCALE_BLOCK)
outer_off = blk_n // arith.index(b_scale_load_rep)
inner_off = k_scale_off * arith.index(b_scale_load_rep)
return _make_tdm_desc(
global_ptr=arg_b_scale,
Expand Down Expand Up @@ -837,7 +837,7 @@ def load_b_frag(lds_buffer, b_lane_bases, wn, ks):

def _precompute_scale_lane_bases(lds_ptr, warp_base, reps, interleaved_cols):
"""Precompute scale lane bases (byte offsets)."""
warp_lds_row = warp_base / arith.index(reps) + lane16
warp_lds_row = warp_base // arith.index(reps) + lane16
base = warp_lds_row * arith.index(interleaved_cols)
if const_expr(is_fp4 or is_a8w4):
# FP4/A8W4: always add lane_kgrp offset (no opsel on BScale)
Expand Down Expand Up @@ -1985,8 +1985,8 @@ def _l2_prefetch(k_base):
if const_expr(_effective_l2_pf <= 0):
return
pf_k = k_base + arith.index(_effective_l2_pf * tile_k)
pf_k_packed_a = pf_k / arith.index(PACK_FACTOR_A)
pf_k_packed_b = pf_k / arith.index(PACK_FACTOR_B)
pf_k_packed_a = pf_k // arith.index(PACK_FACTOR_A)
pf_k_packed_b = pf_k // arith.index(PACK_FACTOR_B)
tdm_ops.l2_prefetch_tile(
arg_a,
(blk_m, pf_k_packed_a),
Expand All @@ -1998,7 +1998,7 @@ def _l2_prefetch(k_base):
)
tdm_ops.l2_prefetch_tile(
arg_b,
(blk_n / arith.index(16), pf_k_packed_b * arith.index(16)),
(blk_n // arith.index(16), pf_k_packed_b * arith.index(16)),
(tile_n // 16, packed_tile_k_b * 16),
(K_packed_b * 16, 1),
elem_bytes=1,
Expand Down Expand Up @@ -2057,9 +2057,9 @@ def _l2_prefetch(k_base):
# Match the TDM-store descriptor offsets to the compute wave mapping.
if const_expr(use_fp8_deep_pipeline_schedule):
wave_m_sgpr = wave_id_idx % arith.index(m_warp)
wave_n_sgpr = wave_id_idx / arith.index(m_warp)
wave_n_sgpr = wave_id_idx // arith.index(m_warp)
else:
wave_m_sgpr = wave_id_idx / arith.index(n_warp)
wave_m_sgpr = wave_id_idx // arith.index(n_warp)
wave_n_sgpr = wave_id_idx % arith.index(n_warp)
d_warp_linear_sgpr = wave_m_sgpr * arith.index(n_warp) + wave_n_sgpr
d_warp_off_sgpr = d_warp_linear_sgpr * arith.index(warp_d_bytes) + arith.index(d_output_off)
Expand Down
2 changes: 1 addition & 1 deletion kernels/layernorm_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def _load_norm_input_value(index):
mean = sum_val / n_float
var = sumsq_val / n_float - mean * mean
var = (var < c_zero_f).select(c_zero_f, var)
rstd = (var + eps_c).rsqrt(fastmath=fm_fast)
rstd = fmath.rsqrt(var + eps_c, fastmath=fm_fast)

thread_row_max = c_zero_f
for base_idx_int in range_constexpr(0, N, BLOCK_THREADS):
Expand Down
12 changes: 7 additions & 5 deletions kernels/layout_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,15 @@ def idx2crd(idx, layout):
"""
parsed = _parse_layout(layout)

if hasattr(idx, "ir_value"):
idx = idx.ir_value()

if parsed is None or _has_dynamic_strides(parsed[1]):
result = fx.idx2crd(idx, layout)
result = fx.idx2crd(fx.Int32(idx), layout)
ndims = len(parsed[1]) if parsed else 1
return [_wrap(fx.get(result, i)) for i in range(ndims)]

if hasattr(idx, "type") and str(idx.type) != "index":
if isinstance(idx, ir.Value) and not isinstance(idx.type, ir.IndexType):
idx = arith.index_cast(T.index, idx)
shapes, strides = parsed
ndims = len(strides)
Expand Down Expand Up @@ -156,9 +159,8 @@ def crd2idx(crd, layout):
cv = raw
crd_i32.append(cv)
coord_val = fx.make_coord(*crd_i32)
result = fx.crd2idx(coord_val, layout)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = fx.get_scalar(fx.crd2idx(coord_val, layout)).ir_value()
if not isinstance(scalar.type, ir.IndexType):
scalar = arith.index_cast(T.index, scalar)
return _wrap(scalar)

Expand Down
27 changes: 13 additions & 14 deletions kernels/mfma_preshuffle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@


def crd2idx(crd, layout):
"""crd2idx returning an index-type scalar (unwraps fly.int_tuple)."""
result = fx.crd2idx(crd, layout)
scalar = fx.get_scalar(result)
if isinstance(scalar, ir.Value) and not isinstance(scalar.type, ir.IndexType):
scalar = _arith.IndexCastOp(T.index, scalar).result
return scalar
"""crd2idx returning an index-typed ir.Value (unwraps fly.int_tuple)."""
scalar = fx.get_scalar(fx.crd2idx(crd, layout)).ir_value()
Comment thread
sjfeng1999 marked this conversation as resolved.
if isinstance(scalar.type, ir.IndexType):
return scalar
return _arith.IndexCastOp(T.index, scalar).result


def swizzle_xor16(row, col, k_blocks16):
Expand Down Expand Up @@ -326,7 +325,7 @@ def load_b_raw_w4a16(
k2_base = lane_odd * fx.Index(half_bytes)

coord_pack = (n_blk, k0, k1_local, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)
idx_bytes = idx_pack + k2_base

b4 = _buffer_load_vec(
Expand Down Expand Up @@ -464,7 +463,7 @@ def load_b_pack_k32(
k2_base = arith.constant((ki_step % 2) * half_bytes, index=True)

coord_pack = (n_blk, k0, k1, n_intra, fx.Index(0))
idx_pack = crd2idx(coord_pack, layout_b)
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)

if unpack_int4:
idx_bytes = idx_pack + k2_base
Expand Down Expand Up @@ -527,7 +526,7 @@ def tile_chunk_coord_i32(
raise ValueError(f"chunk_i32 must be one of (1,2,4), got {chunk_i32!r}")
chunk_off_i32 = arith.constant(i * total_threads * chunk_i32, index=True)
tile_idx_i32 = tx_i32_base + chunk_off_i32
coord_local = fx.idx2crd(tile_idx_i32, layout_tile_div4)
coord_local = fx.idx2crd(fx.Int32(tile_idx_i32), layout_tile_div4)
row_local = fx.get(coord_local, 0)
col_local_i32 = fx.get(coord_local, 1)
return row_local, col_local_i32
Expand Down Expand Up @@ -580,7 +579,7 @@ def lds_store_16b_xor16(
col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16)
col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2
coord_store = (row_local, col_swz)
idx0 = crd2idx(coord_store, layout_lds) + lds_base
idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base
v16 = vector.bitcast(vec16_ty, vec_part_i32x4)
vector.store(v16, lds_memref, [idx0])

Expand All @@ -607,7 +606,7 @@ def lds_store_8b_xor16(
col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16)
col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2
coord_store = (row_local, col_swz)
idx0 = crd2idx(coord_store, layout_lds) + lds_base
idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base
v8 = vector.bitcast(vec8_ty, vec_part_i32x2)
vector.store(v8, lds_memref, [idx0])

Expand All @@ -634,7 +633,7 @@ def lds_store_4b_xor16(
col_swz_bytes = swizzle_xor16(row_local, col_local_bytes, k_blocks16)
col_swz = col_swz_bytes if elem_bytes == 1 else col_swz_bytes // 2
coord_store = (row_local, col_swz)
idx0 = crd2idx(coord_store, layout_lds) + lds_base
idx0 = crd2idx(tuple(fx.Int32(c) for c in coord_store), layout_lds) + lds_base
v4 = vector.bitcast(vec4_ty, vec_part_i32x1)
vector.store(v4, lds_memref, [idx0])

Expand All @@ -660,14 +659,14 @@ def lds_load_pack_k32(
col_base_swz = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16)
if ck_lds128:
coord_a16 = (curr_row_a_lds, col_base_swz)
idx_a16 = crd2idx(coord_a16, layout_lds) + lds_base
idx_a16 = crd2idx(tuple(fx.Int32(c) for c in coord_a16), layout_lds) + lds_base
loaded_a16 = vector.load_op(vec16_ty, lds_memref, [idx_a16])
a_vec128 = vector.bitcast(vec2_i64_ty, loaded_a16)
return vector.extract(a_vec128, static_position=[half], dynamic_position=[])
else:
col_swizzled = col_base_swz + (half * 8)
coord_a = (curr_row_a_lds, col_swizzled)
idx_a = crd2idx(coord_a, layout_lds) + lds_base
idx_a = crd2idx(tuple(fx.Int32(c) for c in coord_a), layout_lds) + lds_base
loaded_a8 = vector.load_op(vec8_ty, lds_memref, [idx_a])
a_vec64 = vector.bitcast(vec1_i64_ty, loaded_a8)
return vector.extract(a_vec64, static_position=[0], dynamic_position=[])
Expand Down
18 changes: 9 additions & 9 deletions kernels/mixed_moe_gemm_2stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,10 +729,10 @@ def load_x_tile(base_k):
return parts

# Wave/lane decomposition (identical to stage2)
coord_wl = idx2crd(tx, layout_tx_wave_lane)
coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = layout_get(coord_wl, 0)
lane_id = layout_get(coord_wl, 1)
coord_l16 = idx2crd(lane_id, layout_lane16)
coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = layout_get(coord_l16, 0)
lane_mod_16 = layout_get(coord_l16, 1)
row_a_lds = lane_mod_16
Expand Down Expand Up @@ -763,12 +763,12 @@ def load_x_tile(base_k):
global_n = by_n + n_tile_base + c_offset + lane_mod_16
# Gate/interleave: rows [expert_off, expert_off + 2*inter_dim)
gate_row_w = expert_off_idx + global_n
gate_coord = idx2crd(gate_row_w, layout_n_blk_intra)
gate_coord = idx2crd(fx.Int32(gate_row_w), layout_n_blk_intra)
gate_n_blk_list.append(layout_get(gate_coord, 0))
gate_n_intra_list.append(layout_get(gate_coord, 1))
if const_expr(not mock_gate_only and not gate_up_interleave):
up_row_w = gate_row_w + inter_idx
up_coord = idx2crd(up_row_w, layout_n_blk_intra)
up_coord = idx2crd(fx.Int32(up_row_w), layout_n_blk_intra)
up_n_blk_list.append(layout_get(up_coord, 0))
up_n_intra_list.append(layout_get(up_coord, 1))

Expand Down Expand Up @@ -799,7 +799,7 @@ def load_b_packs_k64(base_k, ku: int, n_blk, n_intra):
k0 = base_k_bytes // c64 + arith.constant(ku, index=True)
k1 = lane_div_16
coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True))
idx_pack = crd2idx(coord_pack, layout_b)
idx_pack = crd2idx(tuple(fx.Int32(c) for c in coord_pack), layout_b)
vec_elems = kpack_bytes // int(b_elem_bytes)
b16 = _buffer_load_vec(
buffer_ops,
Expand Down Expand Up @@ -1015,7 +1015,7 @@ def prefetch_x_to_lds(base_k, lds_buffer):
def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer):
col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16)
col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2))
idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds)
idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds)
Comment thread
sjfeng1999 marked this conversation as resolved.
loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16])
a_i64x2 = vector.bitcast(vec2_i64, loaded_a16)
a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[])
Expand Down Expand Up @@ -3074,10 +3074,10 @@ def load_x_tile(base_k):
return parts

# tx -> wave/lane (GEMM-style decomposition).
coord_wl = idx2crd(tx, layout_tx_wave_lane)
coord_wl = idx2crd(fx.Int32(tx), layout_tx_wave_lane)
wave_id = layout_get(coord_wl, 0)
lane_id = layout_get(coord_wl, 1)
coord_l16 = idx2crd(lane_id, layout_lane16)
coord_l16 = idx2crd(fx.Int32(lane_id), layout_lane16)
lane_div_16 = layout_get(coord_l16, 0)
lane_mod_16 = layout_get(coord_l16, 1)

Expand Down Expand Up @@ -3330,7 +3330,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer):
def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer):
col_base_swz_bytes = swizzle_xor16(curr_row_a_lds, col_base, k_blocks16)
col_base_swz = col_base_swz_bytes if elem_bytes == 1 else (col_base_swz_bytes / arith.index(2))
idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds)
idx_a16 = crd2idx([fx.Int32(curr_row_a_lds), fx.Int32(col_base_swz)], layout_lds)
Comment thread
sjfeng1999 marked this conversation as resolved.
loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16])
a_i64x2 = vector.bitcast(vec2_i64, loaded_a16)
a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[])
Expand Down
Loading
Loading