-
Notifications
You must be signed in to change notification settings - Fork 30
Add U8 copy operation for K16 MMA #374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sycl-develop
Are you sure you want to change the base?
Add U8 copy operation for K16 MMA #374
Conversation
# Conflicts: # include/cute/arch/xe_copy_1B.hpp # include/cute/arch/xe_copy_2B.hpp # include/cute/arch/xe_copy_4B.hpp
# Conflicts: # include/cute/arch/mma_xe.hpp
With FP8xFP8 GEMM, this config didn't work, but the corresponding code works for FP16xFP16 GEMM: using GmemTiledCopyA = XE_2D_U8x32x32_LD_N;
using GmemTiledCopyB = XE_2D_U8x32x32_LD_V;
using TileShape = Shape<_64, _256, _32>;
using TiledMma =
typename TiledMMAHelper<MMA_Atom<XE_8x16x16_F32F16F16F32_TT>, Layout<TileShape>,
Layout<Shape<_2, _8, _1>, Stride<_8, _1, _0>>>::TiledMMA; The compile-time error was
It seems to be a bug since the shapes are correct. Thanks! |
…ked-copy # Conflicts: # CMakeLists.txt # include/cute/arch/copy_xe_U16.hpp # include/cute/arch/copy_xe_U32.hpp # include/cute/arch/copy_xe_U4.hpp # include/cute/arch/copy_xe_U64.hpp # include/cute/arch/copy_xe_U8.hpp # include/cute/arch/copy_xe_builtin.hpp # include/cute/arch/copy_xe_spirv.hpp # include/cutlass/epilogue/collective/xe_epilogue.hpp
struct XE_2D_U8x32x32_LD_N { | ||
using BlockShape = Shape<_32, _32>; | ||
|
||
template <class T> | ||
CUTE_HOST_DEVICE static void copy(const void *baseoffset, int width, | ||
int height, int pitch, intel::coord_t coord, | ||
T *dst) { | ||
#if defined(CUTE_ARCH_COPY_XE_ENABLED) | ||
static_assert(sizeof(T) == 1, "Expected T to have size 1"); | ||
// detail::XeSubgroup2DBlockLoad<1, 16, 32, 2>{}(baseoffset, width, height, pitch, coord, dst); | ||
// Use the transform (VNNI) version as it provides better performance when loading the A matrix for | ||
// GEMM FP8 and GEMM mixed-precision types. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @aacostadiaz, can you please elaborate on why loading A
in VNNI format is faster? I assume it's later being converted back to plain layout, since the output is correct, so that layout conversion should have had some overhead.
Did you make this change on the basis of an empirical observation of it being faster, or is there any reason why this approach should be expected to perform better? Thanks!
BTW, the DstLayout
in atom traits for this copy atom is Layout<Shape <_16,Shape <_8, _2, _32>>, Stride<_16,Stride< _1,_128,_256>>>;
, which seems to correspond to plain layout. So, does this mean that initially, when the data would be copied, it'd be transformed into VNNI layout before writing to the registers, and would later be converted to DstLayout
somehow? If yes, can you please point out which part of the code handles it?
Also, I don't see any shfl
based instructions in the generated assembly dump, so is it possible that the shuffle (for VNNI -> plain layout conversion) may not be happening directly via lane registers -> lane registers
(I understand this isn't possible on Nvidia GPUs, but is somehow possible on Intel GPUs, based on the documentation) but lane registers -> shared local memory -> lane registers
?
Thanks!
This PR adds the U8 copy operation that works correctly with the K16 MMA for FP8 GEMM or mixed dtype GEMM.