Skip to content

HIP: Enable Matrix cores for MMQ Kernels, Enable stream-K for CDNA 3 #14624

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
68da4e5
Feat: Enable MFMA instr for Q4_K
deepsek May 23, 2025
79f348a
Fix: Missed template param
deepsek May 24, 2025
89ba8a6
Feat: Add MFMA instr for Q6_K, remove MMQ_NWARPS
deepsek May 27, 2025
e57e563
Merge branch 'ggml-org:master' into amd-integration
deepsek May 27, 2025
9784a51
Merge branch 'ggml-org:master' into amd-integration
deepsek Jun 2, 2025
dad79b3
Merge branch 'ggml-org:master' into amd-integration
deepsek Jun 10, 2025
ff60fa9
Perf: Fix Register Spilling Q6_K - Refactor kernel, launch_bound
deepsek Jun 10, 2025
e8eeb34
Perf: Refactor Q4_K, reduce register pressure
deepsek Jun 12, 2025
a161900
Perf: Throughput Increase 4k->6.9k t/s
deepsek Jun 25, 2025
75d386a
Perf: 7.1k tokens/sec
deepsek Jul 2, 2025
0215a80
Perf/Feat: Throughput 8.3k tokens/sec, Add support for all quants
deepsek Jul 9, 2025
aa35feb
Feat: Remove warnings, deprecated __AMDGCN_WAVEFRONT_SIZE
deepsek Jul 10, 2025
ba17f62
Merge branch 'master' into amd-integration
deepsek Jul 10, 2025
5ab1491
Feat: Enable stream-k for CDNA3
deepsek Jul 10, 2025
fb2fd31
Fix: Remove Trailing Whitespaces
deepsek Jul 14, 2025
b55d44a
Fix: Unused Params Warnings, CUDA Build
deepsek Jul 14, 2025
ab7c007
-p512: 8.4k->9.5k - Account for DataPadding for writing tile_y
deepsek Jul 15, 2025
279b51e
refactor: PR code cleanup, amd_mma_available->amd_mfma_available
deepsek Jul 22, 2025
a2a336b
refactor: PR cleanup
deepsek Jul 23, 2025
b489b4e
Feat/Perf: redesign all quants use same tile size, mfma instr, nwarps
deepsek Jul 24, 2025
4e6d54f
Fix: CI fail - unused parameter Werror
deepsek Jul 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .devops/rocm.Dockerfile
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
ARG UBUNTU_VERSION=24.04

# This needs to generally match the container host's environment.
ARG ROCM_VERSION=6.3
ARG AMDGPU_VERSION=6.3
ARG ROCM_VERSION=6.4
ARG AMDGPU_VERSION=6.4

# Target the CUDA build image
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
Expand Down
16 changes: 13 additions & 3 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16
#define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue
#define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a
#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers
#define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing
#define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300

Expand All @@ -72,8 +72,9 @@
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1)

// Moore Threads
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
Expand Down Expand Up @@ -226,6 +227,10 @@ typedef float2 dfloat2;
#define FP16_MMA_AVAILABLE
#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4)))

#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3)

#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
Expand Down Expand Up @@ -288,6 +293,11 @@ static bool fp32_mma_hardware_available(const int cc) {
return GGML_CUDA_CC_IS_CDNA(cc);
}

// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later.
static bool amd_mfma_available(const int cc) {
return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc);
}

// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static bool new_mma_available(const int cc) {
return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
Expand Down
114 changes: 111 additions & 3 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile.
// All matrix tiles have ne physical 32 bit elements per warp.
//
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior.

#include "common.cuh"

Expand Down Expand Up @@ -66,7 +67,44 @@ namespace ggml_cuda_mma {
struct tile {
static constexpr int I = I_;
static constexpr int J = J_;
static constexpr int ne = I * J / WARP_SIZE;

#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
static constexpr int ne = I * J / 64;
T x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
return threadIdx.x % 16;
} else if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 4) {
return threadIdx.x % 32;
} else if constexpr (I == 16 && J == 16) {
return 4 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 32) {
return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4);
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
return (2 * ((threadIdx.x / 16) % 2) + l);
} else if constexpr (I == 16 && J == 8) {
return 2 * (threadIdx.x / 16) + l;
} else if constexpr (I == 32 && J == 4) {
return 2 * (threadIdx.x / 32) + l;
} else if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else if constexpr (I == 32 && J == 32) {
return threadIdx.x % 32;
} else {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};

static __device__ __forceinline__ int get_i(const int l) {
Expand Down Expand Up @@ -94,6 +132,7 @@ namespace ggml_cuda_mma {
static_assert(I == -1 && J == -1, "template specialization not implemented");
}
}
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
};

template <int I_, int J_>
Expand Down Expand Up @@ -148,10 +187,23 @@ namespace ggml_cuda_mma {

template <int I, int J, typename T>
static __device__ __forceinline__ void load_generic(tile<I, J, T> & t, const T * __restrict__ xs0, const int stride) {
#if defined(AMD_MFMA_AVAILABLE)
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
} else {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}
Comment on lines +191 to +200
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8>
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
} else {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}
if constexpr (I != 64 || J != 2) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
return;
}

I think this would be simpler.

Copy link
Author

@deepsek deepsek Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean without the preprocessor directives?
This would affect the NV code path when we call load_generic though? I see some instances where load_generic is called

static __device__ __forceinline__ void load_generic(...) {
        if constexpr (I != 64 || J != 2) {
            int64_t * xi = (int64_t *) t.x;
            const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
            xi[0] = xs[0];
            return;
        }

#pragma unroll
        for (int l = 0; l < t.ne; ++l) {
            t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
        }
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I basically meant to have the instructions for loading data as 64 bit encapsulated in an ifdef AMD_MFMA_AVAILABLE ... #endif and to use the generic implementation if the preconditions aren't met. But if this is going to be refactored anyways it doesn't matter.

#else
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
#endif // defined(AMD_MFMA_AVAILABLE)
}

template <typename T>
Expand Down Expand Up @@ -186,7 +238,7 @@ namespace ggml_cuda_mma {
template <typename T>
static __device__ __forceinline__ void load_ldmatrix(
tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) {
#ifdef NEW_MMA_AVAILABLE
#if defined(NEW_MMA_AVAILABLE)
int * xi = (int * ) t.x;
const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2);
asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
Expand Down Expand Up @@ -393,4 +445,60 @@ namespace ggml_cuda_mma {
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) {
#if defined(AMD_MFMA_AVAILABLE)
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
int32x4_t * acc = (int32x4_t *) D.x;
#if defined(CDNA3)
acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0],
((int64_t *) B.x)[0],
acc[0],
0, 0, 0);
#elif defined(CDNA2) || defined(CDNA)
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0],
B.x[0],
acc[0],
0, 0, 0);
acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1],
B.x[1],
acc[0],
0, 0, 0);
#endif // defined(CDNA3)
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // AMD_MFMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) {
#if defined(AMD_MFMA_AVAILABLE)
using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int;
int32x16_t * acc = (int32x16_t *) D.x;
#if defined(CDNA3)
acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0],
((int64_t *) B.x)[0],
acc[0],
0, 0, 0);
#elif defined(CDNA2) || defined(CDNA)
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0],
B.x[0],
acc[0],
0, 0, 0);
acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1],
B.x[1],
acc[0],
0, 0, 0);
#endif // defined(CDNA3)
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
GGML_UNUSED(B);
NO_DEVICE_CODE;
#endif // AMD_MFMA_AVAILABLE
}
}
10 changes: 6 additions & 4 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ void ggml_cuda_mul_mat_q(
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s3 = dst->nb[3] / ts_dst;

const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)));

if (!ids) {
const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
Expand Down Expand Up @@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q(
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
|| (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc)))
&& src1_ncols == ne11;
const mmq_args args = {
src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i,
ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst,
Expand Down Expand Up @@ -304,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}

if (new_mma_available(cc)) {
if (new_mma_available(cc) || amd_mfma_available(cc)) {
return true;
}

Expand Down
Loading