-
Notifications
You must be signed in to change notification settings - Fork 22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SM90 Support #126
SM90 Support #126
Changes from all commits
63782d4
bb1a825
d5975e8
9ebafdc
d49a52e
d31c736
546934c
e1f5516
9495689
21fc904
628ecb4
3206cea
2ca58ae
5e3802f
e0f5a3e
f88ea00
0c9d5e1
d1badb7
d3d97ab
c9e395e
73b6c7d
dc56737
3798c3a
83f0547
386b14b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -76,12 +76,20 @@ | |
#include "cutlass/util/tensor_view_io.h" | ||
#include "cutlass/util/reference/device/gemm.h" | ||
#include "cutlass/util/reference/device/tensor_compare.h" | ||
#if defined(SYCL_NVIDIA_TARGET) | ||
#include "cutlass/util/reference/device/sycl_tensor_fill.h" | ||
#else | ||
#include "cutlass/util/reference/device/tensor_fill.h" | ||
#endif | ||
|
||
#include "helper.h" | ||
|
||
using namespace cute; | ||
|
||
#if defined(SYCL_NVIDIA_TARGET) | ||
using namespace cutlass; | ||
#endif | ||
|
||
Comment on lines
+89
to
+92
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because types like cudaError_t and cudaSuccess are defined in the cutlass namespace in the non cuda path |
||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) | ||
|
||
///////////////////////////////////////////////////////////////////////////////////////////////// | ||
|
@@ -379,7 +387,11 @@ bool verify(const Options &options) { | |
ref_D); | ||
|
||
// Wait for kernel to finish | ||
CUDA_CHECK(cudaDeviceSynchronize()); | ||
#if defined(SYCL_NVIDIA_TARGET) | ||
syclcompat::wait_and_throw(); | ||
#else | ||
CUDA_CHECK(cudaDeviceSynchronize()); | ||
#endif | ||
|
||
// Check if output from CUTLASS kernel and reference kernel are equal or not | ||
bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); | ||
|
@@ -427,10 +439,10 @@ int run(Options &options) | |
// Run profiling loop | ||
if (options.iterations > 0) | ||
{ | ||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); | ||
GpuTimer timer; | ||
timer.start(); | ||
for (int iter = 0; iter < options.iterations; ++iter) { | ||
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); | ||
CUTLASS_CHECK(gemm.run()); | ||
} | ||
timer.stop(); | ||
|
@@ -466,6 +478,7 @@ int main(int argc, char const **args) { | |
|
||
// CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example | ||
// and must have compute capability at least 90. | ||
#if !defined(SYCL_NVIDIA_TARGET) | ||
if (__CUDACC_VER_MAJOR__ < 12) { | ||
std::cerr << "This example requires CUDA 12 or newer.\n"; | ||
// Returning zero so this test passes on older Toolkits. Its actions are no-op. | ||
|
@@ -483,6 +496,7 @@ int main(int argc, char const **args) { | |
<< "later (compute capability 90 or greater).\n"; | ||
return 0; | ||
} | ||
#endif | ||
// | ||
// Parse options | ||
// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -355,7 +355,7 @@ struct SM90_TMA_LOAD_IM2COL_3D | |
uint64_t gmem_int_desc = reinterpret_cast<uint64_t>(desc_ptr); | ||
uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); | ||
uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); | ||
// Copy from global to shared::cluster. | ||
// Copy from global to shared::cluster | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Revert? |
||
asm volatile ( | ||
"cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" | ||
" [%0], [%1, {%3, %4, %5}], [%2], {%6};" | ||
|
@@ -1113,7 +1113,7 @@ CUTE_HOST_DEVICE static void | |
tma_store_fence() { | ||
#if defined(CUTE_ARCH_TMA_SM90_ENABLED) | ||
asm volatile ("fence.proxy.async.shared::cta;"); | ||
#elif defined(__CUDA_ARCH__) | ||
#elif defined(__CUDA_ARCH__) || (__SYCL_CUDA_ARCH__) | ||
CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); | ||
#endif | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,7 +33,9 @@ | |
#include <cute/config.hpp> | ||
#include <cute/arch/mma.hpp> | ||
// Config | ||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) | ||
#if ((defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || \ | ||
(defined(__SYCL_CUDA_ARCH__) && (__SYCL_CUDA_ARCH__ >= 900))) && \ | ||
defined(__CUDA_ARCH_FEAT_SM90_ALL) | ||
# define CUTE_ARCH_MMA_SM90A_ENABLED | ||
#endif | ||
|
||
|
@@ -84,15 +86,15 @@ warpgroup_fence_operand(uint32_t& reg) { | |
// MSVC emits a build error for 'asm volatile' | ||
// even if it only occurs in a __device__ function. | ||
// This prevents the error. | ||
#if defined(__CUDA_ARCH__) | ||
#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This SYCL_CUDA_ARCH seems to create a lot of noise in the code can we we wrap it up with cuda_arch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, we cannot do that yet. |
||
asm volatile("" : "+r"(reg) :: "memory"); | ||
#endif | ||
} | ||
|
||
CUTE_HOST_DEVICE | ||
void | ||
warpgroup_fence_operand(float& reg) { | ||
#if defined(__CUDA_ARCH__) | ||
#if defined(__CUDA_ARCH__) || defined(__SYCL_CUDA_ARCH__) | ||
asm volatile("" : "+f"(reg) :: "memory"); | ||
#endif | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -762,7 +762,7 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and | |
#include <cute/atom/copy_traits_sm90.hpp> | ||
|
||
// Config | ||
#if (__CUDACC_VER_MAJOR__ >= 12) | ||
#if (__CUDACC_VER_MAJOR__ >= 12) || defined(SYCL_NVIDIA_TARGET) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we use PTX version for SYCL instead of SYCL_NVIDIA_TARGET. Since SYCL_NVIDIA_TARGET is more generic than versioning |
||
# define CUTE_COPY_ATOM_TMA_SM90_ENABLED | ||
#endif | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -78,8 +78,8 @@ struct TMA_LOAD_Unpack | |
#if 0 | ||
auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); | ||
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", | ||
threadIdx.x, threadIdx.y, threadIdx.z, | ||
blockIdx.x, blockIdx.y, blockIdx.z, | ||
ThreadIdxX(), ThreadIdxY(), ThreadIdxZ(), | ||
BlockIdxX(), BlockIdxY(), BlockIdxZ(), | ||
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); | ||
#endif | ||
return detail::explode_tuple(detail::CallCOPY<CopyOp>{}, | ||
|
@@ -314,8 +314,8 @@ struct TMA_STORE_Unpack | |
#if 0 | ||
auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); | ||
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", | ||
threadIdx.x, threadIdx.y, threadIdx.z, | ||
blockIdx.x, blockIdx.y, blockIdx.z, | ||
ThreadIdxX(), ThreadIdxY(), ThreadIdxZ(), | ||
BlockDimX(), BlockDimY(), BlockDimZ(), | ||
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); | ||
#endif | ||
return detail::explode_tuple(detail::CallCOPY<SM90_TMA_STORE>{}, | ||
|
@@ -375,8 +375,8 @@ struct Copy_Traits<SM90_TMA_STORE, NumBitsPerTMA, AuxParams_> | |
#if 0 | ||
auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); | ||
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", | ||
threadIdx.x, threadIdx.y, threadIdx.z, | ||
blockIdx.x, blockIdx.y, blockIdx.z, | ||
ThreadIdxX(), ThreadIdxY(), ThreadIdxZ(), | ||
BlockIdxX(), BlockIdxY(), BlockIdxZ(), | ||
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); | ||
#endif | ||
return detail::explode_tuple(detail::CallCOPY<SM90_TMA_STORE>{}, | ||
|
@@ -457,8 +457,8 @@ struct Copy_Traits<SM90_TMA_REDUCE_ADD, NumBitsPerTMA, AuxParams_> | |
#if 0 | ||
auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); | ||
printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", | ||
threadIdx.x, threadIdx.y, threadIdx.z, | ||
blockIdx.x, blockIdx.y, blockIdx.z, | ||
ThreadIdxX(), ThreadIdxY(), ThreadIdxZ(), | ||
BlockIdxX(), BlockIdxY(), BlockIdxZ(), | ||
int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); | ||
#endif | ||
|
||
|
@@ -974,7 +974,8 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin | |
// TMA general info | ||
// | ||
|
||
#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) | ||
#if ((__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)) || \ | ||
defined(SYCL_NVIDIA_TARGET) | ||
|
||
CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType<TmaInternalType>(); | ||
CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; | ||
|
@@ -984,7 +985,7 @@ make_tma_copy_desc(Tensor<GEngine,GLayout> const& gtensor, // The origin | |
// TMA smem swizzle type | ||
CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); | ||
CUresult result = cuTensorMapEncodeTiled( | ||
&tma_desc, | ||
reinterpret_cast<CUtensorMap*>(&tma_desc), | ||
Comment on lines
-987
to
+988
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the CuTensorMapEncodeTiled accepts a pointer to CUtensorMap, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I must clarify that this change is only temporary, |
||
tma_format, | ||
tma_dim, | ||
gmem_address, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,25 +46,31 @@ | |
|
||
//////////////////////////////////////////////////////////////////////////////// | ||
|
||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) | ||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) || \ | ||
defined(SYCL_NVIDIA_TARGET) | ||
#define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED | ||
#if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED)) | ||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || \ | ||
(defined(__SYCL_CUDA_ARCH__) && (__SYCL_CUDA_ARCH__ >= 900) && \ | ||
defined(__PTX_VERSION__) && (__PTX_VERSION__ >= 80)) | ||
#define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED | ||
#endif | ||
#endif | ||
#endif | ||
|
||
#if (__CUDACC_VER_MAJOR__ >= 12) | ||
#if (__CUDACC_VER_MAJOR__ >= 12) || defined(SYCL_NVIDIA_TARGET) | ||
#define CUTLASS_ARCH_MMA_SM90_SUPPORTED | ||
#if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED)) | ||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) | ||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) || \ | ||
(defined(__SYCL_CUDA_ARCH__) && (__SYCL_CUDA_ARCH__ >= 900) &&\ | ||
defined(__PTX_VERSION__) && (__PTX_VERSION__ >= 80)) | ||
#define CUTLASS_ARCH_MMA_SM90_ENABLED | ||
#endif | ||
#endif | ||
#endif | ||
|
||
#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) | ||
#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 3))) || \ | ||
defined(SYCL_NVIDIA_TARGET) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here as well, the SYCL_NVIDIA_TARGET covers a wide range of targets including SM80. we need to use PTX version here or at least make sure that the Nvidia target >= 900 |
||
#define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED | ||
#endif | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This flag was moved to line 58
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a TODO comment which I thought I had added as a part of 0c9d5e1 over here,
which basically was basically about investigating why this line is needed,
I was aware of this change, but for some reason I was still seeing a kernel
*_with_offset
, hence I added that as a temporary fix,This is also partly the reason why this PR is draft