Skip to content

Commit

Permalink
[ck] Fix compilation errors with Clang20.0.
Browse files Browse the repository at this point in the history
  • Loading branch information
jagadish-amd committed Jan 17, 2025
1 parent f661e4f commit 0203188
Show file tree
Hide file tree
Showing 2 changed files with 361 additions and 1 deletion.
4 changes: 3 additions & 1 deletion cmake/external/composable_kernel.cmake
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
set(PATCH_CLANG ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang_Build.patch)
set(PATCH_GFX12X ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Add_gfx12x_support.patch)
set(PATCH_Clang20 ${PROJECT_SOURCE_DIR}/patches/composable_kernel/Fix_Clang20_error.patch)

include(FetchContent)
FetchContent_Declare(composable_kernel
URL ${DEP_URL_composable_kernel}
URL_HASH SHA1=${DEP_SHA1_composable_kernel}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_CLANG} &&
${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_GFX12X}
${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_GFX12X} &&
${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PATCH_Clang20}
)

FetchContent_GetProperties(composable_kernel)
Expand Down
358 changes: 358 additions & 0 deletions cmake/patches/composable_kernel/Fix_Clang20_error.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
index 5d137e67..758f25a5 100644
--- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}

template <>
- __device__ static constexpr auto TailScheduler<1>()
+ __device__ constexpr auto TailScheduler<1>()
{
// schedule
constexpr auto num_ds_read_inst =
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
}

template <>
- __device__ static constexpr auto TailScheduler<2>()
+ __device__ constexpr auto TailScheduler<2>()
{
// schedule
constexpr auto num_ds_read_inst =
diff --git a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
index a1844316..409bb9f6 100644
--- a/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
+++ b/include/ck/tensor_operation/gpu/warp/dpp_gemm.hpp
@@ -324,55 +324,55 @@ struct DppSelector
static constexpr auto GetDpp();

template <>
- static constexpr auto GetDpp<half_t, 8, 32>()
+ constexpr auto GetDpp<half_t, 8, 32>()
{
return DppInstr::dpp8_f16_8x32x2;
}

template <>
- static constexpr auto GetDpp<half_t, 8, 16>()
+ constexpr auto GetDpp<half_t, 8, 16>()
{
return DppInstr::dpp8_f16_8x16x2;
}

template <>
- static constexpr auto GetDpp<half_t, 16, 16>()
+ constexpr auto GetDpp<half_t, 16, 16>()
{
return DppInstr::dpp8_f16_16x16x2;
}

template <>
- static constexpr auto GetDpp<half_t, 32, 8>()
+ constexpr auto GetDpp<half_t, 32, 8>()
{
return DppInstr::dpp8_f16_32x8x2;
}

template <>
- static constexpr auto GetDpp<half_t, 1, 32>()
+ constexpr auto GetDpp<half_t, 1, 32>()
{
return DppInstr::dpp8_f16_1x32x2;
}

template <>
- static constexpr auto GetDpp<half_t, 2, 32>()
+ constexpr auto GetDpp<half_t, 2, 32>()
{
return DppInstr::dpp8_f16_2x32x2;
}

template <>
- static constexpr auto GetDpp<half_t, 2, 16>()
+ constexpr auto GetDpp<half_t, 2, 16>()
{
return DppInstr::dpp8_f16_2x16x2;
}

template <>
- static constexpr auto GetDpp<half_t, 4, 16>()
+ constexpr auto GetDpp<half_t, 4, 16>()
{
return DppInstr::dpp8_f16_4x16x2;
}

template <>
- static constexpr auto GetDpp<half_t, 4, 32>()
+ constexpr auto GetDpp<half_t, 4, 32>()
{
return DppInstr::dpp8_f16_4x32x2;
}
diff --git a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
index 9a9ebf55..b435a2a1 100644
--- a/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+++ b/include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
@@ -415,7 +415,7 @@ struct WmmaSelector
static constexpr auto GetWmma();

template <>
- static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
+ constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
@@ -425,7 +425,7 @@ struct WmmaSelector
}

template <>
- static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
+ constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
@@ -435,19 +435,19 @@ struct WmmaSelector
}

template <>
- static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
+ constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
{
return WmmaInstr::wmma_f16_16x16x16_f16;
}

template <>
- static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
+ constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
{
return WmmaInstr::wmma_bf16_16x16x16_bf16;
}

template <>
- static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
+ constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
{
#ifdef __gfx12__
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
@@ -458,7 +458,7 @@ struct WmmaSelector

#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template <>
- static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
+ constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
{
return WmmaInstr::wmma_i32_16x16x16_iu4;
}
diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
index 835075b7..24fac91e 100644
--- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
+++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp
@@ -651,97 +651,97 @@ struct MfmaSelector
static constexpr auto GetMfma();

template <>
- static constexpr auto GetMfma<double, 16, 16>()
+ constexpr auto GetMfma<double, 16, 16>()
{
return MfmaInstr::mfma_f64_16x16x4f64;
}

template <>
- static constexpr auto GetMfma<float, 64, 64>()
+ constexpr auto GetMfma<float, 64, 64>()
{
return MfmaInstr::mfma_f32_32x32x1xf32;
}

template <>
- static constexpr auto GetMfma<float, 32, 64>()
+ constexpr auto GetMfma<float, 32, 64>()
{
return MfmaInstr::mfma_f32_32x32x1xf32;
}

template <>
- static constexpr auto GetMfma<float, 16, 64>()
+ constexpr auto GetMfma<float, 16, 64>()
{
return MfmaInstr::mfma_f32_16x16x1xf32;
}

template <>
- static constexpr auto GetMfma<float, 8, 64>()
+ constexpr auto GetMfma<float, 8, 64>()
{
return MfmaInstr::mfma_f32_4x4x1xf32;
}

template <>
- static constexpr auto GetMfma<float, 4, 64>()
+ constexpr auto GetMfma<float, 4, 64>()
{
return MfmaInstr::mfma_f32_4x4x1xf32;
}

template <>
- static constexpr auto GetMfma<float, 32, 32>()
+ constexpr auto GetMfma<float, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x2xf32;
}

template <>
- static constexpr auto GetMfma<float, 16, 16>()
+ constexpr auto GetMfma<float, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x4xf32;
}

template <>
- static constexpr auto GetMfma<half_t, 64, 64>()
+ constexpr auto GetMfma<half_t, 64, 64>()
{
return MfmaInstr::mfma_f32_32x32x4f16;
}

template <>
- static constexpr auto GetMfma<half_t, 32, 64>()
+ constexpr auto GetMfma<half_t, 32, 64>()
{
return MfmaInstr::mfma_f32_32x32x4f16;
}

template <>
- static constexpr auto GetMfma<half_t, 32, 32>()
+ constexpr auto GetMfma<half_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x8f16;
}

template <>
- static constexpr auto GetMfma<half_t, 16, 16>()
+ constexpr auto GetMfma<half_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x16f16;
}

template <>
- static constexpr auto GetMfma<half_t, 16, 64>()
+ constexpr auto GetMfma<half_t, 16, 64>()
{
return MfmaInstr::mfma_f32_16x16x4f16;
}

template <>
- static constexpr auto GetMfma<half_t, 8, 64>()
+ constexpr auto GetMfma<half_t, 8, 64>()
{
return MfmaInstr::mfma_f32_4x4x4f16;
}

template <>
- static constexpr auto GetMfma<half_t, 4, 64>()
+ constexpr auto GetMfma<half_t, 4, 64>()
{
return MfmaInstr::mfma_f32_4x4x4f16;
}

template <>
- static constexpr auto GetMfma<bhalf_t, 32, 32>()
+ constexpr auto GetMfma<bhalf_t, 32, 32>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
@@ -751,7 +751,7 @@ struct MfmaSelector
}

template <>
- static constexpr auto GetMfma<bhalf_t, 16, 16>()
+ constexpr auto GetMfma<bhalf_t, 16, 16>()
{
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
@@ -762,72 +762,72 @@ struct MfmaSelector

#if defined(CK_USE_AMD_MFMA_GFX940)
template <>
- static constexpr auto GetMfma<int8_t, 32, 32>()
+ constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x16i8;
}
template <>
- static constexpr auto GetMfma<int8_t, 16, 16>()
+ constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x32i8;
}
#else
template <>
- static constexpr auto GetMfma<int8_t, 32, 32>()
+ constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x8i8;
}
template <>
- static constexpr auto GetMfma<int8_t, 16, 16>()
+ constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x16i8;
}
#endif

template <>
- static constexpr auto GetMfma<f8_t, 32, 32>()
+ constexpr auto GetMfma<f8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16f8f8;
}

template <>
- static constexpr auto GetMfma<f8_t, 16, 16>()
+ constexpr auto GetMfma<f8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32f8f8;
}

template <>
- static constexpr auto GetMfma<bf8_t, 32, 32>()
+ constexpr auto GetMfma<bf8_t, 32, 32>()
{
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
}

template <>
- static constexpr auto GetMfma<bf8_t, 16, 16>()
+ constexpr auto GetMfma<bf8_t, 16, 16>()
{
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
}

template <>
- static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
+ constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
{
return MfmaInstr::mfma_f32_32x32x16f8bf8;
}

template <>
- static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
+ constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
{
return MfmaInstr::mfma_f32_16x16x32f8bf8;
}

template <>
- static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
+ constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
{
return MfmaInstr::mfma_f32_32x32x16bf8f8;
}

template <>
- static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
+ constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
{
return MfmaInstr::mfma_f32_16x16x32bf8f8;
}

0 comments on commit 0203188

Please sign in to comment.