Skip to content

Add U8 copy operation for K16 MMA #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 28 commits into
base: sycl-develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a6c8e53
spirv APIs
jiyang1011 Mar 12, 2025
73bef6e
mma spirv api
jiyang1011 Apr 7, 2025
6e12cb6
Merge branch 'sycl-develop' into jiyang/spirv_api
jiyang1011 Apr 14, 2025
626fd13
Merge branch 'sycl-develop' into jiyang/spirv_api
jiyang1011 Apr 22, 2025
cf6a41b
Merge branch 'sycl-develop' into jiyang/spirv_api
jiyang1011 Apr 29, 2025
d9f8303
remove -1 from OCL API
jiyang1011 Apr 29, 2025
c1cddb6
Merge branch 'sycl-develop' into jiyang/spirv_api
aacostadiaz May 6, 2025
5537fd7
rebase
aacostadiaz May 6, 2025
c89a875
Disable spirv functions for PVC
aacostadiaz May 6, 2025
5e26dd3
move spirv definitions
aacostadiaz May 6, 2025
8c67947
fix
aacostadiaz May 6, 2025
1af7011
Merge branch 'sycl-develop' into jiyang/spirv_api
aacostadiaz May 6, 2025
879eb35
Refactor
aacostadiaz May 8, 2025
9864ab2
Fix cmake
aacostadiaz May 8, 2025
39e549d
Re-enable test
aacostadiaz May 8, 2025
d6c9358
Fix mma builtin
aacostadiaz May 8, 2025
ec9d0a7
Fix copy builtin
aacostadiaz May 8, 2025
7144422
Revert minor changes
aacostadiaz May 9, 2025
3d30536
Merge branch 'sycl-develop' into jiyang/spirv_api
aacostadiaz May 12, 2025
4bbaaa6
Use builtin for prefetch
aacostadiaz May 12, 2025
304de17
Remove FP16 MMA with FP16 accumulator
aacostadiaz May 13, 2025
a2c45b1
Add U8 copy operation for K16 MMA
aacostadiaz May 14, 2025
1e2595a
Merge remote-tracking branch 'codeplay/sycl-develop' into aacosta/pac…
aacostadiaz May 27, 2025
b962239
fix merge conflict
aacostadiaz May 27, 2025
d8e855e
Revert changes in the tests
aacostadiaz May 27, 2025
d0e2c94
Update GEMM FP8 example
aacostadiaz May 27, 2025
d346207
Merge branch 'sycl-develop' into aacosta/packed-copy
aacostadiaz May 27, 2025
ba60f3a
Merge branch 'sycl-develop' into aacosta/packed-copy
joeatodd May 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ int main(int argc, const char** argv)
using ElementScale = MmaType;

// Note: XE_2D_U18x32x32_LD_N is incompatible with our bf16 MMA atoms
using GmemTiledCopyA = XE_2D_U8x32x32_LD_V; // U8 (1-byte) block copy for A (narrower type)
using GmemTiledCopyA = XE_2D_U8x32x32_LD_N; // U8 (1-byte) block copy for A (narrower type)
using GmemTiledCopyB = XE_2D_U16x32x32_LD_V; // U16 (2-byte) block copy for B (wider type)
static_assert(sizeof(ElementInputA) == 1, "ElementA width must match GmemTiledCopyA U8");

Expand Down
2 changes: 1 addition & 1 deletion examples/sycl/08_bmg_gemm_f8/08_bmg_gemm_f8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ int launcher(Options& options)
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = cutlass::layout::RowMajor;

using GmemTiledCopyA = XE_2D_U8x32x32_LD_V;
using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U8x32x32_LD_V;

using TileShape = Shape<_256, _256, _32>;
Expand Down
43 changes: 31 additions & 12 deletions include/cute/arch/copy_xe_U8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

namespace cute
{
struct XE_2D_U8x1x32_LD_N {
struct XE_2D_Packed_U8x1x32_LD_N {
using BlockShape = Shape<_1, _32>;
using inst_dtype = int8_t;

Expand Down Expand Up @@ -65,7 +65,7 @@ struct XE_2D_U8x1x32_LD_N {
};
};

struct XE_2D_U8x2x32_LD_N {
struct XE_2D_Packed_U8x2x32_LD_N {
using BlockShape = Shape<_2, _32>;
using inst_dtype = int8_t;

Expand Down Expand Up @@ -111,7 +111,7 @@ struct XE_2D_U8x2x32_ST_N {
}
};

struct XE_2D_U8x4x32_LD_N {
struct XE_2D_Packed_U8x4x32_LD_N {
using BlockShape = Shape<_4, _32>;

template <class T>
Expand Down Expand Up @@ -140,7 +140,7 @@ struct XE_2D_U8x4x32_LD_N {
};
};

struct XE_2D_U8x8x32_LD_N {
struct XE_2D_Packed_U8x8x32_LD_N {
using BlockShape = Shape<_8, _32>;

template <class T>
Expand Down Expand Up @@ -169,7 +169,7 @@ struct XE_2D_U8x8x32_LD_N {
};
};

struct XE_2D_U8x16x32_LD_N {
struct XE_2D_Packed_U8x16x32_LD_N {
using BlockShape = Shape<_16, _32>;

template <class T>
Expand Down Expand Up @@ -198,7 +198,7 @@ struct XE_2D_U8x16x32_LD_N {
};
};

struct XE_2D_U8x32x32_LD_N {
struct XE_2D_Packed_U8x32x32_LD_N {
using BlockShape = Shape<_32, _32>;

template <class T>
Expand All @@ -214,7 +214,26 @@ struct XE_2D_U8x32x32_LD_N {
}
};

struct XE_2D_U8x1x64_LD_N {
struct XE_2D_U8x32x32_LD_N {
using BlockShape = Shape<_32, _32>;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(CUTE_ARCH_COPY_XE_ENABLED)
static_assert(sizeof(T) == 1, "Expected T to have size 1");
// detail::XeSubgroup2DBlockLoad<1, 16, 32, 2>{}(baseoffset, width, height, pitch, coord, dst);
// Use the transform (VNNI) version as it provides better performance when loading the A matrix for
// GEMM FP8 and GEMM mixed-precision types.
Comment on lines +217 to +228
Copy link

@sanchitintel sanchitintel May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @aacostadiaz,

Please help resolve a couple of doubts.

The DstLayout in atom traits for this copy atom is Layout<Shape <_16,Shape <_8, _2, _32>>, Stride<_16,Stride< _1,_128,_256>>>;, which seems to correspond to plain layout. So, does this mean that initially, when the data would be copied from global memory, it'd be transformed into VNNI layout before writing to the registers, and would later be converted to DstLayout? If yes, can you please point out where/how it's handled in the code?

Also, I don't see any shfl based instructions in the generated assembly dump, so is it possible that the shuffle (for VNNI -> plain layout conversion) may not be happening directly via lane registers -> lane registers (I understand this isn't possible on Nvidia GPUs, but is somehow possible on Intel GPUs, based on the documentation) but lane registers -> shared local memory -> lane registers?

Thanks!

cc @pengzhao-intel @yuankuns

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Copy trait is used to describe how a copy operation works so that the rest of the code can understand it. It does not change how the actual copy operation works.

In this case, for the VNNI copies the transformation happens inside the builtin/spirv function. There is no transformation inside cutlass for that. We just use these builtin/spirv functions and the copy traits describe how these functions work.

detail::XeSubgroup2DBlockLoadTransform<1, 16, 32, 2>{}(baseoffset, width, height, pitch, coord, dst);
#else
CUTE_INVALID_CONTROL_PATH("Trying to use block loads on non-Xe hardware");
#endif
}
};

struct XE_2D_Packed_U8x1x64_LD_N {
using BlockShape = Shape<_1, _64>;

template <class T>
Expand Down Expand Up @@ -243,7 +262,7 @@ struct XE_2D_U8x1x64_LD_N {
};
};

struct XE_2D_U8x2x64_LD_N {
struct XE_2D_Packed_U8x2x64_LD_N {
using BlockShape = Shape<_2, _64>;

template <class T>
Expand Down Expand Up @@ -272,7 +291,7 @@ struct XE_2D_U8x2x64_LD_N {
};
};

struct XE_2D_U8x4x64_LD_N {
struct XE_2D_Packed_U8x4x64_LD_N {
using BlockShape = Shape<_4, _64>;

template <class T>
Expand Down Expand Up @@ -301,7 +320,7 @@ struct XE_2D_U8x4x64_LD_N {
};
};

struct XE_2D_U8x8x64_LD_N {
struct XE_2D_Packed_U8x8x64_LD_N {
using BlockShape = Shape<_8, _64>;

template <class T>
Expand Down Expand Up @@ -330,7 +349,7 @@ struct XE_2D_U8x8x64_LD_N {
};
};

struct XE_2D_U8x16x64_LD_N {
struct XE_2D_Packed_U8x16x64_LD_N {
using BlockShape = Shape<_16, _64>;

template <class T>
Expand Down Expand Up @@ -359,7 +378,7 @@ struct XE_2D_U8x16x64_LD_N {
};
};

struct XE_2D_U8x32x64_LD_N {
struct XE_2D_Packed_U8x32x64_LD_N {
using BlockShape = Shape<_32, _64>;

template <class T>
Expand Down
16 changes: 15 additions & 1 deletion include/cute/arch/copy_xe_builtin.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ SYCL_DEVICE_BUILTIN(
cute::intel::ushort64 __builtin_IB_subgroup_block_read_flat_u8_m32k32v2(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));

SYCL_DEVICE_BUILTIN(
cute::intel::uchar64 __builtin_IB_subgroup_block_read_flat_u8_m32k16v2(
long baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord));

// 8bits VNNI transform No transpose
SYCL_DEVICE_BUILTIN(
Expand Down Expand Up @@ -523,6 +526,17 @@ struct XeSubgroup2DBlockLoad<1, 32, 32, 1> {
}
};

template<>
struct XeSubgroup2DBlockLoad<1, 16, 32, 2> {
template<typename T>
CUTE_HOST_DEVICE void
operator()(const void* srcBasePointer, int memoryWidth, int memoryHeight, int memoryPitch,
cute::intel::coord_t coordinate, T* dstPointer) {
*reinterpret_cast<intel::uchar64 *>(dstPointer) = __builtin_IB_subgroup_block_read_flat_u8_m32k16v2(
(intptr_t)(srcBasePointer), memoryWidth - 1, memoryHeight - 1, memoryPitch - 1, coordinate);
}
};

template<>
struct XeSubgroup2DBlockLoad<1, 32, 1, 2> {
template<typename T>
Expand Down
16 changes: 16 additions & 0 deletions include/cute/arch/copy_xe_spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@
#include "cute/config.hpp"

// TODO(Codeplay): These builtins are not available on SPIRV
SYCL_EXTERNAL extern "C"
cute::intel::uchar64 __builtin_IB_subgroup_block_read_flat_u8_m32k16v2(
long baseoffset, int width_minus_one, int height_minus_one,
int pitch_minus_one, cute::intel::coord_t coord);

SYCL_EXTERNAL extern "C"
cute::intel::uint2 __builtin_IB_subgroup_block_read_flat_transpose_u32_k2(
intptr_t baseoffset, int width_minus_one, int height_minus_one,
Expand Down Expand Up @@ -271,6 +276,17 @@ struct XeSubgroup2DBlockStore {
}
};

template<>
struct XeSubgroup2DBlockLoad<1, 16, 32, 2> {
template<typename T>
CUTE_HOST_DEVICE void
operator()(const void* srcBasePointer, int memoryWidth, int memoryHeight, int memoryPitch,
cute::intel::coord_t coordinate, T* dstPointer) {
*reinterpret_cast<intel::uchar64 *>(dstPointer) = __builtin_IB_subgroup_block_read_flat_u8_m32k16v2(
(intptr_t)(srcBasePointer), memoryWidth - 1, memoryHeight - 1, memoryPitch - 1, coordinate);
}
};

template<>
struct XeSubgroup2DBlockLoadTranspose<4, 2, 16, 1> {
template<typename T>
Expand Down
Loading
Loading