Skip to content

Commit 8b5b491

Browse files
committed
feat: SM_75 Support
1 parent 7901983 commit 8b5b491

File tree

2 files changed

+89
-3
lines changed

2 files changed

+89
-3
lines changed

awq/kernels/csrc/quantization/gemm_cuda_gen.cu

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,22 +200,72 @@ __global__ void __launch_bounds__(128) gemm_forward_4bit_cuda_m128n64k32(int spl
200200

201201
for (int i_0_3 = 0; i_0_3 < 4; ++i_0_3) {
202202
for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) {
203+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
204+
{
205+
__asm__ __volatile__(
206+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
207+
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
208+
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
209+
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]),
210+
"r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]),
211+
"f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
212+
);
213+
}
214+
215+
{
216+
__asm__ __volatile__(
217+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
218+
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
219+
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
220+
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]),
221+
"r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]),
222+
"f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
223+
);
224+
}
203225

226+
{
227+
__asm__ __volatile__(
228+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
229+
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
230+
: "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
231+
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]),
232+
"r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]),
233+
"f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
234+
);
235+
}
236+
{
237+
__asm__ __volatile__(
238+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
239+
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
240+
: "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
241+
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]),
242+
"r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
243+
"f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
244+
);
245+
}
246+
#else
204247
{
205248
__asm__ __volatile__(
206249
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
207250
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
208251
: "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "=f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
209-
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3]));
252+
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]),
253+
"r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]),
254+
"f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[0]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[1]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[2]), "f"(((float *)(C_warp + ((i_0_3 * 16) + (j_0_4 * 8))))[3])
255+
);
210256
}
211257

212258
{
213259
__asm__ __volatile__(
214260
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
215261
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
216262
: "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "=f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
217-
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3]));
263+
: "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[0]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[1]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[2]), "r"(((unsigned *)(A_shared_warp + (i_0_3 * 8)))[3]),
264+
"r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]),
265+
"f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[0]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[1]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[2]), "f"(((float *)(C_warp + (((i_0_3 * 16) + (j_0_4 * 8)) + 4)))[3])
266+
);
218267
}
268+
#endif
219269
}
220270
}
221271
}

awq/kernels/csrc/quantization_new/gemm/gemm_cuda.cu

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ __inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(half *shared_warp, int ax0
105105
: "r"(addr));
106106
}
107107

108+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
108109
__inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask)
109110
{
110111
const int cp_size = 16;
@@ -117,14 +118,37 @@ __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__r
117118
"l"(src),
118119
"n"(cp_size));
119120
}
121+
#endif
120122

121123
__device__ __inline__ void mma_m16n8k16(float *C_warp, half *A_shared_warp, half *B_shared_warp)
122124
{
125+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
126+
__asm__ __volatile__(
127+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
128+
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};"
129+
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
130+
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]),
131+
"r"(((unsigned *)B_shared_warp)[0]),
132+
"f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3])
133+
);
134+
__asm__ __volatile__(
135+
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
136+
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};"
137+
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
138+
: "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]),
139+
"r"(((unsigned *)B_shared_warp)[1]),
140+
"f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3])
141+
);
142+
#else
123143
__asm__ __volatile__(
124144
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
125145
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
126146
: "=f"(((float *)C_warp)[0]), "=f"(((float *)C_warp)[1]), "=f"(((float *)C_warp)[2]), "=f"(((float *)C_warp)[3])
127-
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]), "r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]), "f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3]));
147+
: "r"(((unsigned *)A_shared_warp)[0]), "r"(((unsigned *)A_shared_warp)[1]), "r"(((unsigned *)A_shared_warp)[2]), "r"(((unsigned *)A_shared_warp)[3]),
148+
"r"(((unsigned *)B_shared_warp)[0]), "r"(((unsigned *)B_shared_warp)[1]),
149+
"f"(((float *)C_warp)[0]), "f"(((float *)C_warp)[1]), "f"(((float *)C_warp)[2]), "f"(((float *)C_warp)[3])
150+
);
151+
#endif
128152
}
129153

130154
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
@@ -148,12 +172,14 @@ __device__ __inline__ void global_to_share_one_stage_A(half *src, half *dst, int
148172
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
149173
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
150174
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
175+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
151176
if constexpr (STAGES > 1)
152177
{
153178
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
154179
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
155180
}
156181
else
182+
#endif
157183
{
158184
if (local_mask & (ld_row + cta_offset_m < global_nrows))
159185
*(uint4 *)dst_ptr = *src_ptr;
@@ -183,12 +209,14 @@ __device__ __inline__ void global_to_share_one_stage_B(half *src, half *dst, int
183209
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
184210
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
185211
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
212+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
186213
if constexpr (STAGES > 1)
187214
{
188215
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
189216
cp_async_cg_A(addr, src_ptr, local_mask);
190217
}
191218
else
219+
#endif
192220
{
193221
if (local_mask)
194222
*(uint4 *)dst_ptr = *src_ptr;
@@ -212,6 +240,7 @@ __device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst
212240
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
213241
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x / threads_per_row) * kSmemCol + (threadIdx.x % threads_per_row) * PACK_SIZE);
214242
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x / threads_per_row) * global_ncols + (threadIdx.x % threads_per_row) * PACK_SIZE);
243+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
215244
if (STAGES > 1)
216245
{
217246
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
@@ -220,6 +249,7 @@ __device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst
220249
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
221250
}
222251
else
252+
#endif
223253
{
224254
if (local_mask)
225255
{
@@ -606,12 +636,14 @@ __device__ __inline__ void global_to_share_one_stage_A_T2(half *src, half *dst,
606636
int ld_col_swizzled = (ld_col ^ (ld_row) & 7) * PACK_SIZE;
607637
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
608638
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
639+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
609640
if constexpr (STAGES > 1)
610641
{
611642
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
612643
cp_async_cg_A(addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
613644
}
614645
else
646+
#endif
615647
{
616648
if (local_mask & (ld_row + cta_offset_m < global_nrows))
617649
*(uint4 *)dst_ptr = *src_ptr;
@@ -641,12 +673,14 @@ __device__ __inline__ void global_to_share_one_stage_B_T2(half *src, half *dst,
641673
int ld_col_swizzled = ld_col ^ (ld_row % 2) & 7;
642674
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
643675
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE);
676+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
644677
if constexpr (STAGES > 1)
645678
{
646679
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
647680
cp_async_cg_A(addr, src_ptr, local_mask);
648681
}
649682
else
683+
#endif
650684
{
651685
if (local_mask)
652686
*(uint4 *)dst_ptr = *src_ptr;
@@ -669,6 +703,7 @@ __device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *
669703
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
670704
void *dst_ptr_z = (void *)(dst_z + (threadIdx.x % threads_per_row) * PACK_SIZE);
671705
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
706+
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
672707
if (STAGES > 1)
673708
{
674709
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
@@ -677,6 +712,7 @@ __device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *
677712
cp_async_cg_A(addr_z, src_ptr_z, local_mask);
678713
}
679714
else
715+
#endif
680716
{
681717
if (local_mask)
682718
{

0 commit comments

Comments
 (0)