Skip to content
Draft
23 changes: 23 additions & 0 deletions kernels/gemm_common_gfx1250.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,29 @@ def lds_load_b128_raw(lds_base_idx, byte_offset):
return llvm_dialect.load(ir.VectorType.get([4], ir.IntegerType.get_signless(32)), ptr_val)


def lds_load_b32_raw(lds_base_idx, byte_offset):
"""Load 4 bytes (one i32) from LDS using a pre-extracted base index (raw LLVM).

Unlike :func:`lds_load_b128_raw`, this only requires 4-byte alignment, so it
suits scale layouts where consumed words sit at 4-byte (not 16-byte) granular
offsets (e.g. the N4K4 B-scale layout's per-N-block reads).
"""
ptr_val = _raw_lds_ptr(lds_base_idx, byte_offset)
return llvm_dialect.load(ir.IntegerType.get_signless(32), ptr_val)


def lds_load_b64_raw(lds_base_idx, byte_offset):
"""Load 8 bytes (``vector<2xi32>``) from LDS using a pre-extracted base index.

Requires 8-byte alignment. Sits between :func:`lds_load_b32_raw` and
:func:`lds_load_b128_raw` for layouts whose contiguous read width is 2 words
(e.g. the N4K4 B-scale layout when ``wmma_n_rep`` is even but not a multiple
of 4, where each aligned batch covers exactly 2 N-blocks).
"""
ptr_val = _raw_lds_ptr(lds_base_idx, byte_offset)
return llvm_dialect.load(ir.VectorType.get([2], ir.IntegerType.get_signless(32)), ptr_val)


def lds_transpose_load_raw(result_type, lds_base_idx, byte_offset):
"""Transpose-load 16 bytes from LDS using a pre-extracted base index."""
from flydsl._mlir.dialects import rocdl as _rocdl
Expand Down
Loading