Skip to content

Add U8 copy operation for K16 MMA #374

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 27 commits into
base: sycl-develop
Choose a base branch
from

Conversation

aacostadiaz
Copy link
Collaborator

This PR adds the U8 copy operation that works correctly with the K16 MMA for FP8 GEMM or mixed dtype GEMM.

@sanchitintel
Copy link

sanchitintel commented May 21, 2025

With FP8xFP8 GEMM, this config didn't work, but the corresponding code works for FP16xFP16 GEMM:

  using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
  using GmemTiledCopyB = XE_2D_U8x32x32_LD_V;

  using TileShape = Shape<_64, _256, _32>;

  using TiledMma =
      typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
      Layout<Shape<_2, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA;

The compile-time error was

include/cute/atom/copy_traits_xe.hpp:78:19: error: static assertion failed due to requirement 'size(cute::Layout<cute::tuple<cute::C<16>, cute::C<8>>, cute::tuple<cute::C<0>, cute::C<1>>>{}) % size(cute::tuple<cute::C<8>, cute::C<64>>{}) == 0'
   78 |     static_assert(size(LayoutIn{}) % size(BlockShape{}) == 0);

It seems to be a bug since the shapes are correct.

Thanks!

…ked-copy

# Conflicts:
#	CMakeLists.txt
#	include/cute/arch/copy_xe_U16.hpp
#	include/cute/arch/copy_xe_U32.hpp
#	include/cute/arch/copy_xe_U4.hpp
#	include/cute/arch/copy_xe_U64.hpp
#	include/cute/arch/copy_xe_U8.hpp
#	include/cute/arch/copy_xe_builtin.hpp
#	include/cute/arch/copy_xe_spirv.hpp
#	include/cutlass/epilogue/collective/xe_epilogue.hpp
@aacostadiaz aacostadiaz removed the incremental Incremental changes label May 27, 2025
Comment on lines +217 to +228
struct XE_2D_U8x32x32_LD_N {
using BlockShape = Shape<_32, _32>;

template <class T>
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width,
int height, int pitch, intel::coord_t coord,
T *dst) {
#if defined(CUTE_ARCH_COPY_XE_ENABLED)
static_assert(sizeof(T) == 1, "Expected T to have size 1");
// detail::XeSubgroup2DBlockLoad<1, 16, 32, 2>{}(baseoffset, width, height, pitch, coord, dst);
// Use the transform (VNNI) version as it provides better performance when loading the A matrix for
// GEMM FP8 and GEMM mixed-precision types.
Copy link

@sanchitintel sanchitintel May 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @aacostadiaz, can you please elaborate on why loading A in VNNI format is faster? I assume it's later being converted back to plain layout, since the output is correct, so that layout conversion should have had some overhead.

Did you make this change on the basis of an empirical observation of it being faster, or is there any reason why this approach should be expected to perform better? Thanks!

BTW, the DstLayout in atom traits for this copy atom is Layout<Shape <_16,Shape <_8, _2, _32>>, Stride<_16,Stride< _1,_128,_256>>>;, which seems to correspond to plain layout. So, does this mean that initially, when the data would be copied, it'd be transformed into VNNI layout before writing to the registers, and would later be converted to DstLayout somehow? If yes, can you please point out which part of the code handles it?

Also, I don't see any shfl based instructions in the generated assembly dump, so is it possible that the shuffle (for VNNI -> plain layout conversion) may not be happening directly via lane registers -> lane registers (I understand this isn't possible on Nvidia GPUs, but is somehow possible on Intel GPUs, based on the documentation) but lane registers -> shared local memory -> lane registers?

Thanks!

cc @pengzhao-intel @yuankuns

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants