@@ -105,6 +105,7 @@ __inline__ __device__ void ldmatrix_m8n8_x4_trans_b16(half *shared_warp, int ax0
105
105
: " r" (addr));
106
106
}
107
107
108
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
108
109
__inline__ __device__ void cp_async_cg_A (uint32_t smem_int_ptr, const uint4 *__restrict__ src, bool mask)
109
110
{
110
111
const int cp_size = 16 ;
@@ -117,14 +118,37 @@ __inline__ __device__ void cp_async_cg_A(uint32_t smem_int_ptr, const uint4 *__r
117
118
" l" (src),
118
119
" n" (cp_size));
119
120
}
121
+ #endif
120
122
121
123
__device__ __inline__ void mma_m16n8k16 (float *C_warp, half *A_shared_warp, half *B_shared_warp)
122
124
{
125
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
126
+ __asm__ __volatile__ (
127
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
128
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};"
129
+ : " =f" (((float *)C_warp)[0 ]), " =f" (((float *)C_warp)[1 ]), " =f" (((float *)C_warp)[2 ]), " =f" (((float *)C_warp)[3 ])
130
+ : " r" (((unsigned *)A_shared_warp)[0 ]), " r" (((unsigned *)A_shared_warp)[1 ]),
131
+ " r" (((unsigned *)B_shared_warp)[0 ]),
132
+ " f" (((float *)C_warp)[0 ]), " f" (((float *)C_warp)[1 ]), " f" (((float *)C_warp)[2 ]), " f" (((float *)C_warp)[3 ])
133
+ );
134
+ __asm__ __volatile__ (
135
+ " mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
136
+ " {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};"
137
+ : " =f" (((float *)C_warp)[0 ]), " =f" (((float *)C_warp)[1 ]), " =f" (((float *)C_warp)[2 ]), " =f" (((float *)C_warp)[3 ])
138
+ : " r" (((unsigned *)A_shared_warp)[2 ]), " r" (((unsigned *)A_shared_warp)[3 ]),
139
+ " r" (((unsigned *)B_shared_warp)[1 ]),
140
+ " f" (((float *)C_warp)[0 ]), " f" (((float *)C_warp)[1 ]), " f" (((float *)C_warp)[2 ]), " f" (((float *)C_warp)[3 ])
141
+ );
142
+ #else
123
143
__asm__ __volatile__ (
124
144
" mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
125
145
" {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
126
146
: " =f" (((float *)C_warp)[0 ]), " =f" (((float *)C_warp)[1 ]), " =f" (((float *)C_warp)[2 ]), " =f" (((float *)C_warp)[3 ])
127
- : " r" (((unsigned *)A_shared_warp)[0 ]), " r" (((unsigned *)A_shared_warp)[1 ]), " r" (((unsigned *)A_shared_warp)[2 ]), " r" (((unsigned *)A_shared_warp)[3 ]), " r" (((unsigned *)B_shared_warp)[0 ]), " r" (((unsigned *)B_shared_warp)[1 ]), " f" (((float *)C_warp)[0 ]), " f" (((float *)C_warp)[1 ]), " f" (((float *)C_warp)[2 ]), " f" (((float *)C_warp)[3 ]));
147
+ : " r" (((unsigned *)A_shared_warp)[0 ]), " r" (((unsigned *)A_shared_warp)[1 ]), " r" (((unsigned *)A_shared_warp)[2 ]), " r" (((unsigned *)A_shared_warp)[3 ]),
148
+ " r" (((unsigned *)B_shared_warp)[0 ]), " r" (((unsigned *)B_shared_warp)[1 ]),
149
+ " f" (((float *)C_warp)[0 ]), " f" (((float *)C_warp)[1 ]), " f" (((float *)C_warp)[2 ]), " f" (((float *)C_warp)[3 ])
150
+ );
151
+ #endif
128
152
}
129
153
130
154
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
@@ -148,12 +172,14 @@ __device__ __inline__ void global_to_share_one_stage_A(half *src, half *dst, int
148
172
int ld_col_swizzled = (ld_col ^ (ld_row) & 7 ) * PACK_SIZE;
149
173
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
150
174
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K + cta_offset_k); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
175
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
151
176
if constexpr (STAGES > 1 )
152
177
{
153
178
uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
154
179
cp_async_cg_A (addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
155
180
}
156
181
else
182
+ #endif
157
183
{
158
184
if (local_mask & (ld_row + cta_offset_m < global_nrows))
159
185
*(uint4 *)dst_ptr = *src_ptr;
@@ -183,12 +209,14 @@ __device__ __inline__ void global_to_share_one_stage_B(half *src, half *dst, int
183
209
int ld_col_swizzled = ld_col ^ (ld_row % 2 ) & 7 ;
184
210
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
185
211
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE + cta_offset_k);
212
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
186
213
if constexpr (STAGES > 1 )
187
214
{
188
215
uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
189
216
cp_async_cg_A (addr, src_ptr, local_mask);
190
217
}
191
218
else
219
+ #endif
192
220
{
193
221
if (local_mask)
194
222
*(uint4 *)dst_ptr = *src_ptr;
@@ -212,6 +240,7 @@ __device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst
212
240
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx .x / threads_per_row) * global_ncols + (threadIdx .x % threads_per_row) * PACK_SIZE);
213
241
void *dst_ptr_z = (void *)(dst_z + (threadIdx .x / threads_per_row) * kSmemCol + (threadIdx .x % threads_per_row) * PACK_SIZE);
214
242
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx .x / threads_per_row) * global_ncols + (threadIdx .x % threads_per_row) * PACK_SIZE);
243
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
215
244
if (STAGES > 1 )
216
245
{
217
246
uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
@@ -220,6 +249,7 @@ __device__ __inline__ void global_to_share_one_stage_scales(half *src, half *dst
220
249
cp_async_cg_A (addr_z, src_ptr_z, local_mask);
221
250
}
222
251
else
252
+ #endif
223
253
{
224
254
if (local_mask)
225
255
{
@@ -606,12 +636,14 @@ __device__ __inline__ void global_to_share_one_stage_A_T2(half *src, half *dst,
606
636
int ld_col_swizzled = (ld_col ^ (ld_row) & 7 ) * PACK_SIZE;
607
637
void *dst_ptr = (void *)(dst + ld_row * kSmemCol + ld_col_swizzled);
608
638
uint4 *src_ptr = (uint4 *)(src + (ld_row + cta_offset_m) * global_ncols + ld_col * PACK_SIZE + global_iter_k * CTA_K); // cta_offset_m * global_ncols + global_iter * cta_step_m_or_n * global_ncols + threadIdx.y * warp_step_m_or_n * global_ncols + (threadIdx.x / threads_per_row) * global_ncols + global_iter_k * CTA_K + (threadIdx.x % threads_per_row) * PACK_SIZE);
639
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
609
640
if constexpr (STAGES > 1 )
610
641
{
611
642
uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
612
643
cp_async_cg_A (addr, src_ptr, local_mask & (ld_row + cta_offset_m < global_nrows));
613
644
}
614
645
else
646
+ #endif
615
647
{
616
648
if (local_mask & (ld_row + cta_offset_m < global_nrows))
617
649
*(uint4 *)dst_ptr = *src_ptr;
@@ -641,12 +673,14 @@ __device__ __inline__ void global_to_share_one_stage_B_T2(half *src, half *dst,
641
673
int ld_col_swizzled = ld_col ^ (ld_row % 2 ) & 7 ;
642
674
void *dst_ptr = (void *)(dst + (ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE));
643
675
uint4 *src_ptr = (uint4 *)(src + global_iter_k * CTA_K + cta_offset_n / kInterleave * global_ncols + ld_row * global_ncols + ld_col * PACK_SIZE);
676
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
644
677
if constexpr (STAGES > 1 )
645
678
{
646
679
uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
647
680
cp_async_cg_A (addr, src_ptr, local_mask);
648
681
}
649
682
else
683
+ #endif
650
684
{
651
685
if (local_mask)
652
686
*(uint4 *)dst_ptr = *src_ptr;
@@ -669,6 +703,7 @@ __device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *
669
703
uint4 *src_ptr = (uint4 *)(src + g_idx * global_ncols + cta_offset_n + (threadIdx .x % threads_per_row) * PACK_SIZE);
670
704
void *dst_ptr_z = (void *)(dst_z + (threadIdx .x % threads_per_row) * PACK_SIZE);
671
705
uint4 *src_ptr_z = (uint4 *)(src_z + g_idx * global_ncols + cta_offset_n + (threadIdx .x % threads_per_row) * PACK_SIZE);
706
+ #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
672
707
if (STAGES > 1 )
673
708
{
674
709
uint32_t addr = cast_smem_ptr_to_uint (dst_ptr);
@@ -677,6 +712,7 @@ __device__ __inline__ void global_to_share_one_stage_scales_T2(half *src, half *
677
712
cp_async_cg_A (addr_z, src_ptr_z, local_mask);
678
713
}
679
714
else
715
+ #endif
680
716
{
681
717
if (local_mask)
682
718
{
0 commit comments