diff --git a/library/include/rocwmma/internal/config.hpp b/library/include/rocwmma/internal/config.hpp index d61dd51e..08b112b5 100644 --- a/library/include/rocwmma/internal/config.hpp +++ b/library/include/rocwmma/internal/config.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -146,6 +146,11 @@ static_assert(0, "Unsupported architecture"); #define ROCWMMA_ARCH_GFX94X 1 #endif +#if ROCWMMA_ARCH_HOST +#define ROCWMMA_BLOCK_DIM_16_SUPPORTED 1 +#define ROCWMMA_BLOCK_DIM_32_SUPPORTED 1 +#endif + #if !defined(ROCWMMA_ARCH_GFX9) #define ROCWMMA_ARCH_GFX9 0 #endif @@ -201,10 +206,10 @@ static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DI #endif #if ROCWMMA_ARCH_GFX12 - static_assert((bool)(ROCWMMA_WAVE32_MODE) && !(bool)(ROCWMMA_WAVE64_MODE), - "rocWMMA supports only wave32 for gfx12 arch"); - static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DIM_32_SUPPORTED), - "rocWMMA supports only block size of 16 for gfx12 arch"); +static_assert((bool)(ROCWMMA_WAVE32_MODE) && !(bool)(ROCWMMA_WAVE64_MODE), + "rocWMMA supports only wave32 for gfx12 arch"); +static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DIM_32_SUPPORTED), + "rocWMMA supports only block size of 16 for gfx12 arch"); #endif /// diff --git a/library/include/rocwmma/internal/coop_io_config.hpp b/library/include/rocwmma/internal/coop_io_config.hpp index 9b39fded..b3bf287e 100644 --- a/library/include/rocwmma/internal/coop_io_config.hpp +++ b/library/include/rocwmma/internal/coop_io_config.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ #ifndef ROCWMMA_COOP_IO_CONFIG_HPP #define ROCWMMA_COOP_IO_CONFIG_HPP +#include "./layout/register_layout_transforms.hpp" #include "coop_load.hpp" #include "coop_store.hpp" #include "io_layout.hpp" @@ -85,6 +86,15 @@ namespace rocwmma typename IOLayout::MatrixLayout, IOLayout::VW>; + using PostLoadXForm = register_layout_transform; + + using PreMmaXForm = register_layout_transform; + + using PreStoreXForm = register_layout_transform; + using Storer = CooperativeStore; + using PostLoadXForm = register_layout_transform; + + using PreMmaXForm = register_layout_transform; + + // Currently, only makes sense to have a post-mma transform on acc layouts + using PostMmaXForm = conditional_t, + register_layout_transform, + register_layout_transform_nop>; + + using PreStoreXForm = register_layout_transform; + using Storer = OpaqueStore struct IOConfig { - using IOShape = IOShape; - using IOTraits = IOTraits; - using PackUtil = PackUtil; + using IOShape = IOShape; + using IOLayout = IOLayout; + using IOTraits = IOTraits; + using PackUtil = PackUtil; using Broadcaster = Broadcast; + + using PreMmaXForm = register_layout_transform; + + using PostMmaXForm = register_layout_transform; }; /** @}*/ diff --git a/library/include/rocwmma/internal/io_layout.hpp b/library/include/rocwmma/internal/io_layout.hpp index 1e6a0555..ec509e6b 100644 --- a/library/include/rocwmma/internal/io_layout.hpp +++ b/library/include/rocwmma/internal/io_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,7 @@ #include "api_fwd.hpp" #include "constants.hpp" -#include "layout.hpp" +#include "layout/layout.hpp" #include "types.hpp" namespace rocwmma @@ -45,47 +45,72 @@ namespace rocwmma = 4u * Constants::AMDGCN_DWORD_SIZE_BYTES / (uint32_t)sizeof(DataT)> struct MaxVWSelector { - private: - enum : uint32_t - { - // For small block sizes (16, 32): - // Best to keep MaxVW high and reduce splits amongst waves. - WaveCountFactor = (BlockDim <= 32) ? 1u : WaveCount, - - // Total number of elements in a single I/O operation - ElementsPerIO = Constants::AMDGCN_WAVE_SIZE * TestWidth * WaveCountFactor, - - // Total number of elements for the entire block - ElementCount = BlockDim * BlockK, - - // Ensure that for MaxVW: - // - A minimum of one IO from each wave can fit - // - A balanced multiple of IOs from each wave - ElementCountTest - = (ElementsPerIO <= ElementCount) && (ElementCount % ElementsPerIO == 0), - - // Currently, all layouts are using ColOrthoVW. This means that VW must be less than BlockK - LeadingDimTest = (TestWidth <= BlockK), - - MaxVectorWidth = (bool)ElementCountTest && (bool)LeadingDimTest - ? TestWidth - : MaxVWSelector::Result, - }; + // For small block sizes (16, 32): + // Best to keep MaxVW high and reduce splits amongst waves. + static constexpr uint32_t WaveCountFactor = (BlockDim <= 32) ? 1u : WaveCount; + + // Total number of elements in a single I/O operation + static constexpr uint32_t ElementsPerIO + = Constants::AMDGCN_WAVE_SIZE * TestWidth * WaveCountFactor; + + // Total number of elements for the entire block + static constexpr uint32_t ElementCount = BlockDim * BlockK; + + // Ensure that for MaxVW: + // - A minimum of one IO from each wave can fit + // - A balanced multiple of IOs from each wave + static constexpr bool ElementCountTest + = (ElementsPerIO <= ElementCount) && (ElementCount % ElementsPerIO == 0); + + // Layout fitness check: + // Basic non-interleaved layouts are classified into *OrthoVW (SOA) and *InlineVW (AOS) formats. + // For any BlockDim/BlockK geometry, we ensure that these layouts come up with the same MaxVW, + // so that the AOS <-> SOA transforms are possible and valid. The followings tests assure this. + static constexpr bool BlockKTest = (Constants::AMDGCN_WAVE_SIZE * TestWidth / min(BlockDim, Constants::AMDGCN_WAVE_SIZE)) <= BlockK; + static constexpr bool OrthoTest = TestWidth <= BlockK; + static constexpr bool InlineTest = TestWidth <= BlockDim; + static constexpr bool LayoutFitnessTest = (BlockKTest && OrthoTest && InlineTest); + + // Decide on final MaxVW + static constexpr uint32_t MaxVectorWidth = (ElementCountTest && LayoutFitnessTest) + ? TestWidth + : MaxVWSelector::Result; public: - enum : uint32_t - { - Result = (uint32_t)MaxVectorWidth - }; + static constexpr uint32_t Result = MaxVectorWidth; + }; + + // Accumulator case, is architecture specific + template + struct MaxVWSelector + { + static_assert(WaveCount == 1u, "Accumulators are not cooperative"); + + constexpr static uint32_t Result + = (bool)ROCWMMA_ARCH_GFX12 + ? 8u + : ((is_same_v || (bool)ROCWMMA_ARCH_GFX11) ? 1u : 4u); }; + // Fallback case for bad test. Stay safe to VW=1 template struct MaxVWSelector { - enum : uint32_t - { - Result = 1u - }; + static constexpr uint32_t Result = 1u; }; } // namespace detail /*! \struct IOLayout - * \brief Definition of VW, MaxVW, data and matrix mapping utilities - * in specific matrix context. - * - * @tparam MatrixT fragment context - * @tparam BlockDim Block leading dimension - * @tparam BlockK Block K-dimension - * @tparam DataT data type - * @tparam DataLayoutT in-memory layout as col_major or row_major - * @tparam WaveCount number of cooperative waves - */ + * \brief Definition of VW, MaxVW, data and matrix mapping utilities + * in specific matrix context. + * + * @tparam MatrixT fragment context + * @tparam BlockDim Block leading dimension + * @tparam BlockK Block K-dimension + * @tparam DataT data type + * @tparam DataLayoutT in-memory layout as col_major or row_major + * @tparam WaveCount number of cooperative waves + */ template { // Vector size properties - enum : uint32_t - { - MaxVW = detail:: - MaxVWSelector::Result, + constexpr static uint32_t MaxVW = detail:: + MaxVWSelector::Result; - VW = is_same::value || BlockDim > 32 ? MaxVW : 1u - }; + constexpr static uint32_t VW + = is_same_v || BlockDim > 32u ? MaxVW : 1u; + + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Matrix Layouts + // Small dim mma friendly + using SmallDimMatrixLayout + = MatrixLayout::ColOrthoVW; + + // Large dim not mma friendly + using LargeDimMatrixLayout + = conditional_t, + MatrixLayout::ColInlineVW, + MatrixLayout::ColOrthoVW>; + + using MatrixLayout + = conditional_t; - // Layout profile for 'matrix_a': ColNT for small frags, Col for large frags - using Profile = conditional_t< - BlockDim <= 32, - LayoutProfile::template ColNT, - LayoutProfile::template Col>; + // Register layout direct to memory storage (load / store) + using StorageLayout = RegisterLayout::Storage; - using DataLayout = typename Profile::DataLayout; - using MatrixLayout = typename Profile::MatrixLayout; - using RegisterLayout = typename Profile::RegisterLayout; + // Register layout required for mma. Expect non-interleaved SOA format. + // Quirk: gfx11 requires input duplication. + using MmaLayout = RegisterLayout::MmaInput; + // Fragments will keep storage layout + using FragmentLayout = StorageLayout; }; template { // Vector size properties - enum : uint32_t - { - MaxVW = detail:: - MaxVWSelector::Result, + constexpr static uint32_t MaxVW = detail:: + MaxVWSelector::Result; - VW = is_same::value || BlockDim > 32 ? MaxVW : 1u - }; + constexpr static uint32_t VW + = is_same_v || BlockDim > 32 ? MaxVW : 1u; + + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Matrix Layouts + // Small dim mma friendly + using SmallDimMatrixLayout = + MatrixLayout::RowOrthoVW; + + // Large dim not mma friendly + using LargeDimMatrixLayout + = conditional_t, + MatrixLayout::RowInlineVW, + MatrixLayout::RowOrthoVW>; + + using MatrixLayout + = conditional_t; - // Layout profile for 'matrix_b': RowNT for small frags, Row for large frags - using Profile = conditional_t< - BlockDim <= 32, - LayoutProfile::template RowNT, - LayoutProfile::template Row>; + // Register layout direct to memory storage (load / store) + using StorageLayout = RegisterLayout::Storage; - using DataLayout = typename Profile::DataLayout; - using MatrixLayout = typename Profile::MatrixLayout; - using RegisterLayout = typename Profile::RegisterLayout; + // Register layout required for mma. Expect non-interleaved SOA format. + // Quirk: gfx11 requires input duplication. + using MmaLayout = RegisterLayout::MmaInput; + + // Fragments will keep storage register layout. + using FragmentLayout = StorageLayout; }; template { // Vector size properties - enum : uint32_t + constexpr static uint32_t MaxVW = detail:: + MaxVWSelector::Result; + + constexpr static uint32_t VW = is_same_v ? MaxVW : 1u; + + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Always mma friendly + using MatrixLayout + = MatrixLayout::RowOrthoVW; + + // Register layout direct to memory storage (load / store) + using StorageLayout = RegisterLayout::Storage; + + // Register layout required for mma. Expect non-interleaved SOA format. + // Quirk: gfx11 requires padded acc. + using MmaLayout = RegisterLayout::MmaAcc; + + // Fragments will keep storage register layout. + using FragmentLayout = StorageLayout; + }; + + template + struct IOLayout + { + // We don't know which storage is needed: no DataLayout + using StorageLayout = void; + + // Register layout required for mma. Expect non-interleaved SOA format. + // Quirk: gfx11 requires padded acc. + using MmaLayout = RegisterLayout::MmaAcc; + + // Fragments will assume default mma register layout. + using FragmentLayout = RegisterLayout::MmaAcc; + }; + + namespace detail + { + template + struct MmaDimSelector + { + private: + // Smallest valid mma dim for mfma/wmma. + // Test MmaDim must not exceed BlockDim for valid layout. + static constexpr uint32_t MinMmaDim = 16u; + static constexpr uint32_t TestMmaDim = min(BlockDim, MmaDim); + + // For valid mma sizes, (BlockDim >= 16) + // Find minimum 16 byte load with MmaDim of 32 or 16 + static constexpr uint32_t MinLargeBytes = 16u; + static constexpr uint32_t DimPerThread = BlockDim / TestMmaDim; + static constexpr uint32_t BytesPerThread = DimPerThread * sizeof(DataT); + static constexpr uint32_t MmaDimResult = (BytesPerThread < MinLargeBytes ? MinMmaDim : TestMmaDim); + + // For invalid mma sizes (BlockDim < 16), we can have smaller MmaDim to increase VW. + // Try to balance DimPerThread and KPerThread by aiming to get half BlockDim bytes. + static constexpr bool SmallDim = TestMmaDim < MinMmaDim; + static constexpr uint32_t MinSmallBytes = BlockDim / 2u * sizeof(DataT); + static constexpr uint32_t SmallDimResult = (BytesPerThread < MinSmallBytes) ? + MmaDimSelector::Result : TestMmaDim; + + public: + static constexpr uint32_t Result = SmallDim ? SmallDimResult : MmaDimResult; + }; + + template + struct MmaDimSelector { - MaxVW = ROCWMMA_ARCH_GFX12 - ? 8u - : ((is_same::value || ROCWMMA_ARCH_GFX11) ? 1u : 4u), - VW = is_same::value ? MaxVW : 1u + static constexpr uint32_t Result = 1u; }; - // Layout profile for 'accumulator' set to RowNT - using Profile - = LayoutProfile::template RowNT; + } // namespace detail + + /*! \struct IOLayoutInt + * \brief Definition of VW, MaxVW, data and matrix mapping utilities + * in specific matrix context. + * + * @tparam MatrixT fragment context + * @tparam BlockDim Block leading dimension + * @tparam BlockK Block K-dimension + * @tparam DataT data type + * @tparam DataLayoutT in-memory layout as col_major or row_major + * @tparam WaveCount number of cooperative waves + */ + template + struct IOLayoutInt; + + template + struct IOLayoutInt + { + // Select an appropriate MmaDim + constexpr static uint32_t MmaDim = detail::MmaDimSelector::Result; + + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Matrix Layouts + using MatrixLayout + = conditional_t, + MatrixLayout::ColInlineInt, + MatrixLayout::ColOrthoInt>; + + // Register layout direct to memory storage (load / store) + using StorageLayout = RegisterLayout::Storage; + + // Register layout required for mma. Expect interleaved SOA format. + // Quirk: gfx11 requires input duplication. + using MmaLayout = RegisterLayout::MmaInput; + + // Fragments will keep storage register layout. + using FragmentLayout = StorageLayout; + + // Vector size properties derived from the matrix layout + constexpr static uint32_t MaxVW = layout_traits::MaxVectorWidth; + constexpr static uint32_t VW = MaxVW; + }; + + template + struct IOLayoutInt + { + // Select an appropriate MmaDim + constexpr static uint32_t MmaDim = detail::MmaDimSelector::Result; + + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Matrix Layouts + using MatrixLayout + = conditional_t, + MatrixLayout::RowOrthoInt, + MatrixLayout::RowInlineInt>; + + // Register layout direct to memory storage (load / store) + using StorageLayout = RegisterLayout::Storage; + + // Register layout required for mma. Expect interleaved SOA format. + // Quirk: gfx11 requires input duplication. + using MmaLayout = RegisterLayout::MmaInput; + // Fragments will keep storage register layout. + using FragmentLayout = StorageLayout; + + // Vector size properties derived from the matrix layout + constexpr static uint32_t MaxVW = layout_traits::MaxVectorWidth; + constexpr static uint32_t VW = MaxVW; + }; + + template + struct IOLayoutInt + { + // Select an appropriate MmaDim + constexpr static uint32_t MmaDim = detail::MmaDimSelector::Result; + + // DataLayout + using DataLayout = DataLayout::template Array1d; + + // Matrix Layouts + using MatrixLayout + = conditional_t, + MatrixLayout::RowOrthoInt, + MatrixLayout::RowInlineInt>; + + // Register layout direct to memory storage (load / store) + using StorageLayout = RegisterLayout::Storage; - using DataLayout = typename Profile::DataLayout; - using MatrixLayout = typename Profile::MatrixLayout; - using RegisterLayout = typename Profile::RegisterLayout; + // Register layout required for mma. Expect interleaved accum format for multiple blocks. + // Quirk: gfx11 requires padded mma acc + using MmaLayout + = RegisterLayout::MmaAcc; + + // Fragments will keep mma register layout. + using FragmentLayout + = RegisterLayout::MmaAcc; + + // Vector size properties derived from the matrix layout + constexpr static uint32_t MaxVW = layout_traits::MaxVectorWidth; + constexpr static uint32_t VW = MaxVW; }; template - struct IOLayout + struct IOLayoutInt { - // No layout mapping without VW, MaxVW and DataLayoutT info + // Select an appropriate MmaDim + constexpr static uint32_t MmaDim = detail::MmaDimSelector::Result; + + // We don't know which storage is needed: no DataLayout + using StorageLayout = void; + + // Register layout required for mma. Expect interleaved accum format for multiple blocks. + // Quirk: gfx11 requires padded mma acc + using MmaLayout + = RegisterLayout::MmaAcc; + + // Fragments will keep mma interleaved layout. + using FragmentLayout = RegisterLayout::MmaAcc; }; } // namespace rocwmma diff --git a/library/include/rocwmma/internal/layout.hpp b/library/include/rocwmma/internal/layout.hpp deleted file mode 100644 index c85c0c83..00000000 --- a/library/include/rocwmma/internal/layout.hpp +++ /dev/null @@ -1,305 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef ROCWMMA_LAYOUT_HPP -#define ROCWMMA_LAYOUT_HPP - -#include "mapping_util.hpp" - -namespace rocwmma -{ - // In relation to matrix space, DataLayouts describe whether consecutive elements in 1D data arrays are: - // 1. Contiguous rows (row_major) - // 2. Contiguous columns (col_major) - namespace DataLayout - { - template - using Array1d = typename ::rocwmma::detail::template DataSpace; - - using RowMajor = Array1d; - using ColMajor = Array1d; - - } // namespace DataLayout - - // In 2D space, Matrix Layouts describe per-thread offset coordinates and iterative spaces - // 1. Base thread offsets - // 2. Stride offsets - // 3. Stride spaces (counts) - // 4. Per-iteration offsets (stride step based on iteration) - // 5. Cumulative offsets (cumulative stride steps based on iteration) - namespace MatrixLayout - { - template - struct ColOrthoVW; - - template - struct ColInlineVW; - - template - struct RowOrthoVW; - - template - struct RowInlineVW; - - } // namespace MatrixLayout - - // Register layouts describe whether contiguous BlockDim elements are: - // 1. Captured in the same register lane as if the input were in Array-Of-Structures (AOS) - // 2. Captured across multiple register lanes as if the input were in Structure-Of-Arrays (SOA) - namespace RegisterLayout - { - template - struct Aos - { - }; - - template - struct Soa - { - }; - } - - // Layout profiles describe fragments in three mapped spaces: - // 1. DataLayout: data locality in memory space (row_major or col_major) - // 2. MatrixLayout: data locality in matrix space (ColOrthoVW, ColInlineVW, etc.) - // 3. RegisterLayout: data locality in register space (AOS or SOA) - namespace LayoutProfile - { - // ColNT is a layout profile that has the following properties: - // 1. Leading dimension is aligned with column elements of fragment data: - // - BlockDim is assumed the column size, or BlockM dimension. - // - Analogous to capturing columns of 'matrix A'. - // 2. Register elements are in MFMA friendly, or SOA register layout. - // 3. Register layout does NOT change whether DataLayout is col_major or row_major (fast DataLayoutT change). - // 4. MatrixLayout will capture contiguous column elements across multiple register lanes. - // 5. VectorWidth is fixed to 1 in col_major to ensure #3 (non-optimal). - template - struct ColNT - { - // Layouts - using DataLayout = DataLayout::template Array1d; - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::ColOrthoVW, - MatrixLayout::ColOrthoVW>; - using RegisterLayout = RegisterLayout::template Soa; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // ColNT enforces consistent in-register alignment of contiguous matrix column - // elements in both row_major or col_major data layouts. - // This layout cannot support for VW > 1 in col_major data layout otherwise the - // ordering is broken. - static_assert(!(is_same_v && VectorWidth > 1), - "ColNT in col_major does not support VectorWidth > 1"); - - // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - // RowNT is a layout profile that has the following properties: - // 1. Leading dimension is aligned with row elements of fragment data: - // - BlockDim is assumed the row size, or BlockN dimension. - // - Analogous to capturing rows of 'matrix B' or 'accumulator'. - // 2. Register elements are in MFMA friendly, or SOA register layout. - // 3. Register layout does NOT change whether DataLayout is col_major or row_major (fast DataLayoutT change). - // 4. MatrixLayout will capture contiguous row elements across multiple register lanes. - // 5. VectorWidth is fixed to 1 in row_major to ensure #3 (non-optimal). - template - struct RowNT - { - // Layouts - using DataLayout = DataLayout::template Array1d; - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::RowOrthoVW, - MatrixLayout::RowOrthoVW>; - using RegisterLayout = RegisterLayout::template Soa; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // RowNT enforces consistent in-register alignment of contiguous matrix row - // elements in both in row_major or col_major data layouts. - // This layout cannot support for VW > 1 in row_major data layout. - static_assert(!(is_same_v && VectorWidth > 1), - "RowNT in row_major does not support VectorWidth > 1"); - - // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - // Col is a layout profile that has the following properties: - // 1. Leading dimension is aligned with column elements of fragment data: - // - BlockDim is assumed the column size, or BlockM dimension. - // - Analogous to capturing columns of 'matrix A'. - // 2. Register layout is dynamic: - // - col_major data is stored in AOS register layout (non-MFMA friendly), and - // - row_major data is stored in SOA register layout (MFMA friendly). - // - Both data layouts cover the same geometric elements (transform friendly). - // 3. Register layout DOES change whether DataLayout is col_major or row_major (cost for DataLayoutT change). - // 4. VectorWidth is NOT fixed to 1 in either data layout (optimal). - // 5. User must convert to SOA if using with MFMA. - template - struct Col - { - // Layouts - using DataLayout = DataLayout::template Array1d; - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::ColInlineVW, - MatrixLayout::ColOrthoVW>; - using RegisterLayout - = conditional_t, - RegisterLayout::template Aos, - RegisterLayout::template Soa>; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - // Row is a layout profile that has the following properties: - // 1. Leading dimension is aligned with row elements of fragment data: - // - BlockDim is assumed the row size, or BlockN dimension. - // - Analogous to capturing rows of 'matrix B' or 'accumulator'. - // 2. Register layout is dynamic: - // - row_major data is stored in AOS register layout (non-MFMA friendly), and - // - col_major data is stored in SOA register layout (MFMA friendly). - // - Both data layouts cover the same geometric elements (transform friendly). - // 3. Register layout DOES change whether DataLayout is col_major or row_major (cost for DataLayoutT change). - // 4. VectorWidth is NOT fixed to 1 in either data layout (optimal). - // 5. User must convert to SOA if using with MFMA. - template - struct Row - { - // Layouts - using DataLayout = DataLayout::template Array1d; - using MatrixLayout = conditional_t< - is_same_v, - MatrixLayout::RowInlineVW, - MatrixLayout::RowOrthoVW>; - using RegisterLayout - = conditional_t, - RegisterLayout::template Aos, - RegisterLayout::template Soa>; - - // Mapping - using MappingUtil = MappingUtil; - using MatrixCoordT = typename MappingUtil::MatrixCoordT; - - // Sanity checks - // Must ensure that MaxVectorWidth fits inside the leading dimension - static_assert( - !(is_same_v && (MaxVectorWidth > BlockK)), - "MaxVectorWidth is larger than BlockK dimension. Try reducing MaxVectorWidth"); - }; - - } // namespace FragmentLayout - - /// - /// Helper to ensure layout types are consistent (same, or equivalent) - /// - template - struct ConsistencyCheck : public false_type - { - }; - - /// - /// Check for layout orthogonality - /// - template - struct OrthogonalCheck : public false_type - { - }; - - template - struct OrthogonalLayout; - - template - using orthogonal_layout_t = typename OrthogonalLayout::Type; - - template - struct is_orthogonal; - - template - inline constexpr bool is_orthogonal_v = is_orthogonal::value; - -} // namespace rocwmma - -#include "layout_impl.hpp" - -#endif // ROCWMMA_LAYOUT_HPP diff --git a/library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp new file mode 100644 index 00000000..dfdeca09 --- /dev/null +++ b/library/include/rocwmma/internal/layout/data_layout_traits_impl.hpp @@ -0,0 +1,198 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_DATA_LAYOUT_TRAITS_IMPL_HPP +#define ROCWMMA_DATA_LAYOUT_TRAITS_IMPL_HPP + +#include "layout.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + namespace LayoutTraits_impl + { + // Reference regular layouts + using DataLayout::ColMajor; + using DataLayout::RowMajor; + + // Build a basic set of meta-data classifiers. + // We will be interested in knowing things about our data layouts: + // - is_row_major + // - is_col_major + // - is_data_layout + // + // Note: We will qualify both: + // row_major / col_major (as meta-tags) + // RowMajor and ColMajor (as functional classes) + template + struct is_row_major : public false_type + { + }; + + template <> + struct is_row_major : public true_type + { + }; + + template <> + struct is_row_major : public true_type + { + }; + + template + struct is_col_major : public false_type + { + }; + + template <> + struct is_col_major : public true_type + { + }; + + template <> + struct is_col_major : public true_type + { + }; + + // Convenience evaluators + template + static constexpr bool is_row_major_v = is_row_major::value; + + template + static constexpr bool is_col_major_v = is_col_major::value; + + template + struct is_data_layout + : public integral_constant || is_col_major_v> + { + }; + + // Convenience evaluator + template + static constexpr bool is_data_layout_v = is_data_layout::value; + + // Cumulative traits about our data layouts + template + struct data_layout_traits + { + static constexpr bool is_row_major = is_row_major_v; + static constexpr bool is_col_major = is_col_major_v; + static constexpr bool is_data_layout = is_data_layout_v; + }; + +// Tidy some traits accesses +#define traits_lhs data_layout_traits +#define traits_rhs data_layout_traits + + template + ROCWMMA_HOST_DEVICE constexpr static bool testDataLayoutSame() + { + return (traits_lhs::is_row_major && traits_rhs::is_row_major) + || (traits_lhs::is_col_major && traits_rhs::is_col_major); + } + + template + ROCWMMA_HOST_DEVICE constexpr static bool testDataLayoutOrthogonal() + { + return (traits_lhs::is_row_major && traits_rhs::is_col_major) + || (traits_lhs::is_col_major && traits_rhs::is_row_major); + } + + // Implement sameness classifier for data layouts + template + struct is_layout_same> + : public integral_constant()> + { + }; + + // Implement orthogonality classifier for data layouts + template + struct is_layout_orthogonal< + DataLayoutLhs, + DataLayoutRhs, + enable_if_t> + : public integral_constant()> + { + }; + +#undef traits_lhs +#undef traits_rhs + + // Orthogonal layout guides + template <> + struct orthogonal_layout + { + using type = col_major; + }; + + template <> + struct orthogonal_layout + { + using type = row_major; + }; + + template + struct orthogonal_layout> + { + using type + = DataLayout::template Array1d::type>; + }; + + template + struct layout_traits>> + : public data_layout_traits + { + }; + + } // namespace LayoutTraits_impl + +} // namespace rocwmma + +#if !defined(__HIPCC_RTC__) +namespace std +{ + + template + inline ostream& + operator<<(ostream& stream, + rocwmma::LayoutTraits_impl::data_layout_traits const& traits) + { + using data_traits = decay_t; + stream << "DataLayout Traits: " << DataLayout{} << std::endl; + stream << "is_row_major: " << data_traits::is_row_major << std::endl; + stream << "is_col_major: " << data_traits::is_col_major << std::endl; + stream << "is_data_layout: " << data_traits::is_data_layout << std::endl; + return stream; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + +#endif // ROCWMMA_DATA_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/layout.hpp b/library/include/rocwmma/internal/layout/layout.hpp new file mode 100644 index 00000000..64c5132d --- /dev/null +++ b/library/include/rocwmma/internal/layout/layout.hpp @@ -0,0 +1,272 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_LAYOUT_HPP +#define ROCWMMA_LAYOUT_HPP + +#include "../api_fwd.hpp" +#include "../mapping_util.hpp" + +namespace rocwmma +{ + // DataLayout objects map 2D matrix coordinate space to 1D data arrays offsets. + // DataLayoutT tags describe whether consecutive elements are: + // 1. Contiguous rows (row_major) + // 2. Contiguous columns (col_major) + namespace DataLayout + { + /*! \class Array1d + * \brief A class to help map from 2D matrix space to 1D data space. + * @tparam DataLayoutT Meta-tag indicating whether data is stored in + * row_major or col_major order. + */ + template + using Array1d = typename ::rocwmma::detail::template DataSpace; + + /*! \class RowMajor + * \brief Maps 2D matrix space to row_major 1D data space + */ + using RowMajor = Array1d; + + /*! \class ColMajor + * \brief Maps 2D matrix space to col_major 1D data space + */ + using ColMajor = Array1d; + + } // namespace DataLayout + + // Matrix Layouts map thread offsets into 2D matrix coordinate space: + // 1. Base thread offsets + // 2. Stride offsets + // 3. Stride counts + // 4. Per-iteration offsets (stride step based on iteration) + // 5. Cumulative offsets (cumulative stride steps based on iteration) + namespace MatrixLayout + { + /*! \class ColOrthoVW + * \brief A matrix layout that maps contiguous threads to contiguous column elements, in the BlockDim direction. + * VectorWidth elements are mapped orthogonal to the column, in the BlockK Direction. + * @tparam BlockDim The height of the column + * @tparam BlockK The number of columns + * @tparam DataT The datatype + * @tparam VectorWidth The iterative vector width + * @tparam MaxVectorWidth The total vector width + */ + template + struct ColOrthoVW; + + template + struct ColInlineVW; + + template + struct RowOrthoVW; + + template + struct RowInlineVW; + + template // # of splits + struct ColInlineInt; + + template // # of splits + struct ColOrthoInt; + + template // # of splits + struct RowInlineInt; + + template // # of splits + struct RowOrthoInt; + + } // namespace MatrixLayout + + // Register layouts describe in-register layout and serve as transform states, or endpoints. + // These are mnemonics which provide: + // 1. A relationship between in-register layouts and combinations of matrix / data layouts. + // 2. Useful parameters that may be used in transformations between endpoints. + // 3. With indications from layout traits, can determine likeness or orthogonality between states. + // Note: For these mnemonics to be useful, there must exist a transformable path between layouts. + // Example: + // Suppose we associate associate fragment register data with Storage upon loading. + // To use the fragment register data with mma functions, we may attempt to transform the data from + // Storage to MmaInput<16> to serve as input to a 16x16xk mma builtin. + namespace RegisterLayout + { + // Format for data locality + enum struct Format : uint32_t + { + SOA = 0u, // Structure of Arrays (SOA), e.g., [{XX}, {YY}, {ZZ}] + AOS = 1u, // Array of Structures (AOS), e.g., [{X,Y,Z}, {X,Y,Z}] + SOA_INT = 2u, // SOA interleaved + AOS_INT = 3u, // AOS interleaved + ACC_INT_A_MAJOR = 4u, // Interleaved MmaAcc 'A' major order + ACC_INT_B_MAJOR = 5u, // Interleaved MmaAcc 'B' major order + WMMA_INPUT_GFX11 = 6u, // Gfx11 input format + WMMA_ACC_GFX11 = 7u, // Gfx11 acc format + Invalid, // Invalid register format + }; + + // A mnemonic used to describe the register layout is suitable for input/output + template + struct Storage + { + }; + + // A mnemonic used to describe the register layout is suitable for mma input for A/B + template + struct MmaInput + { + }; + + // A mnemonic used to describe the register layout is suitable for mma input for accumulator input/output + template + struct MmaAcc + { + }; + + } // namespace RegisterLayout + +} // namespace rocwmma + +#if !defined(__HIPCC_RTC__) +namespace std +{ + + inline ostream& operator<<(ostream& stream, rocwmma::row_major const& data_layout) + { + return stream << "row_major"; + } + + inline ostream& operator<<(ostream& stream, rocwmma::col_major const& data_layout) + { + return stream << "col_major"; + } + + inline ostream& operator<<(ostream& stream, rocwmma::DataLayout::RowMajor const& data_layout) + { + return stream << "RowMajor"; + } + + inline ostream& operator<<(ostream& stream, rocwmma::DataLayout::ColMajor const& data_layout) + { + return stream << "ColMajor"; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + +#if !defined(__HIPCC_RTC__) +namespace std +{ + inline ostream& operator<<(ostream& stream, rocwmma::RegisterLayout::Format const& fmt) + { + return stream + << (fmt == rocwmma::RegisterLayout::Format::AOS ? "AOS" + : (fmt == rocwmma::RegisterLayout::Format::SOA) ? "SOA" + : (fmt == rocwmma::RegisterLayout::Format::AOS_INT) ? "AOS_INT" + : (fmt == rocwmma::RegisterLayout::Format::SOA_INT) ? "SOA_INT" + : (fmt == rocwmma::RegisterLayout::Format::ACC_INT_A_MAJOR) ? "ACC_INT_A_MAJOR" + : (fmt == rocwmma::RegisterLayout::Format::ACC_INT_B_MAJOR) ? "ACC_INT_B_MAJOR" + : (fmt == rocwmma::RegisterLayout::Format::WMMA_INPUT_GFX11) ? "WMMA_INPUT_GFX11" + : (fmt == rocwmma::RegisterLayout::Format::WMMA_ACC_GFX11) ? "WMMA_ACC_GFX11" + : "INVALID"); + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::RegisterLayout::Storage const& register_layout) + { + return stream << "Storage<" << MatrixLayout{} << ", " << DataLayout{} << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::RegisterLayout::MmaInput const& register_layout) + { + return stream << "MmaInput<" << MmaDim << ", " << rocwmma::dataTypeToString() << ", " + << Interleaved << ", " << Fmt << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::RegisterLayout::MmaAcc const& register_layout) + { + return stream << "MmaAcc<" << MmaDim << ", " << rocwmma::dataTypeToString() << ", " + << Interleaved << ", " << Fmt << ">"; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + +#include "matrix_layout_impl.hpp" + +#endif // ROCWMMA_LAYOUT_HPP diff --git a/library/include/rocwmma/internal/layout/layout_traits.hpp b/library/include/rocwmma/internal/layout/layout_traits.hpp new file mode 100644 index 00000000..ed775356 --- /dev/null +++ b/library/include/rocwmma/internal/layout/layout_traits.hpp @@ -0,0 +1,110 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_LAYOUT_TRAITS_HPP +#define ROCWMMA_LAYOUT_TRAITS_HPP + +// Need strict inclusion order here +// clang-format off +#include "layout_traits_impl.hpp" +#include "data_layout_traits_impl.hpp" +#include "matrix_layout_traits_impl.hpp" +#include "register_layout_traits_impl.hpp" +// clang-format on + +namespace rocwmma +{ + /*! \class is_layout_same + * \brief Compares layout types are the same, or are equivalent. + * Applicable to layout contexts: DataLayout, MatrixLayout and RegisterLayout. + * DataLayouts are same if they have the same 1D layout in memory. + * MatrixLayouts are the same if they have the same 2D matrix layout in memory. + * RegisterLayouts are the same if they have the same thread mapping in register. + * @tparam LhsLayout Comparative left hand side + * @tparam RhsLayout Comparative right hand side + */ + template + struct is_layout_same : public LayoutTraits_impl::is_layout_same + { + }; + + /*! \class is_layout_same_v + * \brief Evaluates is_layout_same + * @tparam LhsLayout Comparative left hand side + * @tparam RhsLayout Comparative right hand side + */ + template + constexpr static inline bool is_layout_same_v = is_layout_same::value; + + /*! \class is_layout_orthogonal + * \brief Describes a relationship between two layout endpoints. They are considered orthogonal if they + * are not the same, and there exists a reversible transformation path from one to the other. + * Applicable to layout contexts: DataLayout, MatrixLayout and RegisterLayout + * DataLayouts are orthogonal if their 1D layouts in memory are transformable (e.g., row major vs col major). + * MatrixLayouts are orthogonal if their 2D matrix layout geometry is transformable (e.g., layout transpose). + * RegisterLayouts are orthogonal if their in-register layouts are transformable (e.g., AOS vs SOA) + * @tparam LhsLayout Comparative left hand side + * @tparam RhsLayout Comparative right hand side + */ + template + struct is_layout_orthogonal + : public LayoutTraits_impl::is_layout_orthogonal + { + }; + + /*! \class is_layout_orthogonal + * \brief Evaluates is_layout_orthogonal + * @tparam LhsLayout Comparative left hand side + * @tparam RhsLayout Comparative right hand side + */ + template + constexpr static inline bool is_layout_orthogonal_v + = is_layout_orthogonal::value; + + /*! \class orthogonal_layout + * \brief Provides a guide to an orthogonal layout of the source layout. + * Applicable to layout contexts: DataLayout, MatrixLayout and RegisterLayout + * @tparam Layout the source layout + */ + template + struct orthogonal_layout : public LayoutTraits_impl::orthogonal_layout + { + }; + + /*! \class layout_transpose_t + * \brief Transforms the layout type into its orthogonal layout. + * @tparam Layout the source layout + */ + template + using orthogonal_layout_t = typename orthogonal_layout::type; + + template + struct layout_traits : public LayoutTraits_impl::layout_traits + { + }; + +} // namespace rocwmma + +#endif // ROCWMMA_LAYOUT_TRAITS_HPP diff --git a/library/include/rocwmma/internal/layout/layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp new file mode 100644 index 00000000..42fbaec4 --- /dev/null +++ b/library/include/rocwmma/internal/layout/layout_traits_impl.hpp @@ -0,0 +1,59 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_LAYOUT_TRAITS_IMPL_HPP +#define ROCWMMA_LAYOUT_TRAITS_IMPL_HPP + +#include "../utility/type_traits.hpp" + +namespace rocwmma +{ + namespace LayoutTraits_impl + { + // Classifier to test layout sameness + template + struct is_layout_same : public false_type + { + }; + + // Classifer to test layout orthogonality + template + struct is_layout_orthogonal : public false_type + { + }; + + // Orthogonality guide + template + struct orthogonal_layout; + + // Meta traits for layouts + template + struct layout_traits; + + } // namespace LayoutTraits_impl + +} // namespace rocwmma + +#endif // ROCWMMA_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp new file mode 100644 index 00000000..33ae14fd --- /dev/null +++ b/library/include/rocwmma/internal/layout/matrix_layout_impl.hpp @@ -0,0 +1,1158 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_MATRIX_LAYOUT_IMPL_HPP +#define ROCWMMA_MATRIX_LAYOUT_IMPL_HPP + +#include "layout.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + + // Implementations for the MatrixLayout classes + namespace MatrixLayout + { + /* Pattern that maps threads contiguously to matrix columns and assumes + * that VW will be mapped orthogonally to the column. + * This pattern considers VW up to MaxVW, BlockDim <= 64 and BlockDim > 64. + * + * Iterative thread stride cycles (same for all threads): + * Fill MaxVW => Fill BlockK => Fill BlockDim + * + * Example: + * BlockDim = 128 BlockK = 16 + * MaxVW = 4 VW = 1 + * + * BlockDim Stride Count = 2, BlockDimStride = (64, 0) + * BlockK Stride Count = 4, BlockKStride = (0, 4) + * VW Stride Count = 4, VWStride = (0, 1) + * + * Stride mapping (BlockDim, BlockK, VW) + * C_n = Matrix column + * i_n = cumulative iteration + * + * kDim ---------> + * VW Stride + * BlockDim |--1--| + * | |-- BlockK Stride = 4 --| + * | i0(0,0,0) i2(0,0,2) i4(0,1,0) i6(0,1,2) i14(0,3,2) + * | -- v_____ _____v_____ _____v_____ _____v_____ _____ v_____ _____ + * v | | | | | | | | | | | || | + * | | | | | | | | | | | || | + * BlockDim 64 | C0 | C1 | C2 | C3 | C4 | C5 | C6 | C7 | ... | C14 || C15 | + * Stride | | | | | | | | | | | || | + * -- |_____|_____|_____|_____|_____|_____|_____|_____| |_____||_____| + * i16(1,0,0) i18(1,0,2) i20(1,1,0) i22(1,1,2) i30(1,3,2) + * v_____ _____v_____ _____v_____ _____v_____ _____ v_____ _____ + * | | | | | | | | | | || | + * | | | | | | | | | | || | + * | C0 | C1 | C2 | C3 | C4 | C5 | C6 | C7 | ... | C14 || C15 | + * | | | | | | | | | | || | + * |_____|_____|_____|_____|_____|_____|_____|_____| |_____||_____| + * ^(128, 0) ^(BlockDim, BlockK) + * ... ... + * + * Register file (for all VectorWidths = [1, MaxVectorWidth]): + * + * Elements 0..............63 + * ______________ + * Reg0 | C0 [63:0] | + * Reg1 | C1 [63:0] | + * Reg2 | C2 [63:0] | + * ... ... + * Reg15 | C15[63:0] | + * Reg16 | C0 [127:64] | + * ... ... + * Reg31 | C15 [127:64] | + }*/ + + template + struct ColOrthoVW + { + struct Traits + { + // Number of threads per wave + static constexpr uint32_t WaveSize = Constants::AMDGCN_WAVE_SIZE; + + // Stride between tiles + static constexpr uint32_t BlockDimStride_X = min(BlockDim, WaveSize); + static constexpr uint32_t BlockDimStride_Y = 0u; + + static constexpr uint32_t BlockKStride_X = 0u; + static constexpr uint32_t BlockKStride_Y + = WaveSize * MaxVectorWidth / BlockDimStride_X; + + static constexpr uint32_t VWStride_X = 0u; + static constexpr uint32_t VWStride_Y = VectorWidth; + + // Stride space + static constexpr uint32_t BlockDimSegs = BlockDim / BlockDimStride_X; + static constexpr uint32_t BlockKSegs = BlockK / BlockKStride_Y; + static constexpr uint32_t VWSegs = MaxVectorWidth / VWStride_Y; + + // Thread-tile perspective + // TODO: rename to ThreadTile... + static constexpr uint32_t DimPerThread = BlockKSegs; + static constexpr uint32_t KPerThread = MaxVectorWidth; + static constexpr uint32_t ElementsPerThread + = DimPerThread * KPerThread * BlockDimSegs; + + static_assert(MaxVectorWidth <= BlockK, + "MaxVectorWidth cannot exceed BlockK"); + static_assert(BlockDim >= BlockDimStride_X, + "BlockDim must be larger than BlockDimStride_X"); + static_assert(BlockDim % BlockDimStride_X == 0, + "BlockDim must be a multiple of BlockDimStride_X"); + static_assert(BlockK >= BlockKStride_Y, + "BlockK must be larger than BlockKStride_Y"); + static_assert(BlockK % BlockKStride_Y == 0, + "BlockK must be a multiple of BlockKStride_Y"); + static_assert(MaxVectorWidth >= VWStride_Y, + "MaxVectorWidth must larger than VWStride_Y"); + static_assert(MaxVectorWidth % VWStride_Y == 0, + "MaxVectorWidth must be a multiple of VWStride_Y"); + + // Orthogonal layout, coordinates are reversed + // using OrthoLayout + // = RowOrthoVW; + + // using MatrixCoordT = Coord2d; + }; + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return make_vector(Traits::BlockDimSegs, // BlockDim Segments + Traits::BlockKSegs, // BlockK Segments + Traits::VWSegs); // VW Segments + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + return make_vector(make_coord2d(Traits::BlockDimStride_X, Traits::BlockDimStride_Y), + make_coord2d(Traits::BlockKStride_X, Traits::BlockKStride_Y), + make_coord2d(Traits::VWStride_X, Traits::VWStride_Y)); + } + + ROCWMMA_DEVICE static inline auto baseOffset() + { + if constexpr(Traits::BlockDimStride_X >= Traits::WaveSize) + { + // Don't need initial offset calc in Y direction: all threads fit in neighbouring rows + return make_coord2d(threadIdx.x % Traits::BlockDimStride_X, 0u); + } + else + { + // Threads need to spread over the Y direction as well + return make_coord2d(threadIdx.x % Traits::BlockDimStride_X, + (threadIdx.x / Traits::BlockDimStride_X) * MaxVectorWidth + % Traits::BlockKStride_Y); + } + } + + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) + { + // Reference: + // VWOffsetY = VWStride_Y - ((i+1) % VWSegs ? 0u : VWStride_Y * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and BlockDimSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetY = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetY = (int32_t)Traits::VWStride_Y; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::BlockDimSegs > 1)) + { + // "Reset" cycle + VWOffsetY -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_Y * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockDimSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + + if constexpr((int32_t)Traits::BlockDimSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : BlockDimStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::BlockDimSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockDimStride_X); + } + + return make_coord2d(BlockDimOffsetX, VWOffsetY + BlockKOffsetY); + } + + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) + { + int32_t cumVWOffsetY + = (int32_t)Traits::VWStride_Y * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::BlockDimStride_X; + + return make_coord2d(cumBlockDimOffsetX, cumVWOffsetY + cumBlockKOffsetY); + } + + ROCWMMA_DEVICE static inline auto debug() {} + }; + + /* Pattern that maps threads to matrix columns and assumes + * that VW will be mapped inline with the column. + * This pattern considers VW up to MaxVW, BlockDim <= 64 and BlockDim > 64. + * + * Iterative thread stride cycles (same for all threads): + * Fill MaxVW => Fill BlockK => Fill BlockDim + * + * Example: + * BlockDim = 256 BlockK = 4 + * MaxVW = 2 VW = 1 + * + * BlockDim Stride Count = 4, BlockDimStride = (64, 0) + * BlockK Stride Count = 2, BlockKStride = (0, 2) + * VW Stride Count = 2, VWStride = (1, 0) + * + * Stride mapping (BlockDim, BlockK, VW) + * C_n = Matrix column + * i_n = cumulative iteration + * + * Cartesian iteration offsets (row, col): + * i0 = (0, 0) i1 = (1, 0) i2 = (0, 2) i3 = (1, 2) + * i4 = (64, 0) i5 = (65, 0) i6 = (64, 2) i7 = (65, 2) + * i8 = (128, 0) i9 = (129, 0) i10 = (128, 2) i11 = (129, 2) + * i12 = (192, 0) i13 = (193, 0) i14 = (192, 2) i15 = (192, 2) + * + * Strides iteration offsets (BlockDim, BlockK, VW): + * i0 = (0,0,0) i1 = (0,0,1) + * i2 = (0,1,0) i3 = (0,1,1) + * i4 = (1,0,0) i5 = (1,0,1) + * i6 = (1,1,0) i7 = (1,1,1) + * i8 = (2,0,0) i9 = (2,0,1) + * i10 = (2,1,0) i11 = (2,1,1) + * i12 = (3,0,0) i13 = (3,0,1) + * i14 = (3,1,0) i15 = (3,1,1) + * + * Let's follow thread 0: + * + * kDim ---------> + * + * BlockDim1 + * | |-- BlockK Stride = 2 --| + * | i0(0,0,0) i2(0,1,0) + * | _ _ v_____ _____v_____ _____ + * v | | | | | | | + * | VW 1 | | | | | + * BlockDim | Stride | | C0 | C1 | C2 | C3 | + * Stride | _ v | v | | + * 64 i1(0,0,1) i3(0,1,1) | + * | | | | | | + * | | | | | | + * | | C0 | C1 | C2 | C3 | + * _ |_____|_____|_____|_____| + * i4(1,0,0) i6(1,1,0) + * v_____ _____v_____ _____ + * | | | | | + * | | | | | + * | C0 | C1 | C2 | C3 | + * v | v | | + * i5(1,0,1) i7(1,1,1) | + * | | | | | + * | | | | | + * | C0 | C1 | C2 | C3 | + * |_____|_____|_____|_____| + * ... ... + * ... ... + * ... ... + * v | v | | + * i13(3,0,1) i14(3,1,1) | + * | | | | | + * | | | | | + * | C0 | C1 | C2 | C3 | + * |_____|_____|_____|_____| + * + * ^(BlockDim, 0) ^(BlockDim, BlockK) + * + * Register file (for all VectorWidths = [MaxVectorWidth, 1]): + * + * Elements 0...........1........................................... ............64 + * ________________________________________________________________________ + * Reg0 | C0E0 | C0E2 | ... | C0E62 | C1E0 | C1E2 | ... | C1E62 | + * Reg1 | C0E1 | C0E3 | ... | C0E63 | C1E1 | C1E3 | ... | C1E63 | + * Reg2 | C2E0 | C2E2 | ... | C2E62 | C3E0 | C3E2 | ... | C3E62 | + * Reg3 | C2E1 | C2E3 | ... | C2E63 | C3E1 | C3E3 | ... | C3E63 | + * Reg4 | C0E64 | C0E66 | ... | C0E126 | C1E64 | C1E66 | ... | C1E126 | + * Reg5 | C0E65 | C0E67 | ... | C0E127 | C1E65 | C1E67 | ... | C1E127 | + * ... ... + * Reg10 | C2E192 | C2E194 | ... | C2E254 | C3E192 | C3E194 | ... | C3E254 | + * Reg11 | C2E193 | C2E195 | ... | C2E255 | C3E193 | C3E195 | ... | C3E255 | + * + */ + + template + struct ColInlineVW + { + + struct Traits + { + // Number of threads per wave + static constexpr uint32_t WaveSize = Constants::AMDGCN_WAVE_SIZE; + + // Strides + static constexpr uint32_t BlockDimStride_X = min(BlockDim, WaveSize); + static constexpr uint32_t BlockDimStride_Y = 0u; + + static constexpr uint32_t BlockKStride_X = 0u; + static constexpr uint32_t BlockKStride_Y + = WaveSize * MaxVectorWidth / BlockDimStride_X; + + static constexpr uint32_t VWStride_X = VectorWidth; + static constexpr uint32_t VWStride_Y = 0u; + + // Stride Space + static constexpr uint32_t BlockDimSegs = BlockDim / BlockDimStride_X; + static constexpr uint32_t BlockKSegs = BlockK / BlockKStride_Y; + static constexpr uint32_t VWSegs = MaxVectorWidth / VWStride_X; + + // Thread-tile perspective + // TODO: rename to ThreadTile... + static constexpr uint32_t DimPerThread = MaxVectorWidth; + static constexpr uint32_t KPerThread = BlockKSegs; + static constexpr uint32_t ElementsPerThread + = DimPerThread * KPerThread * BlockDimSegs; + + // Sanity checks for strides sizes + static_assert(MaxVectorWidth <= BlockDim, + "MaxVectorWidth cannot exceed BlockDim"); + static_assert(BlockDim >= BlockDimStride_X, + "BlockDim must be larger than BlockDimStride_X"); + static_assert(BlockDim % BlockDimStride_X == 0, + "BlockDim must be a multiple of BlockDimStride_X"); + static_assert(BlockK >= BlockKStride_Y, + "BlockK must be larger than BlockKStride_Y"); + static_assert(BlockK % BlockKStride_Y == 0, + "BlockK must be a multiple of BlockKStride_Y"); + static_assert(MaxVectorWidth >= VWStride_X, + "MaxVectorWidth must larger than VWStride_X"); + static_assert(MaxVectorWidth % VWStride_X == 0, + "MaxVectorWidth must be a multiple of VWStride_X"); + + // Orthogonal layout, coordinates are reversed + //using OrthoLayout + // = RowInlineVW; + + //using MatrixCoordT = Coord2d; + }; + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return make_vector(Traits::BlockDimSegs, // BlockDim Segments + Traits::BlockKSegs, // BlockK Segments + Traits::VWSegs); // VW Segments + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + return make_vector(make_coord2d(Traits::BlockDimStride_X, Traits::BlockDimStride_Y), + make_coord2d(Traits::BlockKStride_X, Traits::BlockKStride_Y), + make_coord2d(Traits::VWStride_X, Traits::VWStride_Y)); + } + + ROCWMMA_DEVICE static inline auto baseOffset() + { + if constexpr((Traits::BlockDimStride_X >= Traits::WaveSize) + && (MaxVectorWidth == 1)) + { + // Don't need initial offset calc in Y direction: all threads fit in neighbouring rows + return make_coord2d(threadIdx.x % Traits::BlockDimStride_X, 0u); + } + else + { + // Threads need to spread over the Y direction as well + return make_coord2d(threadIdx.x * MaxVectorWidth % Traits::BlockDimStride_X, + threadIdx.x * MaxVectorWidth / Traits::BlockDimStride_X + % Traits::BlockKStride_Y); + } + } + + // Incremental iteration offset + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) + { + // Reference: + // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and BlockDimSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetX = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetX = (int32_t)Traits::VWStride_X; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::BlockDimSegs > 1)) + { + // "Reset" cycle + VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockDimSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::BlockDimSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : BlockDimStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::BlockDimSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockDimStride_X); + } + + return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); + } + + // Cumulative iteration offset + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) + { + int32_t cumVWOffsetX + = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::BlockKStride_X; + + return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); + } + + ROCWMMA_DEVICE static inline auto debug() {} + }; + + template // # of splits + struct ColInlineInt + { + struct Traits + { + // Number of threads per wave + static constexpr uint32_t WaveSize = Constants::AMDGCN_WAVE_SIZE; + + // Number of elements each thread will fetch in BlockDim direction + static constexpr uint32_t DimPerThread = BlockDim / MfmaDim; + + // Number of elements each thread will fetch in BlockK direction + static constexpr uint32_t KPerThread = BlockK * MfmaDim / (WaveSize * SplitK); + + // How many elements each thread will gather + static constexpr uint32_t ElementsPerThread = DimPerThread * KPerThread; + + // Strides + static constexpr uint32_t SplitKStride_X = 0u; + static constexpr uint32_t SplitKStride_Y = BlockK / SplitK; + + static constexpr uint32_t BlockKStride_X = 0u; + static constexpr uint32_t BlockKStride_Y = 1u; + + static constexpr uint32_t VWStride_X = DimPerThread; + static constexpr uint32_t VWStride_Y = 0u; + + // Stride Space + static constexpr uint32_t SplitKSegs = BlockK / SplitKStride_Y; + static constexpr uint32_t BlockKSegs = KPerThread / BlockKStride_Y; + static constexpr uint32_t VWSegs = DimPerThread / VWStride_X; + + // // Check VectorWidth validity + // static_assert((uint32_t)Traits::DimPerThread >= VectorWidth, "Invalid VectorWidth"); + // static_assert((uint32_t)Traits::DimPerThread % VectorWidth == 0, + // "DimPerThread not a multiple of VectorWidth"); + + // Check KPerThread validity + static_assert(BlockK >= KPerThread, "Invalid KPerThread"); + static_assert(BlockK % KPerThread == 0, "BlockK is not a multiple of KPerThread"); + + // Check SplitK validity + static_assert(BlockK >= SplitK, "Invalid SplitK"); + static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); + + // Check MfmaDim validity + static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); + static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); + + // Orthogonal layout, coordinates are reversed + //using OrthoLayout = RowInlineInt; + + //using MatrixCoordT = Coord2d; + }; + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return make_vector(Traits::SplitKSegs, Traits::BlockKSegs, Traits::VWSegs); + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + return make_vector(make_coord2d(Traits::SplitKStride_X, Traits::SplitKStride_Y), + make_coord2d(Traits::BlockKStride_X, Traits::BlockKStride_Y), + make_coord2d(Traits::VWStride_X, Traits::VWStride_Y)); + } + + ROCWMMA_DEVICE static inline auto baseOffset() + { + return make_coord2d((threadIdx.x * Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * Traits::KPerThread) % BlockK); + } + + // Incremental iteration offset + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) + { + // Reference: + // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetX = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetX = (int32_t)Traits::VWStride_X; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::SplitKSegs > 1)) + { + // "Reset" cycle + VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::SplitKStride_X); + } + + return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); + } + + // Cumulative iteration offset + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) + { + int32_t cumVWOffsetX + = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::SplitKStride_X; + + return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); + } + ROCWMMA_DEVICE static inline auto debug() + { + // if(threadIdx.x == 0 && threadIdx.y == 0) + // { + // printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", + // (uint32_t)Traits::SplitKSegs, + // (uint32_t)Traits::BlockKSegs, + // (uint32_t)Traits::VWSegs); + + // printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, " + // "BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", + // (uint32_t)Traits::SplitKStride_X, + // (uint32_t)Traits::SplitKStride_Y, + // (uint32_t)Traits::BlockKStride_X, + // (uint32_t)Traits::BlockKStride_Y, + // (uint32_t)Traits::VWStride_X, + // (uint32_t)Traits::VWStride_Y); + // } + // if(threadIdx.x <= 63 && threadIdx.y == 0) + // { + // printf("Tid: (%d) Base offset(X, Y): = (%d, %d)\n", + // threadIdx.x, + // get<0>(baseOffset()), + // get<1>(baseOffset())); + // } + } + }; + + template // # of splits + struct ColOrthoInt + { + struct Traits + { + // Number of threads per wave + static constexpr uint32_t WaveSize = Constants::AMDGCN_WAVE_SIZE; + + // Number of elements each thread will fetch in BlockDim direction + static constexpr uint32_t DimPerThread = BlockDim / MfmaDim; + + // Number of elements each thread will fetch in BlockK direction + static constexpr uint32_t KPerThread = BlockK * MfmaDim / (WaveSize * SplitK); + + // Number of elements that each thread is responsible for + static constexpr uint32_t ElementsPerThread = DimPerThread * KPerThread; + + // Strides + static constexpr uint32_t SplitKStride_X = 0u; + static constexpr uint32_t SplitKStride_Y = BlockK / SplitK; + + static constexpr uint32_t BlockKStride_X = 1u; + static constexpr uint32_t BlockKStride_Y = 0u; + + static constexpr uint32_t VWStride_X = 0u; + static constexpr uint32_t VWStride_Y = KPerThread; + + // Stride Space + static constexpr uint32_t SplitKSegs = BlockK / SplitKStride_Y; + static constexpr uint32_t BlockKSegs = DimPerThread / BlockKStride_X; + static constexpr uint32_t VWSegs = KPerThread / VWStride_Y; + + // Check KPerThread validity + static_assert(BlockK >= KPerThread, "Invalid KPerThread"); + static_assert(BlockK % KPerThread == 0, "BlockK is not a multiple of KPerThread"); + + // // Check VectorWidth validity + // static_assert((uint32_t)Traits::KPerThread >= VectorWidth, "Invalid VectorWidth"); + // static_assert((uint32_t)Traits::KPerThread % VectorWidth == 0, + // "KPerThread not a multiple of VectorWidth"); + + // Check SplitK validity + static_assert(BlockK >= SplitK, "Invalid SplitK"); + static_assert(BlockK % SplitK == 0, "BlockK is not a multiple of SplitK"); + + // Check MfmaDim validity + static_assert(BlockDim >= MfmaDim, "BlockDim must be larger than MfmaDim"); + static_assert(BlockDim % MfmaDim == 0, "BlockDim must be a multiple of MfmaDim"); + + // Orthogonal layout, coordinates are reversed + //using OrthoLayout = RowOrthoInt; + + //using MatrixCoordT = Coord2d; + }; + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return make_vector(Traits::SplitKSegs, // WaveKSegs Segments + Traits::BlockKSegs, // BlockK Segments + Traits::VWSegs); // VW Segments + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + return make_vector(make_coord2d(Traits::SplitKStride_X, Traits::SplitKStride_Y), + make_coord2d(Traits::BlockKStride_X, Traits::BlockKStride_Y), + make_coord2d(Traits::VWStride_X, Traits::VWStride_Y)); + } + + ROCWMMA_DEVICE static inline auto baseOffset() + { + return make_coord2d((threadIdx.x * Traits::DimPerThread) % BlockDim, + (threadIdx.x / MfmaDim * Traits::KPerThread) % BlockK); + } + + // Incremental iteration offset + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) + { + // Reference: + // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); + // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence + // the subtraction. + // Optimization 1: if VWSegs == 1, there are no contributions from this stride + // Optimization 2: if BlockKSegs == 1 and SplitKSegs == 1, there are no "reset" + // contributions from this stride + int32_t VWOffsetX = 0; + if constexpr((int32_t)Traits::VWSegs > 1) + { + // Offset contribution + VWOffsetX = (int32_t)Traits::VWStride_X; + if constexpr(((int32_t)Traits::BlockKSegs > 1) + || ((int32_t)Traits::SplitKSegs > 1)) + { + // "Reset" cycle + VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); + } + } + + // Reference: + // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - + // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); + // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence + // the subtraction. + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: if SplitKSegs == 1, there are no "reset" contributions from this stride + int32_t BlockKOffsetY = 0; + if constexpr((int32_t)Traits::BlockKSegs > 1) + { + // Offset contribution + BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs + ? 0 + : (int32_t)Traits::BlockKStride_Y); + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // "Reset" cycle + BlockKOffsetY + -= (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y); + } + } + + // Reference: + // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : SplitKStride_X); + // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride + // Optimization 2: There are no "reset" contributions from this stride because it is the last dim + int32_t BlockDimOffsetX = 0; + if constexpr((int32_t)Traits::SplitKSegs > 1) + { + // Offset contribution + BlockDimOffsetX + = (((int32_t)iteration + 1) + % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) + ? 0 + : (int32_t)Traits::SplitKStride_X); + } + + return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); + } + + // Cumulative iteration offset + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) + { + int32_t cumVWOffsetX + = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); + int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) + % (int32_t)Traits::BlockKSegs + * (int32_t)Traits::BlockKStride_Y; + int32_t cumBlockDimOffsetX + = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) + * (int32_t)Traits::SplitKStride_X; + + return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); + } + + ROCWMMA_DEVICE static inline auto debug() + { + // if(threadIdx.x == 0 && threadIdx.y == 0) + // { + // printf("SplitKSegs: %d, BlockKSegs: %d, VWSegs: %d\n", + // (uint32_t)Traits::SplitKSegs, + // (uint32_t)Traits::BlockKSegs, + // (uint32_t)Traits::VWSegs); + + // printf("SplitKStride_X: %d, SplitKStride_Y: %d\nBlockKStride_X: %d, BlockKStride_Y: %d\nVWStride_X: %d, VWStride_Y: %d\n", + // (uint32_t)Traits::SplitKStride_X, + // (uint32_t)Traits::SplitKStride_Y, + // (uint32_t)Traits::BlockKStride_X, + // (uint32_t)Traits::BlockKStride_Y, + // (uint32_t)Traits::VWStride_X, + // (uint32_t)Traits::VWStride_Y); + + // } + // if(threadIdx.x <= 63 && threadIdx.y == 0) + // { + // printf("Base offset(X, Y): = (%d, %d)", get<0>(baseOffset()), get<1>(baseOffset())); + // } + } + }; + + template + struct OrthoTraits; + + template + struct OrthoTraits> + { + // Number of threads per wave + static constexpr uint32_t WaveSize = MatrixLayout::Traits::WaveSize; + + // Strides (swapped) + static constexpr uint32_t BlockDimStride_X = MatrixLayout::Traits::BlockDimStride_Y; + static constexpr uint32_t BlockDimStride_Y = MatrixLayout::Traits::BlockDimStride_X; + + static constexpr uint32_t BlockKStride_X = MatrixLayout::Traits::BlockKStride_Y; + static constexpr uint32_t BlockKStride_Y = MatrixLayout::Traits::BlockKStride_X; + + static constexpr uint32_t VWStride_X = MatrixLayout::Traits::VWStride_Y; + static constexpr uint32_t VWStride_Y = MatrixLayout::Traits::VWStride_X; + + // Stride space (same) + static constexpr uint32_t BlockDimSegs = MatrixLayout::Traits::BlockDimSegs; + static constexpr uint32_t BlockKSegs = MatrixLayout::Traits::BlockKSegs; + static constexpr uint32_t VWSegs = MatrixLayout::Traits::VWSegs; + }; + + template + struct OrthoTraits> + { + // Number of threads per wave + static constexpr uint32_t WaveSize = MatrixLayout::Traits::WaveSize; + + // Number of elements each thread will fetch in BlockDim direction + static constexpr uint32_t DimPerThread = MatrixLayout::Traits::DimPerThread; + + // Number of elements each thread will fetch in BlockK direction + static constexpr uint32_t KPerThread = MatrixLayout::Traits::KPerThread; + + // Number of elements that each thread is responsible for + static constexpr uint32_t ElementsPerThread = MatrixLayout::Traits::ElementsPerThread; + + // Swapped strides + static constexpr uint32_t SplitKStride_X = MatrixLayout::Traits::SplitKStride_Y; + static constexpr uint32_t SplitKStride_Y = MatrixLayout::Traits::SplitKStride_X; + + static constexpr uint32_t BlockKStride_X = MatrixLayout::Traits::BlockKStride_Y; + static constexpr uint32_t BlockKStride_Y = MatrixLayout::Traits::BlockKStride_X; + + static constexpr uint32_t VWStride_X = MatrixLayout::Traits::VWStride_Y; + static constexpr uint32_t VWStride_Y = MatrixLayout::Traits::VWStride_X; + + // Stride Space + static constexpr uint32_t SplitKSegs = MatrixLayout::Traits::SplitKSegs; + static constexpr uint32_t BlockKSegs = MatrixLayout::Traits::BlockKSegs; + static constexpr uint32_t VWSegs = MatrixLayout::Traits::VWSegs; + }; + + template + struct OrthoImpl + { + struct Traits : public OrthoTraits + { + }; + + ROCWMMA_DEVICE constexpr static inline auto strideCounts() + { + return MatrixLayout::strideCounts(); + } + + ROCWMMA_DEVICE constexpr static inline auto strides() + { + auto t = MatrixLayout::strides(); + // TODO: use apply + //apply([](auto const& v){ return swap(v); }); + return make_vector(swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); + } + + ROCWMMA_DEVICE static inline auto baseOffset() + { + return swap(MatrixLayout::baseOffset()); + } + + ROCWMMA_DEVICE static inline auto incrementalOffset(uint32_t iteration) + { + return swap(MatrixLayout::incrementalOffset(iteration)); + } + + ROCWMMA_DEVICE static inline auto cumulativeOffset(uint32_t iteration) + { + return swap(MatrixLayout::cumulativeOffset(iteration)); + } + + ROCWMMA_DEVICE static inline auto debug() {} + }; + + template + struct RowOrthoVW + : public OrthoImpl> + { + }; + + template + struct RowInlineVW + : public OrthoImpl> + { + }; + + template // # of splits + struct RowOrthoInt : public OrthoImpl> + { + }; + + template // # of splits + struct RowInlineInt + : public OrthoImpl> + { + }; + + } // namespace MatrixLayout + +} // namespace rocwmma + +#if !defined(__HIPCC_RTC__) +namespace std +{ + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout:: + ColOrthoVW const& matrix_layout) + { + return stream << "ColOrthoVW<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " + << MaxVectorWidth << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout:: + ColInlineVW const& matrix_layout) + { + return stream << "ColInlineVW<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " + << MaxVectorWidth << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout:: + RowOrthoVW const& matrix_layout) + { + return stream << "RowOrthoVW<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " + << MaxVectorWidth << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout:: + RowInlineVW const& matrix_layout) + { + return stream << "RowInlineVW<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << VectorWidth << ", " + << MaxVectorWidth << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout::ColOrthoInt const& + matrix_layout) + { + return stream << "ColOrthoInt<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK + << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout::ColInlineInt const& + matrix_layout) + { + return stream << "ColInlineInt<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK + << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout::RowOrthoInt const& + matrix_layout) + { + return stream << "RowOrthoInt<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK + << ">"; + } + + template + inline ostream& operator<<( + ostream& stream, + rocwmma::MatrixLayout::RowInlineInt const& + matrix_layout) + { + return stream << "RowInlineInt<" << BlockDim << ", " << BlockK << ", " + << rocwmma::dataTypeToString() << ", " << MmaDim << ", " << SplitK + << ">"; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + +#endif // ROCWMMA_MATRIX_LAYOUT_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp new file mode 100644 index 00000000..f79188f6 --- /dev/null +++ b/library/include/rocwmma/internal/layout/matrix_layout_traits_impl.hpp @@ -0,0 +1,605 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_MATRIX_LAYOUT_TRAITS_IMPL_HPP +#define ROCWMMA_MATRIX_LAYOUT_TRAITS_IMPL_HPP + +#include "../config.hpp" +#include "layout.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + // Common helpers for supported traits + namespace LayoutTraits_impl + { + // Reference regular layouts + using MatrixLayout::ColInlineVW; + using MatrixLayout::ColOrthoVW; + using MatrixLayout::RowInlineVW; + using MatrixLayout::RowOrthoVW; + + // Reference interleaved layouts + using MatrixLayout::ColInlineInt; + using MatrixLayout::ColOrthoInt; + using MatrixLayout::RowInlineInt; + using MatrixLayout::RowOrthoInt; + + // Build a basic set of meta-data classifiers. + // We will be interested in knowing things about our matrix layouts: + // - is_col_ortho + // - is_row_ortho + // - is_col_inline + // - is_row_inline + // - is_interleaved + // - is_matrix_layout + template + struct is_col_ortho : public false_type + { + }; + + template + struct is_col_ortho> + : public true_type + { + }; + + template + struct is_col_ortho> + : public true_type + { + }; + + template + struct is_row_ortho : public false_type + { + }; + + template + struct is_row_ortho> + : public true_type + { + }; + + template + struct is_row_ortho> + : public true_type + { + }; + + template + struct is_col_inline : public false_type + { + }; + + template + struct is_col_inline> + : public true_type + { + }; + + template + struct is_col_inline> + : public true_type + { + }; + + template + struct is_row_inline : public false_type + { + }; + + template + struct is_row_inline> + : public true_type + { + }; + + template + struct is_row_inline> + : public true_type + { + }; + + template + struct is_interleaved : public false_type + { + }; + + template + struct is_interleaved> + : public true_type + { + }; + + template + struct is_interleaved> + : public true_type + { + }; + + template + struct is_interleaved> + : public true_type + { + }; + + template + struct is_interleaved> + : public true_type + { + }; + + // Convenience evaluators + template + constexpr static bool is_col_ortho_v = is_col_ortho::value; + + template + constexpr static bool is_row_ortho_v = is_row_ortho::value; + + template + constexpr static bool is_col_inline_v = is_col_inline::value; + + template + constexpr static bool is_row_inline_v = is_row_inline::value; + + template + constexpr static bool is_interleaved_v = is_interleaved::value; + + template + struct is_matrix_layout + : public integral_constant || is_col_inline_v + || is_row_ortho_v + || is_row_inline_v> + { + }; + + template + constexpr static bool is_matrix_layout_v = is_matrix_layout::value; + + template + struct matrix_layout_classifier_traits + { + // Add associative traits + constexpr static bool is_col_ortho = is_col_ortho_v; + constexpr static bool is_col_inline = is_col_inline_v; + constexpr static bool is_row_ortho = is_row_ortho_v; + constexpr static bool is_row_inline = is_row_inline_v; + constexpr static bool is_interleaved = is_interleaved_v; + constexpr static bool is_matrix_layout = is_matrix_layout_v; + }; + + template + struct matrix_layout_derived_traits + { + // Interface for params we want to derive from matrix layouts + constexpr static uint32_t BlockDim = 0u; + constexpr static uint32_t KDim = 0u; + using DataT = void; + constexpr static uint32_t VectorWidth = 0u; + constexpr static uint32_t MaxVectorWidth = 0u; + constexpr static uint32_t MmaDim = 0u; + constexpr static uint32_t SplitK = 0u; + }; + +#define matrix_layout \ + MatrixLayout + + // Combine internal layout traits with template params + template + class MatrixLayout> + struct matrix_layout_derived_traits< + matrix_layout, + enable_if_t && !is_interleaved_v>> + : public matrix_layout::Traits // Base traits + { + // Common params derived from template params + constexpr static uint32_t BlockDim = LayoutBlockDim; + constexpr static uint32_t KDim = LayoutBlockK; + using DataT = LayoutDataT; + constexpr static uint32_t VectorWidth = LayoutVectorWidth; + constexpr static uint32_t MaxVectorWidth = LayoutMaxVectorWidth; + constexpr static uint32_t MmaDim = LayoutBlockDim; // Effective MmaDim + constexpr static uint32_t SplitK = 0; // Unused + }; + +#undef matrix_layout + +#define matrix_layout \ + MatrixLayout + + // Represent interleaved MatrixLayout instances + template + class MatrixLayout> + struct matrix_layout_derived_traits< + matrix_layout, + enable_if_t && is_interleaved_v>> + : public matrix_layout::Traits // Base traits + { + private: + // Wrapper to get fixed MaxVectorWidth / VectorWidth from layout + constexpr static inline uint32_t calcMaxVw() + { + if constexpr(is_col_inline_v || is_row_inline_v) + { + return matrix_layout::Traits::DimPerThread; + } + else if constexpr(is_col_ortho_v || is_row_ortho_v) + { + return matrix_layout::Traits::KPerThread; + } + else + { + return 0; + } + } + + public: + // Common params derived from template params + constexpr static uint32_t BlockDim = LayoutBlockDim; + constexpr static uint32_t KDim = LayoutBlockK; + using DataT = LayoutDataT; + constexpr static uint32_t VectorWidth = calcMaxVw(); + constexpr static uint32_t MaxVectorWidth = calcMaxVw(); + constexpr static uint32_t MmaDim = LayoutMmaDim; + constexpr static uint32_t SplitK = LayoutSplitK; + }; + +#undef matrix_layout + + // Combine base instance traits with specific layout classifiers + template + struct matrix_layout_traits : public matrix_layout_derived_traits, + public matrix_layout_classifier_traits + { + }; + + // NOTE: MatrixLayout assumptions + // When determining MatrixLayout traits, there are several strong assumptions. + // 1. Regarding same-ness: MatrixLayouts must match, as defined below: + // ____________________________________________________________________ + // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | + // | | (Same) | Required Fixed Params | + // | ------------------------------------------------------------------ | + // | ColOrthoVW | ColOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | ColInlineVW | ColInlineVW | BlockDim, KDim, MaxVectorWidth | + // | RowOrthoVW | RowOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | RowInlineVW | RowInlineVW | BlockDim, KDim, MaxVectorWidth | + // | ------------------------------------------------------------------ | + // | ColOrthoInt | ColOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | ColInlineInt | ColInlineInt | BlockDim, KDim, MmaDim, SplitK | + // | RowOrthoInt | RowOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | RowInlineInt | RowInlineInt | BlockDim, KDim, MmaDim, SplitK | + // -------------------------------------------------------------------- + // + // 2. Regarding orthogonality: for all Col* layouts, their Row* + // orthogonal counterparts are implemented by row / col coordinate swaps. + // This is valid as long as we have some fixed parameters, as defined below: + // ____________________________________________________________________ + // | MatrixLayoutLhs | MatrixLayoutRhs | Compatibility test: | + // | | (Orthogonal) | Required Fixed Params | + // | ------------------------------------------------------------------ | + // | ColOrthoVW | RowOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | ColInlineVW | RowInlineVW | BlockDim, KDim, MaxVectorWidth | + // | RowOrthoVW | ColOrthoVW | BlockDim, KDim, MaxVectorWidth | + // | RowInlineVW | ColInlineVW | BlockDim, KDim, MaxVectorWidth | + // | ------------------------------------------------------------------ | + // | ColOrthoInt | RowOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | ColInlineInt | RowInlineInt | BlockDim, KDim, MmaDim, SplitK | + // | RowOrthoInt | ColOrthoInt | BlockDim, KDim, MmaDim, SplitK | + // | RowInlineInt | ColInlineInt | BlockDim, KDim, MmaDim, SplitK | + // -------------------------------------------------------------------- + // This defines the need for MatrixLayout classifiers based upon: + // - ColOrtho / RowOrtho + // - ColInline / RowInline + // - Non-interleave / non-interleaved + // + // Following the above traits, we can build more complicated traits such as + // is_same, is_orthogonal and orthogonal_layout. + +// Tidy access to matrix layout traits. +#define traits_lhs matrix_layout_traits +#define traits_rhs matrix_layout_traits + + // For a fixed maxVW, we can change VW of a matrix layout to any common divisor + ROCWMMA_HOST_DEVICE constexpr static inline bool + testSupportedVW(uint32_t maxVW, uint32_t vw0, uint32_t vw1) + { + return (vw0 <= maxVW) && (vw1 <= maxVW) && (maxVW % vw0 == 0) && (maxVW % vw1 == 0); + } + + // As a predicate to is_layout_same or is_layout_orthogonal, their matrix parameters must + // be compatible (see above table). + template + ROCWMMA_HOST_DEVICE constexpr static bool testCompatibleMatrixParams() + { + if constexpr(!traits_lhs::is_matrix_layout && !traits_rhs::is_matrix_layout) + { + return false; + } + else if constexpr(!traits_lhs::is_interleaved && !traits_rhs::is_interleaved) + { + // Non-interleaved matrix layout compatibility requires: + // 1. Fixed: BlockDim, KDim, MaxVectorWidth + // 2. VectorWidths must satisfy criterion in testSupportedVW(). + return (traits_lhs::BlockDim == traits_rhs::BlockDim) + && (traits_lhs::KDim == traits_rhs::KDim) + && (traits_lhs::MaxVectorWidth == traits_rhs::MaxVectorWidth) + && (testSupportedVW(traits_lhs::MaxVectorWidth, + traits_lhs::VectorWidth, + traits_rhs::VectorWidth)); + } + else if constexpr(traits_lhs::is_interleaved && traits_rhs::is_interleaved) + { + // Interleaved matrix layout compatibility requires: + // 1. Must have fixed BlockDim, BlockK, MmaDim, SplitK + // 2. MmaDim values must satisfy criterion in testSupportedMmaDim(). + return (traits_lhs::BlockDim == traits_rhs::BlockDim) + && (traits_lhs::KDim == traits_rhs::KDim) + && (traits_lhs::MmaDim == traits_rhs::MmaDim) + && (traits_lhs::SplitK == traits_rhs::SplitK) + && (traits_lhs::DimPerThread == traits_rhs::DimPerThread) + && (traits_lhs::KPerThread == traits_rhs::KPerThread); + } + else + { + return false; + } + } + + // Test for same layout + template + ROCWMMA_HOST_DEVICE constexpr static bool testMatrixLayoutSame() + { + return ((traits_lhs::is_col_ortho && traits_rhs::is_col_ortho) + || (traits_lhs::is_row_ortho && traits_rhs::is_row_ortho) + || (traits_lhs::is_col_inline && traits_rhs::is_col_inline) + || (traits_lhs::is_row_inline && traits_rhs::is_row_inline)) + && testCompatibleMatrixParams(); + } + + // Test for orthogonal layout + template + ROCWMMA_HOST_DEVICE constexpr static bool testMatrixLayoutOrthogonal() + { + return ((traits_lhs::is_col_ortho && traits_rhs::is_row_ortho) + || (traits_lhs::is_row_ortho && traits_rhs::is_col_ortho) + || (traits_lhs::is_col_inline && traits_rhs::is_row_inline) + || (traits_lhs::is_row_inline && traits_rhs::is_col_inline)) + && testCompatibleMatrixParams(); + } + + // Now to implement the interfaces for is_layout_same and is_layout_orthogonal, + // with MatrixLayout types. + + // Implement sameness classifier for matrix layouts + template + struct is_layout_same< + MatrixLayoutLhs, + MatrixLayoutRhs, + enable_if_t> + : public integral_constant()> + { + }; + + // Implement orthogonality classifier for matrix layouts + template + struct is_layout_orthogonal< + MatrixLayoutLhs, + MatrixLayoutRhs, + enable_if_t> + : public integral_constant< + bool, + testMatrixLayoutOrthogonal()> + { + }; + +#undef traits_lhs +#undef traits_rhs + + // Matrix space transpose guide: Swap rows / cols + // VW stays consistent. + template + struct orthogonal_layout> + { + using type = RowOrthoVW; + }; + + template + struct orthogonal_layout> + { + using type = ColOrthoVW; + }; + + template + struct orthogonal_layout> + { + using type = RowInlineVW; + }; + + template + struct orthogonal_layout> + { + using type = ColInlineVW; + }; + + // Orthogonal guide for interleaved layouts + template + struct orthogonal_layout> + { + using type = RowOrthoInt; + }; + + template + struct orthogonal_layout> + { + using type = ColOrthoInt; + }; + + template + struct orthogonal_layout> + { + using type = RowInlineInt; + }; + + template + struct orthogonal_layout> + { + using type = ColInlineInt; + }; + template + struct layout_traits>> + : public matrix_layout_traits + { + }; + + } // namespace LayoutTraits_impl + +} // namespace rocwmma + +#if !defined(__HIPCC_RTC__) +namespace std +{ + + template + inline ostream& + operator<<(ostream& stream, + rocwmma::LayoutTraits_impl::matrix_layout_traits const& traits) + { + using matrix_traits = decay_t; + + stream << "MatrixLayout Traits: " << MatrixLayout{} << std::endl; + stream << "is_col_ortho: " << matrix_traits::is_col_ortho << std::endl; + stream << "is_row_ortho: " << matrix_traits::is_row_ortho << std::endl; + stream << "is_col_inline: " << matrix_traits::is_col_inline << std::endl; + stream << "is_row_inline: " << matrix_traits::is_row_inline << std::endl; + stream << "is_interleaved: " << matrix_traits::is_interleaved << std::endl; + stream << "is_matrix_layout: " << matrix_traits::is_matrix_layout << std::endl; + stream << "BlockDim: " << matrix_traits::BlockDim << std::endl; + stream << "KDim: " << matrix_traits::KDim << std::endl; + stream << "MmaDim: " << matrix_traits::MmaDim << std::endl; + stream << "SplitK: " << matrix_traits::SplitK << std::endl; + stream << "VectorWidth: " << matrix_traits::VectorWidth << std::endl; + stream << "MaxVectorWidth: " << matrix_traits::MaxVectorWidth << std::endl; + return stream; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + +#endif // ROCWMMA_MATRIX_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp new file mode 100644 index 00000000..5e88b8ea --- /dev/null +++ b/library/include/rocwmma/internal/layout/register_layout_traits_impl.hpp @@ -0,0 +1,640 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_REGISTER_LAYOUT_TRAITS_IMPL_HPP +#define ROCWMMA_REGISTER_LAYOUT_TRAITS_IMPL_HPP + +#include "../utility/type_traits.hpp" +#include "layout.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + namespace LayoutTraits_impl + { + using RegisterLayout::MmaAcc; + using RegisterLayout::MmaInput; + using RegisterLayout::Storage; + + // Start to build a basic set of meta-data classifiers. + // We will be interested in knowing things about our register layouts: + // - is_register_layout + // - is_storage + // - is_mma_input + // - is_mma_acc + template + struct is_register_layout : public false_type + { + }; + + template + struct is_register_layout> + : public is_matrix_layout + { + }; + + template + struct is_register_layout> : public true_type + { + }; + + template + struct is_register_layout> : public true_type + { + }; + + template + struct is_storage : public false_type + { + }; + + template + struct is_storage> : public is_matrix_layout + { + }; + + template + struct is_mma_input : public false_type + { + }; + + template + struct is_mma_input> : public true_type + { + }; + + template + struct is_mma_acc : public false_type + { + }; + + template + struct is_mma_acc> : public true_type + { + }; + + // Convenience evaluators + template + constexpr inline static bool is_register_layout_v + = is_register_layout::value; + + template + constexpr inline static bool is_storage_v = is_storage::value; + + template + constexpr inline static bool is_mma_input_v = is_mma_input::value; + + template + constexpr inline static bool is_mma_acc_v = is_mma_acc::value; + + template + struct register_layout_classifier_traits + { + constexpr static bool is_register_layout = is_register_layout_v; + constexpr static bool is_storage = is_storage_v; + constexpr static bool is_mma_input = is_mma_input_v; + constexpr static bool is_mma_acc = is_mma_acc_v; + }; + + template + struct register_layout_traits; + + // RegisterLayouts are consistent for both data layouts if we restrict + // VectorWidth to 1 in the opposite data layout grain. + // This applies to all matrix layouts. + template + ROCWMMA_HOST_DEVICE constexpr static bool testStorageLayoutIdentity() + { + using traits = register_layout_traits; + if constexpr(traits::is_col_inline) + { + return (traits::is_col_major || traits::VectorWidth == 1); + } + else if constexpr(traits::is_row_inline) + { + return (traits::is_row_major || traits::VectorWidth == 1); + } + else if constexpr(traits::is_col_ortho) + { + return (traits::is_row_major || traits::VectorWidth == 1u); + } + else if constexpr(traits::is_row_ortho) + { + return (traits::is_col_major || traits::VectorWidth == 1u); + } + + return false; + } + + // AOS is a strict register layout where thread VW is inline + // with contiguous BlockDim elements. + // To be valid, the layout must be consistent across row_major + // and col_major data layouts. + template + ROCWMMA_HOST_DEVICE constexpr static bool testStorageLayoutAos() + { + using traits = register_layout_traits; + return (traits::is_col_inline || traits::is_row_inline); + } + + // SOA is a strict register layout where thread VW is inline + // with contiguous BlockK elements, orthogonal to BlockDim. + // To be valid, the layout must be consistent across row_major + // and col_major data layouts. + template + ROCWMMA_HOST_DEVICE constexpr static bool testStorageLayoutSoa() + { + using traits = register_layout_traits; + return (traits::is_col_ortho || traits::is_row_ortho); + } + + // Based on the current architecture, which mma dimensions supported + template + ROCWMMA_HOST_DEVICE constexpr static inline bool testSupportedMmaDim() + { + using traits = register_layout_traits; + return (traits::MmaDim == 16u && (bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED) + || (traits::MmaDim == 32u && (bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED + && !is_same_v); // f64 mfma only 16 + } + + // Based on the current architecture, which register layout formats currently supported + template + ROCWMMA_HOST_DEVICE constexpr static inline bool testSupportedFormat() + { + using traits = register_layout_traits; + using rocwmma::RegisterLayout::Format; + + if constexpr(traits::is_mma_input) + { + if constexpr((bool)ROCWMMA_ARCH_GFX11) + { + return traits::Format == Format::WMMA_INPUT_GFX11; + } + else if constexpr(traits::is_interleaved) + { + return (traits::Format == Format::SOA_INT) + || (traits::Format == Format::AOS_INT); + } + else + { + return (traits::Format == Format::SOA) + || (traits::Format == Format::AOS); + } + } + else if constexpr(traits::is_mma_acc) + { + if constexpr((bool)ROCWMMA_ARCH_GFX11) + { + return (traits::Format == Format::WMMA_ACC_GFX11) + || (!traits::is_interleaved && (traits::Format == Format::SOA || traits::Format == Format::AOS)) + || (traits::is_interleaved && (traits::Format == Format::ACC_INT_A_MAJOR || traits::Format == Format::ACC_INT_B_MAJOR)); + } + else if constexpr(traits::is_interleaved) + { + // Intermediate accumulation format for interleaved layout + return (traits::Format == Format::ACC_INT_A_MAJOR) + || (traits::Format == Format::ACC_INT_B_MAJOR); + } + else + { + // Acc with void datalayout will take SOA format + return (traits::Format == Format::SOA) + || (traits::Format == Format::AOS); + } + } + else + { + return traits::is_storage + && ((traits::Format == Format::SOA) + || (traits::Format == Format::AOS) + || (traits::Format == Format::SOA_INT) + || (traits::Format == Format::AOS_INT)); + } + } + + template + ROCWMMA_HOST_DEVICE constexpr static inline auto registerFormat() + { + using traits = register_layout_traits; + using rocwmma::RegisterLayout::Format; + + // MmaInput and MmaAcc are statically assigned + if constexpr(traits::is_mma_input || traits::is_mma_acc) + { + return traits::Format; + } + // Determine the register format of the current storage layout + // based on the layout traits. + else if constexpr(traits::is_storage) + { + if constexpr(traits::is_interleaved) + { + return testStorageLayoutAos() + ? Format::AOS_INT + : (testStorageLayoutSoa() ? Format::SOA_INT + : Format::Invalid); + } + else + { + return testStorageLayoutAos() + ? Format::AOS + : (testStorageLayoutSoa() ? Format::SOA + : Format::Invalid); + } + } + else + { + return Format::Invalid; + } + } + + template + struct register_layout_derived_traits + { + }; + + template + struct register_layout_derived_traits> + : public matrix_layout_traits, + public data_layout_traits + { + using MatrixLayout = MatrixLayoutInternal; + using DataLayout = DataLayoutInternal; + + constexpr static RegisterLayout::Format Format + = registerFormat>(); + + constexpr static bool is_valid + = testStorageLayoutIdentity>() + && testSupportedFormat>(); + }; + + template + struct register_layout_derived_traits< + MmaInput> + : public matrix_layout_traits, public data_layout_traits + { + using MatrixLayout = void; + using DataLayout = void; + + using DataT = LayoutDataT; + + // Overrides + constexpr static bool is_interleaved = LayoutIsInterleaved; + constexpr static uint32_t MmaDim = LayoutMmaDim; + + // Template param driven format + constexpr static RegisterLayout::Format Format = Fmt; + + constexpr static bool is_valid + = testSupportedMmaDim< + MmaInput>() + && testSupportedFormat< + MmaInput>(); + }; + + template + struct register_layout_derived_traits< + MmaAcc> + : public matrix_layout_traits, public data_layout_traits + { + using MatrixLayout = void; + using DataLayout = void; + + using DataT = LayoutDataT; + + // Overrides + constexpr static bool is_interleaved = LayoutIsInterleaved; + constexpr static uint32_t MmaDim = LayoutMmaDim; + + // Template param driven format + constexpr static RegisterLayout::Format Format = Fmt; + + constexpr static bool is_valid + = testSupportedMmaDim>() + && testSupportedFormat< + MmaAcc>(); + }; + + // Combine base instance traits with specific layout classifiers + template + struct register_layout_traits : public register_layout_classifier_traits, + public register_layout_derived_traits + + { + }; + + // NOTE: RegisterLayout comparison assumptions + // When determining RegisterLayout traits, there are several strong assumptions. + // Register layouts are assigned Formats, based on their given matrix and data layouts. + // 1. Regarding same-ness: + // - Register formats match, if tested for matching register layout traits: + // MmaDim, is_interleaved and is_valid. + // - Register layouts match if register formats match, and there is congruency between + // Storage, MmaInput and MmaAcc types. + // - Congruency between Storage, MmaInput and MmaAcc types is partly defined by how + // MmaInput and MmaAcc register format template parameters are set for the Mma workflow, + // and partly by architecture (e.g., MmaAcc layout VW per block is fixed). + // + // 2. Regarding orthogonality: + // - Format orthogonality is defined as having an in-register transition from one distinct + // format to another. + // E.g,. AOS <-> SOA, SOA <-> ACC_INT_A_MAJOR, SOA <-> ACC_INT_B_MAJOR, + // AOS <-> ACC_INT_A_MAJOR or AOS <-> ACC_INT_B_MAJOR. + // These require matching MmaDim, is_interleaved and is_valid traits. + +// Keeps things a bit more tidy. Quick access to register layout traits. +#define traits_lhs register_layout_traits +#define traits_rhs register_layout_traits +#define traits register_layout_traits + + // As a predicate to is_layout_same or is_layout_orthogonal, their register traits must + // be compatible as per above. + template + ROCWMMA_HOST_DEVICE constexpr static bool testCompatibleRegisterParams() + { + // Basic test: + // Matching MmaDim, interleaving and validity + // Note: matching validity does not imply valid! + // Cannot mix valid with invalid layouts + // Datatype must be same + constexpr bool BaseTest + = (traits_lhs::MmaDim == traits_rhs::MmaDim) + && (traits_lhs::is_interleaved == traits_rhs::is_interleaved) + && (traits_lhs::is_valid == traits_rhs::is_valid) + && (is_same_v); + + // MmaInput <-> MmaInput + // MmaAcc <-> MmaAcc + if constexpr((traits_lhs::is_mma_input && traits_rhs::is_mma_input) + || (traits_lhs::is_mma_acc && traits_rhs::is_mma_acc)) + { + return BaseTest; + } + // Storage <-> MmaAcc + // Storage <-> MmaInput + // Storage must be valid layout + // Non-interleaved MmaAcc must check MaxVW + else if constexpr((traits_lhs::is_storage && traits_rhs::is_mma_input) + || (traits_lhs::is_mma_input && traits_rhs::is_storage) + || (traits_lhs::is_storage && traits_rhs::is_mma_acc) + || (traits_lhs::is_mma_acc && traits_rhs::is_storage)) + { + using storage_traits + = conditional_t; + using mma_traits = conditional_t; + + if constexpr(mma_traits::is_mma_input || mma_traits::is_interleaved) + { + return BaseTest && storage_traits::is_valid; + } + else + { + // Acc layout architecture quirks + constexpr uint32_t ExpectedAccMaxVW + = ((bool)ROCWMMA_ARCH_GFX12) ? 8u + : ((bool)ROCWMMA_ARCH_GFX11 + || is_same::value) + ? 1u + : 4u; + + constexpr bool TestMmaAccMaxVW + = (ExpectedAccMaxVW == storage_traits::MaxVectorWidth); + + return TestMmaAccMaxVW && BaseTest && storage_traits::is_valid; + } + } + // Storage <-> Storage + // Must check Matrix compatibility + else if constexpr(traits_lhs::is_storage && traits_rhs::is_storage) + { + return testCompatibleMatrixParams() + && BaseTest; + } + // MmaInput <-> MmaAcc not compatible + else + { + return false; + } + } + + template + ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutSame() + { + // Required compatibility + constexpr bool TestCompatibleParams + = testCompatibleRegisterParams(); + + // General case the formats match + constexpr bool TestFormatMatch = (traits_lhs::Format == traits_rhs::Format); + + if constexpr((traits_lhs::is_interleaved && traits_rhs::is_interleaved) + && ((traits_lhs::is_storage && traits_rhs::is_storage) + || (traits_lhs::is_storage && traits_rhs::is_mma_input) + || (traits_lhs::is_mma_input && traits_rhs::is_storage))) + { + using storage_traits + = conditional_t; + + // Gfx11 MmaInput requires some additional transforms + if constexpr((bool)ROCWMMA_ARCH_GFX11 + && (traits_lhs::is_mma_input || traits_rhs::is_mma_input)) + { + return TestCompatibleParams && TestFormatMatch; + } + else + { + // Special case: interleaved layouts + // Check matching thread dims and if either one is == 1u. + // Register contents will be identical, regardless of different formats. + constexpr bool TestIdentityQuirks + = (storage_traits::DimPerThread == 1u) || (storage_traits::KPerThread == 1u); + + return TestCompatibleParams && (TestFormatMatch || TestIdentityQuirks); + } + } + else + { + // Test both register layouts in same format + return TestCompatibleParams && TestFormatMatch; + } + } + + template + ROCWMMA_HOST_DEVICE constexpr static bool testRegisterLayoutOrthogonal() + { + // Required not same and compatible params + constexpr bool TestNotSame + = !testRegisterLayoutSame(); + + constexpr bool TestCompatibleParams + = testCompatibleRegisterParams(); + + // Identify valid paths in orthogonality. + // SOA <-> AOS + // ACC_INT_A_MAJOR <-> AOS, SOA + // ACC_INT_B_MAJOR <-> AOS, SOA + // Register layouts must be valid to be orthogonal + // clang-format off + using RegisterLayout::Format; + constexpr bool TestOpposingFormat + = ( + // Non-interleaved formats + // SOA <-> AOS + (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::AOS) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::SOA) + // Non-interleaved gfx11 formats + // SOA, AOS <-> WMMA input + // SOA, AOS <-> WMMA acc + || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::AOS) + || (traits_lhs::Format == Format::SOA && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::AOS && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::SOA) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::AOS) + // Interleaved formats + // SOA_INT <-> AOS_INT + // SOA_INT, AOS_INT <-> A-major acc fmt + // SOA_INT, AOS_INT <-> B-major acc fmt + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::ACC_INT_B_MAJOR) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::AOS_INT) + // Interleaved gfx11 formats + // SOA_INT, AOS_INT <-> WMMA input + // SOA_INT, AOS_INT <-> WMMA acc + // A-major acc fmt <-> WMMA acc + // B-major acc fmt <-> WMMA acc + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::WMMA_INPUT_GFX11 && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::SOA_INT && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::AOS_INT && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::SOA_INT) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::AOS_INT) + || (traits_lhs::Format == Format::ACC_INT_A_MAJOR && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::ACC_INT_B_MAJOR && traits_rhs::Format == Format::WMMA_ACC_GFX11) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + || (traits_lhs::Format == Format::WMMA_ACC_GFX11 && traits_rhs::Format == Format::ACC_INT_B_MAJOR) + ) + && (traits_lhs::is_valid && traits_rhs::is_valid); + // clang-format on + + return TestNotSame && TestCompatibleParams && TestOpposingFormat; + } + + // Checks if both RegisterLayout storages are the same with compatible params + template + struct is_layout_same< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t> + : public integral_constant< + bool, + testRegisterLayoutSame()> + { + }; + + // Checks if RegisterLayouts are transposed with compatible params + template + struct is_layout_orthogonal< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t> + : public integral_constant< + bool, + testRegisterLayoutOrthogonal()> + { + }; + +#undef traits_lhs +#undef traits_rhs +#undef traits + + // Use generic MatrixLayout orthogonality rules to guide the register layout transpose suggestion + // TODO: fix + template + struct orthogonal_layout> + { + using type = Storage::type, + typename orthogonal_layout::type>; + }; + + template + struct layout_traits>> + : public register_layout_traits + { + }; + + } // namespace LayoutTraits_impl + +} // namespace rocwmma + +#if !defined(__HIPCC_RTC__) +namespace std +{ + + template + inline ostream& + operator<<(ostream& stream, + rocwmma::LayoutTraits_impl::register_layout_traits const& traits) + { + using register_traits = decay_t; + + stream << "RegisterLayout Traits: " << RegisterLayout{} << std::endl; + stream << "is_register_layout: " << traits.is_register_layout << std::endl; + stream << "is_storage: " << traits.is_storage << std::endl; + stream << "is_mma_input: " << traits.is_mma_input << std::endl; + stream << "is_mma_acc: " << traits.is_mma_acc << std::endl; + stream << "is_interleaved: " << traits.is_interleaved << std::endl; + stream << "MmaDim: " << traits.MmaDim << std::endl; + stream << "is_valid: " << traits.is_valid << std::endl; + stream << "Format: " << traits.Format << std::endl; + + return stream; + } + +} // namespace std + +#endif // !defined(__HIPCC_RTC__) + +#endif // ROCWMMA_REGISTER_LAYOUT_TRAITS_IMPL_HPP diff --git a/library/include/rocwmma/internal/layout/register_layout_transforms.hpp b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp new file mode 100644 index 00000000..b31d63a3 --- /dev/null +++ b/library/include/rocwmma/internal/layout/register_layout_transforms.hpp @@ -0,0 +1,321 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP +#define ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP + +#include "../transforms.hpp" +#include "layout.hpp" +#include "layout_traits.hpp" + +namespace rocwmma +{ + template + struct soa_int_to_aos_int + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return interleave<1u, KPerThread>(forward(v)); + } + }; + + template + struct aos_int_to_soa_int + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return interleave<1u, DimPerThread>(forward(v)); + } + }; + + struct to_wmma_input_gfx11 + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + // v is unpacked + using VecTraits = VecTraits>; + using PackUtil = PackUtil; + + // Swap upper / lower 16's and then concatenate them + // to make sure we have each K value in each half. + // GFX11 wmma layout quirk needs the duplication. + auto packed = PackUtil::pack(v); + auto swapped = Swizzle::Swap16::exec(packed); + auto result = PackUtil::unpack(concat(packed, swapped)); + return result; // Return by copy + } + }; + + struct from_wmma_input_gfx11 + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return extractLo(v); + } + }; + + struct to_wmma_acc_gfx11 + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + using VecTraits = VecTraits>; + + // SOA format to wmma acc padded accumulator (gfx11). + // f16 -> padded to f32 in lower 16 + // f32 -> nop + using PackUtil = PackUtil; + auto accum = PackUtil::unpack(PackUtil::template pad<>(v)); + return accum; // Return by copy + } + }; + + struct from_wmma_acc_gfx11 + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + using VecTraits = VecTraits>; + + // Padded wmma acc (gfx11) back to SOA format. + // f16 -> padded to f32 in lower 16 + // f32 -> nop + using PackUtil = PackUtil; + return PackUtil::template unpad<>(PackUtil::pack(v)); + } + }; + + template + struct soa_int_to_mma_acc_int_a_major + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + if constexpr((bool)ROCWMMA_ARCH_GFX11) + { + + } + else + { + + } + return interleave<1u, DimPerThread>(forward(v)); + } + }; + + template + struct aos_int_to_mma_acc_int_a_major + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + return interleave<1u, DimPerThread>(forward(v)); + } + }; + + namespace RegisterTransform_impl + { + using LayoutTraits_impl::matrix_layout_traits; + using LayoutTraits_impl::register_layout_traits; + +// Keeps things a bit more tidy. Quick access to register layout traits. +#define traits_lhs register_layout_traits +#define traits_rhs register_layout_traits + + // Note: If you arrive at an undefined register_transform error, it is likely + // the layout transformation is not currently supported. Need to either implement + // the transform or ensure your layout transform mapping is correct. + template + struct register_layout_transform; + + // No-op transform (same-layout): + template + struct register_layout_transform< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t>> + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + // No-op + return v; + } + }; + + template + struct register_layout_transform< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t + && (!traits_lhs::is_register_layout || !traits_rhs::is_register_layout + || !is_layout_orthogonal_v)>> + { + template + ROCWMMA_UNSUPPORTED_IMPL("Register layout transform is not supported") + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + // No-op + return v; + } + }; + + // Apply paths between orthogonal transforms + template + struct register_layout_transform< + RegisterLayoutLhs, + RegisterLayoutRhs, + enable_if_t<(traits_lhs::is_register_layout && traits_rhs::is_register_layout) + && is_layout_orthogonal_v>> + { + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(VecT&& v) + { + using RegisterLayout::Format; + using storage_traits + = conditional_t; + + // Non-interleaved + if constexpr(traits_lhs::Format == Format::AOS + && traits_rhs::Format == Format::SOA) + { + return Transforms:: + AosToSoa::exec( + forward(v)); + } + else if constexpr(traits_lhs::Format == Format::SOA + && traits_rhs::Format == Format::AOS) + { + return Transforms:: + SoaToAos::exec( + forward(v)); + } + else if constexpr(traits_lhs::Format == Format::AOS + && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + { + return to_wmma_input_gfx11::exec(Transforms::AosToSoa::exec(forward(v))); + } + else if constexpr(traits_lhs::Format == Format::SOA + && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + { + return to_wmma_input_gfx11::exec(forward(v)); + } + // Interleaved + else if constexpr(traits_lhs::Format == Format::AOS_INT + && traits_rhs::Format == Format::SOA_INT) + { + return aos_int_to_soa_int::exec(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::SOA_INT + && traits_rhs::Format == Format::AOS_INT) + { + return soa_int_to_aos_int::exec(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::SOA_INT + && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + { + return to_wmma_input_gfx11::exec(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::AOS_INT + && traits_rhs::Format == Format::WMMA_INPUT_GFX11) + { + return to_wmma_input_gfx11::exec(aos_int_to_soa_int::exec(forward(v))); + } + else if constexpr(traits_lhs::Format == Format::ACC_INT_A_MAJOR + && traits_rhs::Format == Format::AOS_INT) + { + return interleave<1u, 4u>(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::AOS_INT + && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + { + return interleave<1u, storage_traits::KPerThread>(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::SOA_INT + && traits_rhs::Format == Format::ACC_INT_A_MAJOR) + { + return interleave<1u, 4u>(forward(v)); + } + + else if constexpr(traits_lhs::Format == Format::ACC_INT_A_MAJOR + && traits_rhs::Format == Format::SOA_INT) + { + return interleave<1u, storage_traits::KPerThread>(forward(v)); + } + + else if constexpr((traits_lhs::Format == Format::SOA + || traits_lhs::Format == Format::ACC_INT_A_MAJOR + || traits_lhs::Format == Format::ACC_INT_B_MAJOR) + && (traits_rhs::Format == Format::WMMA_ACC_GFX11)) + { + return to_wmma_acc_gfx11::exec(forward(v)); + } + else if constexpr(traits_lhs::Format == Format::AOS + && traits_rhs::Format == Format::WMMA_ACC_GFX11) + { + return to_wmma_acc_gfx11::exec(forward(v)); + } + else if constexpr((traits_lhs::Format == Format::WMMA_ACC_GFX11) + && (traits_rhs::Format == Format::SOA + || traits_rhs::Format == Format::ACC_INT_A_MAJOR + || traits_rhs::Format == Format::ACC_INT_B_MAJOR)) + { + // Padded wmma acc (gfx11) back to SOA format. + // f16 -> padded to f32 in lower 16 + // f32 -> nop + using PackUtil = PackUtil; + return PackUtil::template unpad<>(PackUtil::pack(v)); + } + else + { + static_assert(0, "Register layout transform is not implemented"); + return v; + } + } + }; + +#undef traits_lhs +#undef traits_rhs + + } // namespace RegisterTransform_impl + + /*! \class register_layout_transform + * \brief Invokes an in-register transform from one register layout to the other + * @tparam RegisterLayoutLhs Source register layout + * @tparam RegisterLayoutRhs Target register layout + */ + template + using register_layout_transform + = RegisterTransform_impl::register_layout_transform; + + using register_layout_transform_nop = register_layout_transform; + +} // namespace rocWMMA + +#endif // ROCWMMA_REGISTER_LAYOUT_TRANSFORMS_HPP diff --git a/library/include/rocwmma/internal/layout_impl.hpp b/library/include/rocwmma/internal/layout_impl.hpp deleted file mode 100644 index 2ef65dfa..00000000 --- a/library/include/rocwmma/internal/layout_impl.hpp +++ /dev/null @@ -1,827 +0,0 @@ -/******************************************************************************* - * - * MIT License - * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - *******************************************************************************/ -#ifndef ROCWMMA_LAYOUT_IMPL_HPP -#define ROCWMMA_LAYOUT_IMPL_HPP - -#include "io_traits.hpp" -#include "layout.hpp" -#include "mapping_util.hpp" -#include "utils.hpp" - -namespace rocwmma -{ - // Implementations for the MatrixLayout classes - namespace MatrixLayout - { - /* Pattern that maps threads contiguously to matrix columns and assumes - * that VW will be mapped orthogonally to the column. - * This pattern considers VW up to MaxVW, BlockDim <= 64 and BlockDim > 64. - * - * Iterative thread stride cycles (same for all threads): - * Fill MaxVW => Fill BlockK => Fill BlockDim - * - * Example: - * BlockDim = 128 BlockK = 16 - * MaxVW = 4 VW = 1 - * - * BlockDim Stride Count = 2, BlockDimStride = (64, 0) - * BlockK Stride Count = 4, BlockKStride = (0, 4) - * VW Stride Count = 4, VWStride = (0, 1) - * - * Stride mapping (BlockDim, BlockK, VW) - * C_n = Matrix column - * i_n = cumulative iteration - * - * kDim ---------> - * VW Stride - * BlockDim |--1--| - * | |-- BlockK Stride = 4 --| - * | i0(0,0,0) i2(0,0,2) i4(0,1,0) i6(0,1,2) i14(0,3,2) - * | -- v_____ _____v_____ _____v_____ _____v_____ _____ v_____ _____ - * v | | | | | | | | | | | || | - * | | | | | | | | | | | || | - * BlockDim 64 | C0 | C1 | C2 | C3 | C4 | C5 | C6 | C7 | ... | C14 || C15 | - * Stride | | | | | | | | | | | || | - * -- |_____|_____|_____|_____|_____|_____|_____|_____| |_____||_____| - * i16(1,0,0) i18(1,0,2) i20(1,1,0) i22(1,1,2) i30(1,3,2) - * v_____ _____v_____ _____v_____ _____v_____ _____ v_____ _____ - * | | | | | | | | | | || | - * | | | | | | | | | | || | - * | C0 | C1 | C2 | C3 | C4 | C5 | C6 | C7 | ... | C14 || C15 | - * | | | | | | | | | | || | - * |_____|_____|_____|_____|_____|_____|_____|_____| |_____||_____| - * ^(128, 0) ^(BlockDim, BlockK) - * ... ... - * - * Register file (for all VectorWidths = [1, MaxVectorWidth]): - * - * Elements 0..............63 - * ______________ - * Reg0 | C0 [63:0] | - * Reg1 | C1 [63:0] | - * Reg2 | C2 [63:0] | - * ... ... - * Reg15 | C15[63:0] | - * Reg16 | C0 [127:64] | - * ... ... - * Reg31 | C15 [127:64] | - }*/ - - template - struct ColOrthoVW - { - using IOTraits = IOTraits; - struct Traits - { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, - - // Strides - BlockDimStride_X = min(BlockDim, WaveSize), - BlockDimStride_Y = 0u, - - BlockKStride_X = 0u, - BlockKStride_Y = WaveSize * MaxVectorWidth / BlockDimStride_X, - - VWStride_X = 0u, - VWStride_Y = VectorWidth, - - // Stride space - BlockDimSegs = BlockDim / BlockDimStride_X, - BlockKSegs = BlockK / BlockKStride_Y, - VWSegs = MaxVectorWidth / VWStride_Y, - }; - - static_assert(BlockDim >= (uint32_t)Traits::BlockDimStride_X, - "BlockDim must be larger than BlockDimStride_X"); - static_assert(BlockDim % (uint32_t)Traits::BlockDimStride_X == 0, - "BlockDim must be a multiple of BlockDimStride_X"); - static_assert(BlockK >= (uint32_t)Traits::BlockKStride_Y, - "BlockK must be larger than BlockKStride_Y"); - static_assert(BlockK % (uint32_t)Traits::BlockKStride_Y == 0, - "BlockK must be a multiple of BlockKStride_Y"); - static_assert(MaxVectorWidth >= (uint32_t)Traits::VWStride_Y, - "MaxVectorWidth must larger than VWStride_Y"); - static_assert(MaxVectorWidth % (uint32_t)Traits::VWStride_Y == 0, - "MaxVectorWidth must be a multiple of VWStride_Y"); - - // Orthogonal layout, coordinates are reversed - using OrthoLayout - = RowOrthoVW; - - using MatrixCoordT = Coord2d; - }; - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return make_vector((uint32_t)Traits::BlockDimSegs, // BlockDim Segments - (uint32_t)Traits::BlockKSegs, // BlockK Segments - (uint32_t)Traits::VWSegs); // VW Segments - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - return make_vector( - make_coord2d((uint32_t)Traits::BlockDimStride_X, - (uint32_t)Traits::BlockDimStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - if constexpr((uint32_t)Traits::BlockDimStride_X >= (uint32_t)Traits::WaveSize) - { - // Don't need initial offset calc in Y direction: all threads fit in neighbouring rows - return make_coord2d(threadIdx.x % (uint32_t)Traits::BlockDimStride_X, 0u); - } - else - { - // Threads need to spread over the Y direction as well - return make_coord2d(threadIdx.x % (uint32_t)Traits::BlockDimStride_X, - (threadIdx.x / (uint32_t)Traits::BlockDimStride_X) - * MaxVectorWidth % (uint32_t)Traits::BlockKStride_Y); - } - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - // Reference: - // VWOffsetY = VWStride_Y - ((i+1) % VWSegs ? 0u : VWStride_Y * VWSegs); - // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence - // the subtraction. - // Optimization 1: if VWSegs == 1, there are no contributions from this stride - // Optimization 2: if BlockKSegs == 1 and BlockDimSegs == 1, there are no "reset" - // contributions from this stride - int32_t VWOffsetY = 0; - if constexpr((int32_t)Traits::VWSegs > 1) - { - // Offset contribution - VWOffsetY = (int32_t)Traits::VWStride_Y; - if constexpr(((int32_t)Traits::BlockKSegs > 1) - || ((int32_t)Traits::BlockDimSegs > 1)) - { - // "Reset" cycle - VWOffsetY -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::VWStride_Y * (int32_t)Traits::VWSegs); - } - } - - // Reference: - // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - - // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); - // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence - // the subtraction. - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: if BlockDimSegs == 1, there are no "reset" contributions from this stride - int32_t BlockKOffsetY = 0; - if constexpr((int32_t)Traits::BlockKSegs > 1) - { - // Offset contribution - BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::BlockKStride_Y); - if constexpr((int32_t)Traits::BlockDimSegs > 1) - { - // "Reset" cycle - BlockKOffsetY - -= (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y); - } - } - - // Reference: - // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : BlockDimStride_X); - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: There are no "reset" contributions from this stride because it is the last dim - int32_t BlockDimOffsetX = 0; - if constexpr((int32_t)Traits::BlockDimSegs > 1) - { - // Offset contribution - BlockDimOffsetX - = (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::BlockDimStride_X); - } - - return make_coord2d(BlockDimOffsetX, VWOffsetY + BlockKOffsetY); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - int32_t cumVWOffsetY - = (int32_t)Traits::VWStride_Y * ((int32_t)iteration % (int32_t)Traits::VWSegs); - int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) - % (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y; - int32_t cumBlockDimOffsetX - = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) - * (int32_t)Traits::BlockDimStride_X; - - return make_coord2d(cumBlockDimOffsetX, cumVWOffsetY + cumBlockKOffsetY); - } - }; - - /* Pattern that maps threads to matrix columns and assumes - * that VW will be mapped inline with the column. - * This pattern considers VW up to MaxVW, BlockDim <= 64 and BlockDim > 64. - * - * Iterative thread stride cycles (same for all threads): - * Fill MaxVW => Fill BlockK => Fill BlockDim - * - * Example: - * BlockDim = 256 BlockK = 4 - * MaxVW = 2 VW = 1 - * - * BlockDim Stride Count = 4, BlockDimStride = (64, 0) - * BlockK Stride Count = 2, BlockKStride = (0, 2) - * VW Stride Count = 2, VWStride = (1, 0) - * - * Stride mapping (BlockDim, BlockK, VW) - * C_n = Matrix column - * i_n = cumulative iteration - * - * Cartesian iteration offsets (row, col): - * i0 = (0, 0) i1 = (1, 0) i2 = (0, 2) i3 = (1, 2) - * i4 = (64, 0) i5 = (65, 0) i6 = (64, 2) i7 = (65, 2) - * i8 = (128, 0) i9 = (129, 0) i10 = (128, 2) i11 = (129, 2) - * i12 = (192, 0) i13 = (193, 0) i14 = (192, 2) i15 = (192, 2) - * - * Strides iteration offsets (BlockDim, BlockK, VW): - * i0 = (0,0,0) i1 = (0,0,1) - * i2 = (0,1,0) i3 = (0,1,1) - * i4 = (1,0,0) i5 = (1,0,1) - * i6 = (1,1,0) i7 = (1,1,1) - * i8 = (2,0,0) i9 = (2,0,1) - * i10 = (2,1,0) i11 = (2,1,1) - * i12 = (3,0,0) i13 = (3,0,1) - * i14 = (3,1,0) i15 = (3,1,1) - * - * Let's follow thread 0: - * - * kDim ---------> - * - * BlockDim1 - * | |-- BlockK Stride = 2 --| - * | i0(0,0,0) i2(0,1,0) - * | _ _ v_____ _____v_____ _____ - * v | | | | | | | - * | VW 1 | | | | | - * BlockDim | Stride | | C0 | C1 | C2 | C3 | - * Stride | _ v | v | | - * 64 i1(0,0,1) i3(0,1,1) | - * | | | | | | - * | | | | | | - * | | C0 | C1 | C2 | C3 | - * _ |_____|_____|_____|_____| - * i4(1,0,0) i6(1,1,0) - * v_____ _____v_____ _____ - * | | | | | - * | | | | | - * | C0 | C1 | C2 | C3 | - * v | v | | - * i5(1,0,1) i7(1,1,1) | - * | | | | | - * | | | | | - * | C0 | C1 | C2 | C3 | - * |_____|_____|_____|_____| - * ... ... - * ... ... - * ... ... - * v | v | | - * i13(3,0,1) i14(3,1,1) | - * | | | | | - * | | | | | - * | C0 | C1 | C2 | C3 | - * |_____|_____|_____|_____| - * - * ^(BlockDim, 0) ^(BlockDim, BlockK) - * - * Register file (for all VectorWidths = [MaxVectorWidth, 1]): - * - * Elements 0...........1........................................... ............64 - * ________________________________________________________________________ - * Reg0 | C0E0 | C0E2 | ... | C0E62 | C1E0 | C1E2 | ... | C1E62 | - * Reg1 | C0E1 | C0E3 | ... | C0E63 | C1E1 | C1E3 | ... | C1E63 | - * Reg2 | C2E0 | C2E2 | ... | C2E62 | C3E0 | C3E2 | ... | C3E62 | - * Reg3 | C2E1 | C2E3 | ... | C2E63 | C3E1 | C3E3 | ... | C3E63 | - * Reg4 | C0E64 | C0E66 | ... | C0E126 | C1E64 | C1E66 | ... | C1E126 | - * Reg5 | C0E65 | C0E67 | ... | C0E127 | C1E65 | C1E67 | ... | C1E127 | - * ... ... - * Reg10 | C2E192 | C2E194 | ... | C2E254 | C3E192 | C3E194 | ... | C3E254 | - * Reg11 | C2E193 | C2E195 | ... | C2E255 | C3E193 | C3E195 | ... | C3E255 | - * - */ - - template - struct ColInlineVW - { - using IOTraits = IOTraits; - struct Traits - { - enum : uint32_t - { - // Number of threads per wave - WaveSize = IOTraits::ThreadsPerIO, - - // Strides - BlockDimStride_X = min(BlockDim, WaveSize), - BlockDimStride_Y = 0u, - - BlockKStride_X = 0u, - BlockKStride_Y = WaveSize * MaxVectorWidth / BlockDimStride_X, - - VWStride_X = VectorWidth, - VWStride_Y = 0u, - - // Stride Space - BlockDimSegs = BlockDim / BlockDimStride_X, - BlockKSegs = BlockK / BlockKStride_Y, - VWSegs = MaxVectorWidth / VWStride_X, - }; - - // Sanity checks for strides sizes - static_assert(BlockDim >= (uint32_t)Traits::BlockDimStride_X, - "BlockDim must be larger than BlockDimStride_X"); - static_assert(BlockDim % (uint32_t)Traits::BlockDimStride_X == 0, - "BlockDim must be a multiple of BlockDimStride_X"); - static_assert(BlockK >= (uint32_t)Traits::BlockKStride_Y, - "BlockK must be larger than BlockKStride_Y"); - static_assert(BlockK % (uint32_t)Traits::BlockKStride_Y == 0, - "BlockK must be a multiple of BlockKStride_Y"); - static_assert(MaxVectorWidth >= (uint32_t)Traits::VWStride_X, - "MaxVectorWidth must larger than VWStride_X"); - static_assert(MaxVectorWidth % (uint32_t)Traits::VWStride_X == 0, - "MaxVectorWidth must be a multiple of VWStride_X"); - - // Orthogonal layout, coordinates are reversed - using OrthoLayout - = RowInlineVW; - - using MatrixCoordT = Coord2d; - }; - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return make_vector((uint32_t)Traits::BlockDimSegs, // BlockDim Segments - (uint32_t)Traits::BlockKSegs, // BlockK Segments - (uint32_t)Traits::VWSegs); // VW Segments - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - return make_vector( - make_coord2d((uint32_t)Traits::BlockDimStride_X, - (uint32_t)Traits::BlockDimStride_Y), - make_coord2d((uint32_t)Traits::BlockKStride_X, - (uint32_t)Traits::BlockKStride_Y), - make_coord2d((uint32_t)Traits::VWStride_X, (uint32_t)Traits::VWStride_Y)); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - if constexpr(((uint32_t)Traits::BlockDimStride_X >= (uint32_t)Traits::WaveSize) - && (MaxVectorWidth == 1)) - { - // Don't need initial offset calc in Y direction: all threads fit in neighbouring rows - return make_coord2d(threadIdx.x % (uint32_t)Traits::BlockDimStride_X, 0u); - } - else - { - // Threads need to spread over the Y direction as well - return make_coord2d( - threadIdx.x * MaxVectorWidth % (uint32_t)Traits::BlockDimStride_X, - threadIdx.x * MaxVectorWidth / (uint32_t)Traits::BlockDimStride_X - % (uint32_t)Traits::BlockKStride_Y); - } - } - - // Incremental iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - // Reference: - // VWOffsetX = VWStride_X - ((i+1) % VWSegs ? 0u : VWStride_X * VWSegs); - // Every set of VWSegs, we must iteratively reset the VWOffset back to 0, hence - // the subtraction. - // Optimization 1: if VWSegs == 1, there are no contributions from this stride - // Optimization 2: if BlockKSegs == 1 and BlockDimSegs == 1, there are no "reset" - // contributions from this stride - int32_t VWOffsetX = 0; - if constexpr((int32_t)Traits::VWSegs > 1) - { - // Offset contribution - VWOffsetX = (int32_t)Traits::VWStride_X; - if constexpr(((int32_t)Traits::BlockKSegs > 1) - || ((int32_t)Traits::BlockDimSegs > 1)) - { - // "Reset" cycle - VWOffsetX -= (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::VWStride_X * (int32_t)Traits::VWSegs); - } - } - - // Reference: - // BlockKOffsetY = ((i+1) % VWSegs ? 0u : BlockKStride_Y) - - // ((i+1) % (VWSegs * BlockKSegs) ? 0u : BlockKSegs * BlockKStride_Y); - // Every set of BlockKSegs, we must iteratively reset the BlockKOffsetY back to 0, hence - // the subtraction. - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: if BlockDimSegs == 1, there are no "reset" contributions from this stride - int32_t BlockKOffsetY = 0; - if constexpr((int32_t)Traits::BlockKSegs > 1) - { - // Offset contribution - BlockKOffsetY = (((int32_t)iteration + 1) % (int32_t)Traits::VWSegs - ? 0 - : (int32_t)Traits::BlockKStride_Y); - if constexpr((int32_t)Traits::BlockDimSegs > 1) - { - // "Reset" cycle - BlockKOffsetY - -= (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y); - } - } - - // Reference: - // BlockDimOffsetX = ((i+1) % VWSegs * BlockKSegs) ? 0u : BlockDimStride_X); - // Optimization 1: if BlockKSegs == 1, there are no contributions from this stride - // Optimization 2: There are no "reset" contributions from this stride because it is the last dim - int32_t BlockDimOffsetX = 0; - if constexpr((int32_t)Traits::BlockDimSegs > 1) - { - // Offset contribution - BlockDimOffsetX - = (((int32_t)iteration + 1) - % ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs) - ? 0 - : (int32_t)Traits::BlockDimStride_X); - } - - return make_coord2d(VWOffsetX + BlockDimOffsetX, BlockKOffsetY); - } - - // Cumulative iteration offset - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - int32_t cumVWOffsetX - = (int32_t)Traits::VWStride_X * ((int32_t)iteration % (int32_t)Traits::VWSegs); - int32_t cumBlockKOffsetY = ((int32_t)iteration / (int32_t)Traits::VWSegs) - % (int32_t)Traits::BlockKSegs - * (int32_t)Traits::BlockKStride_Y; - int32_t cumBlockDimOffsetX - = ((int32_t)iteration / ((int32_t)Traits::VWSegs * (int32_t)Traits::BlockKSegs)) - * (int32_t)Traits::BlockDimStride_X; - - return make_coord2d(cumVWOffsetX + cumBlockDimOffsetX, cumBlockKOffsetY); - } - }; - - template - struct RowInlineVW - { - // RowInlineVW is orthogonal to ColInlineVW, therefore we can use reversed coordinates - struct Traits - { - using OrthoLayout - = ColInlineVW; - - using MatrixCoordT = Coord2d; - }; - - // Matrix coord offsets - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return swap(Traits::OrthoLayout::baseOffset()); - } - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return Traits::OrthoLayout::strideCounts(); - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - auto t = Traits::OrthoLayout::strides(); - return make_vector( - swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::incrementalOffset(iteration)); - } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); - } - }; - - template - struct RowOrthoVW - { - // RowOrthoVW is orthogonal to ColOrthoVW, therefore we can use reversed coordinates - struct Traits - { - using OrthoLayout - = ColOrthoVW; - - using MatrixCoordT = Coord2d; - }; - - // Matrix coord offsets - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT baseOffset() - { - return swap(Traits::OrthoLayout::baseOffset()); - } - - ROCWMMA_DEVICE constexpr static inline auto strideCounts() - { - return Traits::OrthoLayout::strideCounts(); - } - - ROCWMMA_DEVICE constexpr static inline auto strides() - { - auto t = Traits::OrthoLayout::strides(); - return make_vector( - swap(get<0>(t)), swap(get<1>(t)), swap(get<2>(t))); - } - - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - incrementalOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::incrementalOffset(iteration)); - } - ROCWMMA_DEVICE static inline typename Traits::MatrixCoordT - cumulativeOffset(uint32_t iteration) - { - return swap(Traits::OrthoLayout::cumulativeOffset(iteration)); - } - }; - - } // namespace MatrixLayout - - template - struct RegisterLayoutOfMatrix; - - template - struct RegisterLayoutOfMatrix< - MatrixLayout::template ColOrthoVW> - { - using Type = RegisterLayout::template Soa; - }; - - template - struct RegisterLayoutOfMatrix< - MatrixLayout::template ColInlineVW> - { - using Type = RegisterLayout::template Aos; - }; - - template - struct RegisterLayoutOfMatrix< - MatrixLayout::template RowOrthoVW> - { - using Type = RegisterLayout::template Soa; - }; - - template - struct RegisterLayoutOfMatrix< - MatrixLayout::template RowInlineVW> - { - using Type = RegisterLayout::template Aos; - }; - - /// - /// Helper to obtain orthogonal data layout - /// - - // Data Layouts - - template <> - struct OrthogonalLayout - { - using Type = col_major; - }; - - template <> - struct OrthogonalLayout - { - using Type = row_major; - }; - - template - struct OrthogonalLayout> - { - using Type = DataLayout::template Array1d::Type>; - }; - - // Matrix Layouts - template - struct OrthogonalLayout< - MatrixLayout::template ColOrthoVW> - { - using Type = MatrixLayout:: - template RowOrthoVW; - }; - - template - struct OrthogonalLayout< - MatrixLayout::template ColInlineVW> - { - using Type = MatrixLayout:: - template RowInlineVW; - }; - - template - struct OrthogonalLayout< - MatrixLayout::template RowOrthoVW> - { - using Type = MatrixLayout:: - template ColOrthoVW; - }; - - template - struct OrthogonalLayout< - MatrixLayout::template RowInlineVW> - { - using Type = MatrixLayout:: - template ColInlineVW; - }; - - // Register layouts - template - struct OrthogonalLayout> - { - using Type = RegisterLayout::template Soa; - }; - - template - struct OrthogonalLayout> - { - using Type = RegisterLayout::template Aos; - }; - - /// - /// Helper to check if layout types are orthogonal - /// - - // In general, assume that an orthogonal layout has been assigned - template - struct is_orthogonal : public integral_constant< - bool, - is_same_v, RhsDataLayout>> - { - }; - - // Special case for self: not orthogonal - template - struct is_orthogonal : public false_type - { - }; - - // Special cases for MatrixLayouts, the VectorWidth used does not matter in determining orthogonality, however all other properties must match. - template - struct is_orthogonal< - MatrixLayout::template ColOrthoVW, - MatrixLayout::template RowOrthoVW> - : public true_type - { - }; - - template - struct is_orthogonal< - MatrixLayout::template RowOrthoVW, - MatrixLayout::template ColOrthoVW> - : public true_type - { - }; - - template - struct is_orthogonal< - MatrixLayout::template ColInlineVW, - MatrixLayout::template RowInlineVW> - : public true_type - { - }; - - template - struct is_orthogonal< - MatrixLayout::template RowInlineVW, - MatrixLayout::template ColInlineVW> - : public true_type - { - }; - -} // namespace rocwmma - -#endif // ROCWMMA_LAYOUT_IMPL_HPP diff --git a/library/include/rocwmma/internal/opaque_load.hpp b/library/include/rocwmma/internal/opaque_load.hpp index c14e1978..94d63302 100644 --- a/library/include/rocwmma/internal/opaque_load.hpp +++ b/library/include/rocwmma/internal/opaque_load.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,7 +27,6 @@ #define ROCWMMA_OPAQUE_LOAD_HPP #include "io_traits.hpp" -#include "layout.hpp" #include "tuple.hpp" #include "types.hpp" #include "vector_iterator.hpp" @@ -61,7 +60,8 @@ namespace rocwmma typename DataT, class DataLayout, class MatrixLayout, - uint32_t VectorWidth> + uint32_t VectorWidth, + bool Debug = false> struct OpaqueLoad { using IOTraits = IOTraits; @@ -78,10 +78,7 @@ namespace rocwmma // Outer loop = index 0, // Inner loop = index N-1 - template + template ROCWMMA_DEVICE static inline auto unroll_right(Iterator& out, DataT const* dataPtr, uint32_t ldm, @@ -94,6 +91,14 @@ namespace rocwmma // Last depth layer will invoke the load if constexpr(Depth == (VecTraits>::size() - 1u)) { + if constexpr(Debug) + { + printf("Depth: %d, StrideCount: %d\n", Depth, get(strideCounts)); + printf("StrideX: %d, StrideY: %d\n", + get<0>(get(strides2d)), + get<1>(get(strides2d))); + printf("Executing!\n"); + } #pragma unroll for(int i = 0; i < strideCount; i++) { @@ -105,6 +110,14 @@ namespace rocwmma // Recurse to the next nested layer else { + if constexpr(Debug) + { + printf("Depth: %d, StrideCount: %d\n", Depth, get(strideCounts)); + printf("StrideX: %d, StrideY: %d\n", + get<0>(get(strides2d)), + get<1>(get(strides2d))); + printf("Recursing!\n"); + } #pragma unroll for(int i = 0; i < strideCount; i++) { @@ -117,6 +130,7 @@ namespace rocwmma ROCWMMA_DEVICE static void exec(typename Traits::OutputT& data, DataT const* dataPtr, uint32_t ldm) { + //MatrixLayout::debug(); // Arrange wave threads to starting matrix layout offsets. auto baseOffset2d = MatrixLayout::baseOffset(); auto it = makeVectorIterator(data).begin(); @@ -130,12 +144,17 @@ namespace rocwmma MatrixLayout::strideCounts()), "IOCount inconsistent with total strides"); + // Initialize the stride details as constexpr + // so that the compiler can optimize them as args. + constexpr auto strideCounts = MatrixLayout::strideCounts(); + constexpr auto strides = MatrixLayout::strides(); + // Unroll loading in each strided dimension unroll_right(it, dataPtr + DataLayout::fromMatrixCoord(baseOffset2d, ldm), ldm, - MatrixLayout::strideCounts(), - MatrixLayout::strides()); + strideCounts, + strides); } }; diff --git a/library/include/rocwmma/internal/opaque_store.hpp b/library/include/rocwmma/internal/opaque_store.hpp index 1f1f9990..2880806a 100644 --- a/library/include/rocwmma/internal/opaque_store.hpp +++ b/library/include/rocwmma/internal/opaque_store.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -27,7 +27,6 @@ #define ROCWMMA_OPAQUE_STORE_HPP #include "io_traits.hpp" -#include "layout.hpp" #include "types.hpp" #include "vector_iterator.hpp" @@ -73,10 +72,7 @@ namespace rocwmma using StoreVecTraits = VecTraits; - template + template ROCWMMA_DEVICE static inline auto unroll_right(DataT* dataPtr, Iterator& in, uint32_t ldm, @@ -90,7 +86,7 @@ namespace rocwmma if constexpr(Depth == (VecTraits>::size() - 1u)) { #pragma unroll - for(int i = 0; i < strideCount; i++) + for(unsigned int i = 0; i < strideCount; i++) { Traits::Storer::exec(dataPtr, *in); dataPtr += strideOffset; @@ -101,7 +97,7 @@ namespace rocwmma else { #pragma unroll - for(int i = 0; i < strideCount; i++) + for(unsigned int i = 0; i < strideCount; i++) { unroll_right(dataPtr, in, ldm, strideCounts, strides2d); dataPtr += strideOffset; @@ -125,11 +121,16 @@ namespace rocwmma MatrixLayout::strideCounts()), "IOCount inconsistent with total strides"); + // Initialize the stride details as constexpr + // so that the compiler can optimize them as args. + constexpr auto strideCounts = MatrixLayout::strideCounts(); + constexpr auto strides = MatrixLayout::strides(); + unroll_right(dataPtr + DataLayout::fromMatrixCoord(baseOffset2d, ldm), it, ldm, - MatrixLayout::strideCounts(), - MatrixLayout::strides()); + strideCounts, + strides); } }; diff --git a/library/include/rocwmma/internal/permute_impl.hpp b/library/include/rocwmma/internal/permute_impl.hpp index 2f6a82c5..0e2fcece 100644 --- a/library/include/rocwmma/internal/permute_impl.hpp +++ b/library/include/rocwmma/internal/permute_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,7 @@ #ifndef ROCWMMA_PERMUTE_IMPL_HPP #define ROCWMMA_PERMUTE_IMPL_HPP +#include "mapping_util.hpp" #include "permute.hpp" namespace rocwmma diff --git a/library/include/rocwmma/internal/transforms.hpp b/library/include/rocwmma/internal/transforms.hpp index acbc4c69..49ef6d9b 100644 --- a/library/include/rocwmma/internal/transforms.hpp +++ b/library/include/rocwmma/internal/transforms.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,6 +28,7 @@ #include "transforms_impl.hpp" #include "vector.hpp" +#include "vector_iterator.hpp" namespace rocwmma { @@ -59,10 +60,10 @@ namespace rocwmma template using AosToSoa = Driver>; - + template using SoaToAos = Driver>; - + } // namespace Transforms } // namespace rocwmma diff --git a/library/include/rocwmma/internal/transforms_impl.hpp b/library/include/rocwmma/internal/transforms_impl.hpp index e9e4ebc7..10959b42 100644 --- a/library/include/rocwmma/internal/transforms_impl.hpp +++ b/library/include/rocwmma/internal/transforms_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -32,6 +32,7 @@ #include "io_traits.hpp" #include "pack_util.hpp" #include "permute.hpp" +#include "swizzle.hpp" #include "utils.hpp" #include "vector_util.hpp" @@ -213,7 +214,6 @@ namespace rocwmma { namespace Ops { - template struct AosToSoa { diff --git a/library/include/rocwmma/internal/utility/type_traits.hpp b/library/include/rocwmma/internal/utility/type_traits.hpp index ac42d080..595b40eb 100644 --- a/library/include/rocwmma/internal/utility/type_traits.hpp +++ b/library/include/rocwmma/internal/utility/type_traits.hpp @@ -70,7 +70,7 @@ namespace rocwmma // TODO: override namespace not detail using __hip_internal::is_standard_layout; using __hip_internal::is_trivial; - + using detail::is_void; using detail::is_void_v; using detail::remove_const; @@ -85,9 +85,18 @@ namespace rocwmma using detail::remove_volatile_t; using detail::true_type; + // TODO: goes into algorithm using detail::max; using detail::min; + // TODO: goes into functional + using detail::logical_or; + //using detail::logical_or_v; + using detail::logical_and; + //using detail::logical_and_v; + using detail::logical_not; + //using detail::logical_not_v; + } // namespace rocwmma #define ROCWMMA_TYPE_TRAITS_IMPL_NAMESPACE rocwmma::detail @@ -95,6 +104,10 @@ namespace rocwmma #else #include + +// TODO: move to own files +#include +#include namespace rocwmma { // std implementations @@ -147,9 +160,18 @@ namespace rocwmma using std::remove_volatile_t; using std::true_type; + // TODO: goes into algorithm using std::max; using std::min; + // TODO: goes into functional + using std::logical_or; + //using std::logical_or_v; + using std::logical_and; + //using std::logical_and_v; + using std::logical_not; + //using std::logical_not_v; + } // namespace rocwmma #define ROCWMMA_TYPE_TRAITS_IMPL_NAMESPACE std @@ -159,13 +181,13 @@ namespace rocwmma // Define some convenience traits namespace rocwmma { - template + template using enable_if_integral_t = enable_if_t{}>; - - template + + template using enable_if_signed_t = enable_if_t{}>; - template + template using enable_if_arithmetic_t = enable_if_t{}>; } diff --git a/library/include/rocwmma/internal/vector_util.hpp b/library/include/rocwmma/internal/vector_util.hpp index abc56c4f..bd6a97c4 100644 --- a/library/include/rocwmma/internal/vector_util.hpp +++ b/library/include/rocwmma/internal/vector_util.hpp @@ -122,6 +122,17 @@ namespace rocwmma template ROCWMMA_DEVICE constexpr static inline auto unpackHi(VecT const& v0, VecT const& v1); + + //! Interleaves elements from the vector, according to group size + //! E.g. GroupSize = 4 + //! v0 = [0, 1, 2, 3, 4, 5, 6, 7] + //! result = [0, 4, 1, 5, 2, 6, 3, 7] + /*! + \param v0 Vector from which interleaved elements are selected from + */ + template + ROCWMMA_DEVICE constexpr static inline auto interleave(VecT const& v0); + } // namespace rocwmma #include "vector_util_impl.hpp" diff --git a/library/include/rocwmma/internal/vector_util_impl.hpp b/library/include/rocwmma/internal/vector_util_impl.hpp index e0115a46..b9e9d1b1 100644 --- a/library/include/rocwmma/internal/vector_util_impl.hpp +++ b/library/include/rocwmma/internal/vector_util_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2022-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2022-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -416,6 +416,46 @@ namespace rocwmma } } + // A permutation of vector indices, given a gather size and a stride + // Examples: + // row_major col_major + // [0, 1] => interleave<1, 2>([0, 1, 2, 3, 4, 5]) = [0, 2, 4, 1, 3, 5] + // A = [2, 3] col_major row_major + // [4, 5] => interleave<1, 4>([0, 2, 4, 1, 3, 5]) = [0, 1, 2, 3, 4, 5] + // + // [0, 1] + // A = [2, 3] => interleave<2, 4>([0, 1, 2, 3, 4, 5, 6, 7]) = [0, 1, 4, 5, 2, 3, 6, 7] + // [4, 5] + // [6, 7] + // + template + ROCWMMA_DEVICE constexpr static inline decltype(auto) interleave(VecT const& v0) + { + static_assert((GatherSize >= 1u) && (GatherSize <= ElementStride) + && (ElementStride % GatherSize == 0) && (VecSize % GatherSize == 0), + "Invalid GatherSize"); + static_assert(ElementStride >= 1u && ElementStride <= VecSize, "Invalid Stride"); + + // No transform is needed (NOP) + if constexpr(GatherSize == ElementStride || ElementStride == VecSize) + { + return v0; + } + else + { + auto offset = [](auto&& idx, auto&& v0) { + constexpr auto Index = decay_t::value; + constexpr auto Offset0 = (Index / GatherSize) * ElementStride % VecSize; + constexpr auto Offset1 = Index % GatherSize; + constexpr auto Offset2 + = (Index * ElementStride) / (VecSize * GatherSize) * GatherSize; + return get(v0); + }; + + return vector_generator()(offset, v0); + } + } + } // namespace rocwmma #endif // ROCWMMA_VECTOR_UTIL_IMPL_HPP diff --git a/library/include/rocwmma/internal/wmma.hpp b/library/include/rocwmma/internal/wmma.hpp index 7308b5a1..e4667c3b 100644 --- a/library/include/rocwmma/internal/wmma.hpp +++ b/library/include/rocwmma/internal/wmma.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -129,56 +129,33 @@ namespace rocwmma exec(InputARegsT const& regsA, InputBRegsT const& regsB, InputCRegsT const& regsC) { // Inputs from outside will come in as fully packed - static_assert(VecTraits::size() == IOTraitsA::PackedSize, - "WMMA input size mismatch"); - static_assert(VecTraits::size() == IOTraitsB::PackedSize, - "WMMA input size mismatch"); - static_assert(VecTraits::size() == IOTraitsAcc::PackedSize, - "WMMA input size mismatch"); - - // WMMA accumulator operates on unpacked, padded data in separate 32b elements. - // In the case of f16, what needs to happen is extend each unpacked element to 32b wide - // and shift the 16b data to the correct spot (determined by the WMMA backend). - // The nasty bit is that due of the extended 32b element size, the final accumulation vector - // is masqueraded as a 'packed' type, but with the same vector size as unpacked. - auto accum = PackUtil::template pad(PackUtil::unpack(regsC)); + static_assert(VecTraits::size() == VecTraitsA::size() * Traits::WmmaCount, + "WMMA A input size mismatch"); + static_assert(VecTraits::size() == VecTraitsB::size() * Traits::WmmaCount, + "WMMA B input size mismatch"); + static_assert(VecTraits::size() == VecTraitsC::size(), + "WMMA Acc input size mismatch"); + + auto accum = regsC; // Iterate over packed WMMA inputs - auto const aIt - = makeVectorIterator(regsA).begin(); - auto const bIt - = makeVectorIterator(regsB).begin(); + auto const aIt = makeVectorIterator(regsA).begin(); + auto const bIt = makeVectorIterator(regsB).begin(); // Accumulate over WMMA count #pragma unroll for(uint32_t i = 0; i < Traits::WmmaCount; i++) { -#if ROCWMMA_ARCH_GFX11 - // Swap upper / lower 16 elements - auto swappedA = Swizzle::Swap16::exec(*aIt); - auto swappedB = Swizzle::Swap16::exec(*bIt); - - // Combine duplicated data for mult/accum. - // Evens: non-swapped - // Odds: swapped - accum = WMMA::exec(concat(unpackLo(*aIt, swappedA), unpackHi(*aIt, swappedA)), - concat(unpackLo(*bIt, swappedB), unpackHi(*bIt, swappedB)), - accum); -#else - accum = WMMA::exec(*aIt, *bIt, accum); - -#endif - aIt++; bIt++; } - return PackUtil::pack(PackUtil::template unpad(accum)); + return accum; } }; -#endif // ROCWMMA_ARCH_GFX11 +#endif // ROCWMMA_ARCH_GFX11 || ROCWMMA_ARCH_GFX12 } // namespace rocwmma diff --git a/library/include/rocwmma/rocwmma_coop_impl.hpp b/library/include/rocwmma/rocwmma_coop_impl.hpp index de563205..09a05383 100644 --- a/library/include/rocwmma/rocwmma_coop_impl.hpp +++ b/library/include/rocwmma/rocwmma_coop_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -66,8 +66,10 @@ namespace rocwmma uint32_t waveCount) { - using FragT = decay_t; - using Loader = typename GetCoopIOConfig_t::Loader; + using FragT = decay_t; + using IOConfig = GetCoopIOConfig_t; + using Loader = typename IOConfig::Loader; + using PostLoad = typename IOConfig::PostLoadXForm; // Sanity checks static_assert(!is_same::value, @@ -82,6 +84,9 @@ namespace rocwmma // Note: the frag will only be partially filled with useful data. // Layout and thread locality is not guaranteed. Loader::exec(frag.mAccess, data, ldm, waveIndex, waveCount); + + // Post-load transformation + frag.mAccess = PostLoad::exec(frag.mAccess); } template ; - using Loader = typename GetCoopIOConfig_t::Loader; + using FragT = decay_t; + using IOConfig = GetCoopIOConfig_t; + using Loader = typename IOConfig::Loader; + using PostLoad = typename IOConfig::PostLoadXForm; // Sanity checks static_assert(!is_same::value, @@ -156,6 +163,9 @@ namespace rocwmma // Note: the frag will only be partially filled with useful data. // Layout and thread locality is not guaranteed. Loader::template exec(frag.mAccess, data, ldm, waveIndex); + + // Post-load transformation + frag.mAccess = PostLoad::exec(frag.mAccess); } template ; - using Storer = typename GetCoopIOConfig_t::Storer; + using FragT = decay_t; + using IOConfig = GetCoopIOConfig_t; + using PreStore = typename IOConfig::PreStoreXForm; + using Storer = typename IOConfig::Storer; // Sanity checks static_assert(!is_same::value, @@ -204,7 +216,7 @@ namespace rocwmma // Implicit unpack and store // Note: the frag is only be partially filled with useful data. // Layout and thread locality is not guaranteed. - Storer::exec(data, frag.mAccess, ldm, waveIndex, waveCount); + Storer::exec(data, PreStore::exec(frag.mAccess), ldm, waveIndex, waveCount); } template ; - using Storer = typename GetCoopIOConfig_t::Storer; + using FragT = decay_t; + using IOConfig = GetCoopIOConfig_t; + using PreStore = typename IOConfig::PreStoreXForm; + using Storer = typename IOConfig::Storer; // Sanity checks static_assert(!is_same::value, @@ -281,7 +294,7 @@ namespace rocwmma // Implicit unpack and store // Note: the frag is only be partially filled with useful data. // Layout and thread locality is not guaranteed. - Storer::template exec(data, frag.mAccess, ldm, waveIndex); + Storer::template exec(data, PreStore::exec(frag.mAccess), ldm, waveIndex); } } // namespace rocwmma diff --git a/library/include/rocwmma/rocwmma_impl.hpp b/library/include/rocwmma/rocwmma_impl.hpp index deea2b45..fda96c1c 100644 --- a/library/include/rocwmma/rocwmma_impl.hpp +++ b/library/include/rocwmma/rocwmma_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -39,7 +39,7 @@ #include "internal/io_layout.hpp" #include "internal/io_shape.hpp" #include "internal/io_traits.hpp" -#include "internal/layout.hpp" +#include "internal/layout/layout.hpp" #include "internal/mapping_util.hpp" #include "internal/mfma.hpp" #include "internal/opaque_load.hpp" @@ -227,8 +227,10 @@ namespace rocwmma const DataT* data, uint32_t ldm) { - using FragT = decay_t; - using Loader = typename GetIOConfig_t::Loader; + using FragT = decay_t; + using IOConfig = GetIOConfig_t; + using Loader = typename IOConfig::Loader; + using PostLoad = typename IOConfig::PostLoadXForm; // Sanity checks static_assert(!is_same::value, @@ -241,6 +243,9 @@ namespace rocwmma // Load then implicit pack Loader::exec(frag.mAccess, data, ldm); + + // Post-load transformation + frag.mAccess = PostLoad::exec(frag.mAccess); } template @@ -274,8 +279,10 @@ namespace rocwmma fragment const& frag, uint32_t ldm) { - using FragT = decay_t; - using Storer = typename GetIOConfig_t::Storer; + using FragT = decay_t; + using IOConfig = GetIOConfig_t; + using PreStore = typename IOConfig::PreStoreXForm; + using Storer = typename IOConfig::Storer; // Sanity check static_assert(!is_same::value, @@ -287,7 +294,7 @@ namespace rocwmma "Fragment access and store input types do not match"); // Implicit unpack and then store - Storer::exec(data, frag.mAccess, ldm); + Storer::exec(data, PreStore::exec(frag.mAccess), ldm); } template @@ -326,11 +333,22 @@ namespace rocwmma fragment const& b, fragment const& c) { - using FragA = decay_t; - using FragB = decay_t; + using FragA = decay_t; + using FragB = decay_t; + using FragAcc = decay_t; + + using IOConfigA = GetIOConfig_t; + using IOConfigB = GetIOConfig_t; + using IOConfigAcc = GetIOConfig_t; + + using PreMmaA = typename IOConfigA::PreMmaXForm; + using PreMmaB = typename IOConfigB::PreMmaXForm; + using PreMmaAcc = typename IOConfigAcc::PreMmaXForm; + using PostMmaAcc = typename IOConfigAcc::PostMmaXForm; - using IOConfigA = GetIOConfig_t; - using IOConfigB = GetIOConfig_t; + using PackA = typename IOConfigA::PackUtil; + using PackB = typename IOConfigB::PackUtil; + using PackAcc = typename IOConfigAcc::PackUtil; // Sanity checks static_assert((IOConfigA::IOShape::BlockDim >= 16) && (IOConfigB::IOShape::BlockDim >= 16) @@ -338,29 +356,27 @@ namespace rocwmma && (IOConfigB::IOShape::BlockDim <= 32), "Input fragment BlockDim is not mfma friendly"); - static_assert(IOConfigA::IOShape::KDim == IOConfigB::IOShape::KDim, - "KDim of input fragments must match"); + static_assert((IOConfigA::IOShape::BlockDim == IOConfigB::IOShape::BlockDim) + && (IOConfigA::IOShape::KDim == IOConfigB::IOShape::KDim), + "BlockDim and KDim of input fragments must match"); - static_assert(is_orthogonal_v, - "Input fragment matrix layouts are not orthogonal"); - - static_assert(is_same_v, + static_assert(is_layout_same_v, "Input fragment register layouts do not match"); - static_assert(is_same_v>, - "Input fragment register layouts are not mfma friendly"); - // Gfx9 uses MFMA, gfx11 uses WMMA - using MMA = conditional_t, Wmma>; - // mma functions operate on packed vectors - (*d) = MMA::exec(*a, *b, *c); + // 1. Perform input pre-ops on A, B, Acc (unpacked) + // 2. Mma (packed) + // 3. Perform acc post-op on Acc + // 4. Pack back to register + d.mAccess = PostMmaAcc::exec( + PackAcc::unpack(Mma::exec(PackA::pack(PreMmaA::exec(a.mAccess)), + PackB::pack(PreMmaB::exec(b.mAccess)), + PackAcc::pack(PreMmaAcc::exec(c.mAccess))))); } ROCWMMA_DEVICE void synchronize_workgroup() diff --git a/library/include/rocwmma/rocwmma_transforms_impl.hpp b/library/include/rocwmma/rocwmma_transforms_impl.hpp index 9f5c70c4..f58cad3a 100644 --- a/library/include/rocwmma/rocwmma_transforms_impl.hpp +++ b/library/include/rocwmma/rocwmma_transforms_impl.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -26,6 +26,8 @@ #ifndef ROCWMMA_TRANSFORMS_API_IMPL_HPP #define ROCWMMA_TRANSFORMS_API_IMPL_HPP +#include "internal/layout/layout.hpp" +#include "internal/layout/register_layout_transforms.hpp" #include "internal/transforms.hpp" #include "rocwmma_transforms.hpp" @@ -38,7 +40,7 @@ namespace rocwmma /// // Below are defined as fast implicit transposes: - // - We reinterpret meaning between cols of A and rows of B, + // - We re-interpret meaning between cols of A and rows of B, // in order to change the shape of our data for reading / writing. // Implicit transposes of fragment objects are designed to be // relatively cheap, and should only require a signature cast. @@ -50,7 +52,7 @@ namespace rocwmma // Example: // - A matrix_a fragment of (BlockM x BlockK) = 32x8 in col_major may be reinterpreted // as a matrix_b fragment of (BlockK x BlockN) = 8x32 in row_major. - // Here, we have transposed (reimagined) 8 cols of matrix_a into 8 rows of matrix_b. + // Here, we have transposed (implicitly) 8 cols of matrix_a into 8 rows of matrix_b. template struct ApplyTranspose; @@ -65,7 +67,9 @@ namespace rocwmma // Original frag A type using FragA = fragment; - // Transpose to frag B type in opposite data layout. + // Transpose to frag B type in opposite data layout: + // - Exchange Block M for BlockN + // - Exchange row_major for col_major and vice-versa using FragB = fragment, + static_assert(is_layout_orthogonal_v, "Data Layouts are not orthogonal"); - static_assert(is_orthogonal_v, + static_assert(is_layout_orthogonal_v, "Matrix Layouts are not orthogonal"); - static_assert(is_same_v, + static_assert(is_layout_same_v, "Register layouts do not match"); public: @@ -115,10 +119,12 @@ namespace rocwmma struct ApplyTranspose> { private: - // Original frag A type + // Original frag B type using FragB = fragment; - // Transpose to frag A type in opposite data layout. + // Transpose to frag A type in opposite data layout: + // - Exchange Block M for BlockN + // - Exchange row_major for col_major and vice-versa using FragA = fragment, + static_assert(is_layout_orthogonal_v, "Data Layouts are not orthogonal"); - static_assert(is_orthogonal_v, + static_assert(is_layout_orthogonal_v, "Matrix Layouts are not orthogonal"); - static_assert(is_same_v, - "Register layouts do not match"); + static_assert(is_layout_same_v, + "Fragment register layouts do not match"); public: // Interface @@ -177,25 +183,6 @@ namespace rocwmma template struct ApplyDataLayout; - // Same layout case - template - struct ApplyDataLayout, - DataLayoutT> - { - // Interface - using Type = fragment; - template - ROCWMMA_DEVICE constexpr static inline Type const& exec(Type const& frag) - { - return frag; - } - }; - // Other layout case template , NewDataLayoutT> { - private: - using FragIn = fragment; - using FragOut = fragment; - - using IOConfigIn = GetIOConfig_t; - - using RegisterLayoutIn = typename GetIOConfig_t::IOLayout::RegisterLayout; - using RegisterLayoutOut = typename GetIOConfig_t::IOLayout::RegisterLayout; + using Type = fragment; - // Matrix context, BlockDim and KDim implicitly the same due to re-use of - // MatrixT, BlockM, BlockN, BlockK - - public: - // Interface - using Type = FragOut; - - // Optimal case: input and output register layouts match - template < - uint32_t WaveCount = 1, - typename FragT, - enable_if_t< - is_same_v && is_same_v, - int> = 0> + template ROCWMMA_DEVICE constexpr static inline decltype(auto) exec(FragT const& frag) { - return reinterpret_cast(frag); - } + static_assert( + is_same_v, FragT>, + "Input fragment types do not match"); - // Input and output register layouts do not match: must transform using AOS<->SOA - template < - uint32_t WaveCount = 1, - typename FragT, - enable_if_t< - is_same_v && !is_same_v, - int> = 0> - ROCWMMA_DEVICE constexpr static inline auto exec(FragT const& frag) - { - // TODO: Make sure to use coop configs to get the right MaxVW!!! - using IOConfigCoop = GetCoopIOConfig_t; - constexpr uint32_t BlockDim = IOConfigCoop::IOShape::BlockDim; - constexpr uint32_t MaxVW = IOConfigCoop::IOLayout::MaxVW; - using RegisterLayoutIncoming = typename IOConfigCoop::IOLayout::RegisterLayout; - - // Target layouts - using AosLayout = RegisterLayout::template Aos; - using SoaLayout = RegisterLayout::template Soa; - - auto result = FragOut{}; - - if constexpr(is_same_v) - { - result.mAccess = Transforms::AosToSoa::exec(frag.mAccess); - } - else if constexpr(is_same_v) - { - result.mAccess = Transforms::SoaToAos::exec(frag.mAccess); - } + using DstFrag = Type; + + // Make sure to use coop configs to get the right MaxVW!!! + using SrcLayout = + typename GetCoopIOConfig_t::IOLayout::FragmentLayout; + using DstLayout = + typename GetCoopIOConfig_t::IOLayout::FragmentLayout; + auto result = DstFrag{}; + result.mAccess + = register_layout_transform::exec(frag.mAccess); return result; } }; diff --git a/samples/common.hpp b/samples/common.hpp index 247759f9..2ef073c9 100644 --- a/samples/common.hpp +++ b/samples/common.hpp @@ -198,6 +198,15 @@ __host__ static inline void } } +template +__host__ static inline void fillVal(DataT* mat, uint32_t m, uint32_t n, DataT val = 1) +{ + for(int i = 0; i < m * n; ++i) + { + mat[i] = val; + } +} + // Host matrix data random initialization template __host__ static inline void fillRand(DataT* mat, uint32_t m, uint32_t n) @@ -223,6 +232,67 @@ __host__ static inline void fillRand(DataT* mat, uint32_t m, uint32_t n) } } +#include + +template +__host__ static inline void fillEnc(DataT* mat, uint32_t m, uint32_t n) +{ + using EncT = std::conditional_t; + //#pragma omp parallel for + for(int i = 0; i < m; ++i) + { + for(int j = 0; j < n; j++) + { + // Use binary encoding for the row / col coords + // 0x MMMM NNNN + EncT enc = ((i & 0xFF) << (sizeof(DataT) * 4)) | (j & 0xFF); + //std::cout << "row: " << i << " col: " << j << " :"; + //std::cout << "0x" << std::setfill('0') << std::setw(sizeof(DataT)*2) << std::right << std::hex << ((i & 0xFF) << (sizeof(DataT)*4)) << " " << (j & 0xFF) << std::endl; + //std::cout << "0x" << std::setfill('0') << std::setw(sizeof(DataT)*2) << std::right << std::hex << enc << std::endl; + auto idx = std::is_same_v ? (i * n + j) : (i + m * j); + mat[idx] = reinterpret_cast(enc); + } + } +} + +template +__host__ static inline void printEnc(DataT* mat, uint32_t m, uint32_t n) +{ + using EncT = std::conditional_t; + for(int i = 0; i < m; ++i) + { + for(int j = 0; j < n; j++) + { + // Use binary encoding for the row / col coords + // 0x MMMM NNNN + auto idx = std::is_same_v ? (i * n + j) : (i + m * j); + std::cout << "0x" << std::setfill('0') << std::setw(sizeof(DataT) * 2) << std::right + << std::hex << reinterpret_cast(mat[idx]) << " "; + //std::cout << reinterpret_cast(mat[idx]) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; +} + +template +__host__ static inline void printData(DataT* mat, uint32_t m, uint32_t n) +{ + for(int i = 0; i < m; ++i) + { + for(int j = 0; j < n; j++) + { + // Use binary encoding for the row / col coords + // 0x MMMM NNNN + auto idx = std::is_same_v ? (i * n + j) : (i + m * j); + std::cout << std::setw(8) << std::right << float(mat[idx]) << " "; + //std::cout << reinterpret_cast(mat[idx]) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl; +} + // Host GEMM validation template (warpTileBound) > m || get<1>(warpTileBound) > n) + { + return; + } - // Tile Sizes - constexpr auto warpTileSize = make_coord2d(WARP_TILE_X, WARP_TILE_Y); - constexpr auto macroTileSize = make_coord2d(MACRO_TILE_X, MACRO_TILE_Y); + /// + /// 1D global read coordinate setup + /// + using GRBuffAMap1d = GetDataLayout_t; + using GRBuffBMap1d = GetDataLayout_t; - // Local warp coordinate relative to current threadblock (wg). - constexpr auto warpDims = make_coord2d(WARPS_X, WARPS_Y); - auto localWarpCoord = make_coord2d(threadIdx.x / WARP_SIZE, threadIdx.y); - auto localWarpOffset = localWarpCoord * warpTileSize; + // Initial globa read address offsets + auto globalReadOffsetA + = GRBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(macroTileCoord), 0u), lda); + auto globalReadOffsetB + = GRBuffBMap1d::fromMatrixCoord(make_coord2d(0u, get<1>(macroTileCoord)), ldb); - // Global matrix coordinates for C/D - auto macroTileCoord = make_coord2d(blockIdx.x, blockIdx.y) * macroTileSize; - auto warpTileCoord = macroTileCoord + localWarpOffset; + // Incremental global read address offsets + auto kStepOffsetA = GRBuffAMap1d::fromMatrixCoord(make_coord2d(0u, ROCWMMA_K), lda); + auto kStepOffsetB = GRBuffBMap1d::fromMatrixCoord(make_coord2d(ROCWMMA_K, 0u), ldb); - // Bounds check - auto warpTileBound = warpTileCoord + warpTileSize; - if(get<0>(warpTileBound) > m || get<1>(warpTileBound) > n) - { - return; - } + /// + /// Cooperative config for global read A / B + /// - /// - /// 1D global read coordinate setup - /// - using GRBuffAMap1d = GetDataLayout_t; - using GRBuffBMap1d = GetDataLayout_t; - - // Initial globa read address offsets - auto globalReadOffsetA - = GRBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(macroTileCoord), 0u), lda); - auto globalReadOffsetB - = GRBuffBMap1d::fromMatrixCoord(make_coord2d(0u, get<1>(macroTileCoord)), ldb); - - // Incremental global read address offsets - auto kStepOffsetA = GRBuffAMap1d::fromMatrixCoord(make_coord2d(0u, ROCWMMA_K), lda); - auto kStepOffsetB = GRBuffBMap1d::fromMatrixCoord(make_coord2d(ROCWMMA_K, 0u), ldb); - - /// - /// Cooperative config for global read A / B - /// - - // WorkItems will be split up by minimum IOCount to perform either global read or local write. - // These are inputs to cooperative functions. - constexpr auto warpCount = get<0>(warpDims) * get<1>(warpDims); - - // Scheduling warp order is analogous to row major priority. - // E.g. Wg = (128, 2) = 2x2 warps - // (0, 0) (0, 1) Share Schedule: w0 = (0, 0), w1 = (0, 1), - // (1, 0) (1, 1) w2 = (1, 0), w3 = (1, 1), count = 4 - const auto warpIndex = get<0>(localWarpCoord) * get<1>(warpDims) + get<1>(localWarpCoord); - - /// - /// Perform initial global pre-fetch - /// - - GRBuffA grBuffA; - GRBuffB grBuffB; - - globalReadCoopA(grBuffA, a + globalReadOffsetA, lda, warpIndex); - globalReadCoopB(grBuffB, b + globalReadOffsetB, ldb, warpIndex); - - globalReadOffsetA += kStepOffsetA; - globalReadOffsetB += kStepOffsetB; - - /// - /// Setup LDS addressing - /// This kernel will use 2 separate LDS blocks for pipelining - /// the input prefetching during the accumulation loop - /// - - HIP_DYNAMIC_SHARED(void*, localMemPtr); - using LWBuffAShape = GetIOShape_t; - using LWBuffBShape = GetIOShape_t; - using LWBuffAMap1d = GetDataLayout_t; - using LWBuffBMap1d = GetDataLayout_t; - - constexpr uint32_t ldsWidth = ROCWMMA_K; - constexpr uint32_t ldsHeight = LWBuffAShape::BlockHeight + LWBuffBShape::BlockHeight; - constexpr uint32_t sizeLds = ldsHeight * ldsWidth; - constexpr uint32_t ldsld = std::is_same_v ? ldsWidth : ldsHeight; - - auto* ldsPtrLo = reinterpret_cast(localMemPtr); - auto* ldsPtrHi = ldsPtrLo + sizeLds; - - // Local write offsets to start of A / B data - auto ldsWriteOffsetA = 0u; - auto ldsWriteOffsetB - = LWBuffAMap1d::fromMatrixCoord(make_coord2d(LWBuffAShape::BlockHeight, 0u), ldsld); - - // Local read offsets for mfma frags - auto ldsReadOffsetA - = ldsWriteOffsetA - + LWBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(localWarpOffset), 0u), ldsld); - auto ldsReadOffsetB - = ldsWriteOffsetB - + LWBuffBMap1d::fromMatrixCoord(make_coord2d(get<1>(localWarpOffset), 0u), ldsld); - - /// - /// Write prefetch to local - /// - localWriteCoopA(ldsPtrLo + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); - localWriteCoopB(ldsPtrLo + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); - - /// - /// Initialize accumulation frags - /// - MfmaFragAcc fragsAcc[BLOCKS_X][BLOCKS_Y]; - fill(fragsAcc, 0.0f); - - /// - /// Synchronize warps and memory - /// - synchronize_workgroup(); - - /// - /// Accumulate A * B for all mfma frags in warp tile - /// - for(auto currentK = ROCWMMA_K; currentK < k; currentK += ROCWMMA_K) - { - MfmaFragA fragsA[BLOCKS_X]; - MfmaFragB fragsB[BLOCKS_Y]; + // WorkItems will be split up by minimum IOCount to perform either global read or local write. + // These are inputs to cooperative functions. + constexpr auto warpCount = get<0>(warpDims) * get<1>(warpDims); - // Local read mfma frags from first LDS buffer - localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); - localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + // Scheduling warp order is analogous to row major priority. + // E.g. Wg = (128, 2) = 2x2 warps + // (0, 0) (0, 1) Share Schedule: w0 = (0, 0), w1 = (0, 1), + // (1, 0) (1, 1) w2 = (1, 0), w3 = (1, 1), count = 4 + const auto warpIndex = get<0>(localWarpCoord) * get<1>(warpDims) + get<1>(localWarpCoord); + + /// + /// Perform initial global pre-fetch + /// + + GRBuffA grBuffA; + GRBuffB grBuffB; - // Prefetch next round of global frags globalReadCoopA(grBuffA, a + globalReadOffsetA, lda, warpIndex); globalReadCoopB(grBuffB, b + globalReadOffsetB, ldb, warpIndex); - // Advance offsets to next k step globalReadOffsetA += kStepOffsetA; globalReadOffsetB += kStepOffsetB; - // accum(A * B) - mfma(fragsAcc, fragsA, fragsB, fragsAcc); + /// + /// Setup LDS addressing + /// This kernel will use 2 separate LDS blocks for pipelining + /// the input prefetching during the accumulation loop + /// + + HIP_DYNAMIC_SHARED(void*, localMemPtr); + using LWBuffAShape = GetIOShape_t; + using LWBuffBShape = GetIOShape_t; + using LWBuffAMap1d = GetDataLayout_t; + using LWBuffBMap1d = GetDataLayout_t; + + constexpr uint32_t ldsWidth = ROCWMMA_K; + constexpr uint32_t ldsHeight = LWBuffAShape::BlockHeight + LWBuffBShape::BlockHeight; + constexpr uint32_t sizeLds = ldsHeight * ldsWidth; + constexpr uint32_t ldsld = std::is_same_v ? ldsWidth : ldsHeight; + + auto* ldsPtrLo = reinterpret_cast(localMemPtr); + auto* ldsPtrHi = ldsPtrLo + sizeLds; + + // Local write offsets to start of A / B data + auto ldsWriteOffsetA = 0u; + auto ldsWriteOffsetB + = LWBuffAMap1d::fromMatrixCoord(make_coord2d(LWBuffAShape::BlockHeight, 0u), ldsld); + + // Local read offsets for mfma frags + auto ldsReadOffsetA + = ldsWriteOffsetA + + LWBuffAMap1d::fromMatrixCoord(make_coord2d(get<0>(localWarpOffset), 0u), ldsld); + auto ldsReadOffsetB + = ldsWriteOffsetB + + LWBuffBMap1d::fromMatrixCoord(make_coord2d(get<1>(localWarpOffset), 0u), ldsld); + + /// + /// Write prefetch to local + /// + localWriteCoopA(ldsPtrLo + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); + localWriteCoopB(ldsPtrLo + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + + /// + /// Initialize accumulation frags + /// + MfmaFragAcc fragsAcc[BLOCKS_X][BLOCKS_Y]; + fill(fragsAcc, 0.0f); + + /// + /// Synchronize warps and memory + /// + synchronize_workgroup(); - // Write prefetch to second LDS buffer - localWriteCoopA(ldsPtrHi + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); - localWriteCoopB(ldsPtrHi + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + /// + /// Accumulate A * B for all mfma frags in warp tile + /// + for(auto currentK = ROCWMMA_K; currentK < k; currentK += ROCWMMA_K) + { + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; - // Make sure that all waves have finished reading / writing to lds for currentK. - synchronize_workgroup(); + // Local read mfma frags from first LDS buffer + localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); + localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); - // Swap Lds buffers - auto* tmp = ldsPtrLo; - ldsPtrLo = ldsPtrHi; - ldsPtrHi = tmp; - } + // Prefetch next round of global frags + globalReadCoopA(grBuffA, a + globalReadOffsetA, lda, warpIndex); + globalReadCoopB(grBuffB, b + globalReadOffsetB, ldb, warpIndex); + + // Advance offsets to next k step + globalReadOffsetA += kStepOffsetA; + globalReadOffsetB += kStepOffsetB; + + // accum(A * B) + mfma(fragsAcc, fragsA, fragsB, fragsAcc); + + // Write prefetch to second LDS buffer + localWriteCoopA(ldsPtrHi + ldsWriteOffsetA, grBuffA, ldsld, warpIndex); + localWriteCoopB(ldsPtrHi + ldsWriteOffsetB, grBuffB, ldsld, warpIndex); + + // Make sure that all waves have finished reading / writing to lds for currentK. + synchronize_workgroup(); + + // Swap Lds buffers + auto* tmp = ldsPtrLo; + ldsPtrLo = ldsPtrHi; + ldsPtrHi = tmp; + } + + /// + /// Start loading C + /// + using MfmaFragCMap1d = GetDataLayout_t; + using MfmaFragDMap1d = GetDataLayout_t; + + MfmaFragC fragsC[BLOCKS_X][BLOCKS_Y]; + globalReadC(fragsC, c + MfmaFragCMap1d::fromMatrixCoord(warpTileCoord, ldc), ldc); - /// - /// Start loading C - /// - using MfmaFragCMap1d = GetDataLayout_t; - using MfmaFragDMap1d = GetDataLayout_t; - - MfmaFragC fragsC[BLOCKS_X][BLOCKS_Y]; - globalReadC(fragsC, c + MfmaFragCMap1d::fromMatrixCoord(warpTileCoord, ldc), ldc); - - /// - /// Clean up tail A * B - /// - MfmaFragA fragsA[BLOCKS_X]; - MfmaFragB fragsB[BLOCKS_Y]; - - // Local read mfma frags - localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); - localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); - mfma(fragsAcc, fragsA, fragsB, fragsAcc); - - /// - /// D = alpha * accum + beta * C - /// - MfmaFragD fragsD[BLOCKS_X][BLOCKS_Y]; - uniformFma(fragsD, alpha, fragsAcc, beta, fragsC); - globalWriteD(d + MfmaFragDMap1d::fromMatrixCoord(warpTileCoord, ldd), fragsD, ldd); + /// + /// Clean up tail A * B + /// + MfmaFragA fragsA[BLOCKS_X]; + MfmaFragB fragsB[BLOCKS_Y]; + + // Local read mfma frags + localReadA(fragsA, ldsPtrLo + ldsReadOffsetA, ldsld); + localReadB(fragsB, ldsPtrLo + ldsReadOffsetB, ldsld); + mfma(fragsAcc, fragsA, fragsB, fragsAcc); + + /// + /// D = alpha * accum + beta * C + /// + MfmaFragD fragsD[BLOCKS_X][BLOCKS_Y]; + uniformFma(fragsD, alpha, fragsAcc, beta, fragsC); + globalWriteD(d + MfmaFragDMap1d::fromMatrixCoord(warpTileCoord, ldd), fragsD, ldd); + } } ROCWMMA_HOST void gemm_test(uint32_t m, uint32_t n, uint32_t k, ComputeT alpha, ComputeT beta) diff --git a/test/gemm/gemm_kernel_base_impl.hpp b/test/gemm/gemm_kernel_base_impl.hpp index 1e5cdf62..2584093c 100644 --- a/test/gemm/gemm_kernel_base_impl.hpp +++ b/test/gemm/gemm_kernel_base_impl.hpp @@ -644,35 +644,39 @@ namespace rocwmma this->mBeta); // beta }; + hipEvent_t startEvent, stopEvent; + CHECK_HIP_ERROR(hipEventCreate(&startEvent)); + CHECK_HIP_ERROR(hipEventCreate(&stopEvent)); + // Cold runs for frequency warm-up for(uint32_t i = 0; i < mColdRuns; ++i) { rocwmmaKernel(); } - // Use the hot runs for timing - hipEvent_t startEvent, stopEvent; - CHECK_HIP_ERROR(hipEventCreate(&startEvent)); - CHECK_HIP_ERROR(hipEventCreate(&stopEvent)); - CHECK_HIP_ERROR(hipEventRecord(startEvent)); + // Finish cold runs + CHECK_HIP_ERROR(hipEventRecord(stopEvent)); + CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); + + // Use the hot runs for timing. Ensure sequential execution. + mElapsedTimeMs = 0.0; for(uint32_t i = 0; i < mHotRuns; ++i) { + CHECK_HIP_ERROR(hipEventRecord(startEvent)); rocwmmaKernel(); + CHECK_HIP_ERROR(hipEventRecord(stopEvent)); + CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); + auto timeMs = 0.0f; + CHECK_HIP_ERROR(hipEventElapsedTime(&timeMs, startEvent, stopEvent)); + mElapsedTimeMs += timeMs; } - CHECK_HIP_ERROR(hipEventRecord(stopEvent)); - CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); - - auto timeMs = 0.0f; - CHECK_HIP_ERROR(hipEventElapsedTime(&timeMs, startEvent, stopEvent)); // Calculate efficiency auto& deviceInfo = DeviceInfo::instance(); auto devicePeakGFlopsPerSec = deviceInfo->peakGFlopsPerSec(); - - mElapsedTimeMs = float64_t(timeMs); - mTotalGFlops = calculateGFlops(mM, mN, mK); - mMeasuredTFlopsPerSec = calculateTFlopsPerSec(mM, mN, mK, mElapsedTimeMs) + mTotalGFlops = calculateGFlops(mM, mN, mK); + mMeasuredTFlopsPerSec = calculateTFlopsPerSec(mM, mN, mK, mElapsedTimeMs) * static_cast(mHotRuns); mEfficiency = round(mMeasuredTFlopsPerSec / devicePeakGFlopsPerSec * 100000.0); @@ -802,37 +806,42 @@ namespace rocwmma std::numeric_limits::signaling_NaN()); } + hipEvent_t startEvent, stopEvent; + CHECK_HIP_ERROR(hipEventCreate(&startEvent)); + CHECK_HIP_ERROR(hipEventCreate(&stopEvent)); + // Cold runs for frequency warm-up for(uint32_t i = 0; i < mColdRuns; ++i) { refKernel(); } + // Finish cold runs + CHECK_HIP_ERROR(hipEventRecord(stopEvent)); + CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); + // Hot runs for timing - hipEvent_t startEvent, stopEvent; - CHECK_HIP_ERROR(hipEventCreate(&startEvent)); - CHECK_HIP_ERROR(hipEventCreate(&stopEvent)); - CHECK_HIP_ERROR(hipEventRecord(startEvent)); + auto elapsedTimeMs = 0.0; for(uint32_t i = 0; i < mHotRuns; ++i) { + CHECK_HIP_ERROR(hipEventRecord(startEvent)); refKernel(); + CHECK_HIP_ERROR(hipEventRecord(stopEvent)); + CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); + auto timeMs = 0.0f; + CHECK_HIP_ERROR(hipEventElapsedTime(&timeMs, startEvent, stopEvent)); + elapsedTimeMs += timeMs; } - CHECK_HIP_ERROR(hipEventRecord(stopEvent)); - CHECK_HIP_ERROR(hipEventSynchronize(stopEvent)); - auto timeMs = 0.0f; - CHECK_HIP_ERROR(hipEventElapsedTime(&timeMs, startEvent, stopEvent)); CHECK_HIP_ERROR(hipEventDestroy(startEvent)); CHECK_HIP_ERROR(hipEventDestroy(stopEvent)); // Calculate reference efficiency if constexpr(mBenchRef) { - auto& deviceInfo = DeviceInfo::instance(); auto devicePeakGFlopsPerSec = deviceInfo->peakGFlopsPerSec(); - auto elapsedTimeMs = float64_t(timeMs); auto measuredTFlopsPerSec = calculateTFlopsPerSec(mM, mN, mK, elapsedTimeMs) * static_cast(mHotRuns); diff --git a/test/unit/CMakeLists.txt b/test/unit/CMakeLists.txt index 37709de6..a6183bca 100644 --- a/test/unit/CMakeLists.txt +++ b/test/unit/CMakeLists.txt @@ -2,7 +2,7 @@ # # MIT License # - # Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + # Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. # # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal @@ -51,6 +51,7 @@ endfunction() # Add unit tests add_subdirectory(contamination_test) add_subdirectory(layout_test) +add_subdirectory(layout_traits_test) add_subdirectory(map_util_test) add_subdirectory(load_store_matrix_sync_test) add_subdirectory(load_store_matrix_coop_sync_test) diff --git a/test/unit/io_shape_test/detail/io_shape.hpp b/test/unit/io_shape_test/detail/io_shape.hpp index 7f23eecf..ca8930c1 100644 --- a/test/unit/io_shape_test/detail/io_shape.hpp +++ b/test/unit/io_shape_test/detail/io_shape.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -58,49 +58,43 @@ namespace rocwmma template bool waveTest() { - bool err = false; - constexpr auto BlockDim = std::is_same_v ? BlockM : BlockN; - constexpr auto KDim = std::is_same_v ? BlockM : BlockK; - - constexpr auto MaxVW - = std::is_same_v ? detail:: - MaxVWSelector:: - Result - : std::is_same_v - ? detail:: - MaxVWSelector:: - Result - : (std::is_same::value || ROCWMMA_ARCH_GFX11 ? 1u : 4u); - constexpr auto VW - = std::is_same_v - ? std::is_same::value || BlockDim > 32 ? MaxVW : 1u - : std::is_same_v - ? (std::is_same::value || BlockDim > 32 ? MaxVW : 1u) - : (std::is_same::value ? MaxVW : 1u); - - using RowNT - = LayoutProfile::template RowNT; - using ColNT - = LayoutProfile::template ColNT; - - using Row = LayoutProfile::template Row; - using Col = LayoutProfile::template Col; - - using Profile = typename std::conditional_t< - std::is_same_v, - std::conditional_t, - std::conditional_t, - std::conditional_t, - RowNT>>; - - using DataLayout = DataLayout::template Array1d; - - using IOLayout = IOLayout; - - err |= (IOLayout::MaxVW != MaxVW); - err |= (IOLayout::VW != VW); - err |= (!std::is_same::value); - err |= (!std::is_same::value); + bool err = false; + + // Accum requires WaveCount > 1 + if constexpr(!std::is_same_v || WaveCount == 1) + { + constexpr auto BlockDim = std::is_same_v ? BlockM : BlockN; + constexpr auto KDim = std::is_same_v ? BlockM : BlockK; + + using detail::MaxVWSelector; + using detail::MmaDimSelector; + + constexpr auto ExpectMaxVW + = MaxVWSelector::Result; + + constexpr auto ExpectVW + = std::is_same_v + ? std::is_same::value || BlockDim > 32 + ? ExpectMaxVW + : 1u + : std::is_same_v + ? (std::is_same::value || BlockDim > 32 + ? ExpectMaxVW + : 1u) + : (std::is_same::value ? ExpectMaxVW : 1u); + + constexpr auto ExpectMmaDim = MmaDimSelector::Result; + + using IOLayout = IOLayout; + using IOLayoutInt + = IOLayoutInt; + using ExpectDataLayout = DataLayout::template Array1d; + + err |= (IOLayout::MaxVW != ExpectMaxVW); + err |= (IOLayout::VW != ExpectVW); + err |= (IOLayoutInt::MmaDim != ExpectMmaDim); + err |= (!std::is_same_v); + } return err; } diff --git a/test/unit/layout_test/device/col_layout.hpp b/test/unit/layout_test/device/col_layout.hpp index 278618c7..b189983c 100644 --- a/test/unit/layout_test/device/col_layout.hpp +++ b/test/unit/layout_test/device/col_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,16 +28,13 @@ #define ROCWMMA_DEVICE_COL_LAYOUT_HPP #include "unit_test_traits.hpp" +#include #include -#include #include namespace rocwmma { - template + template __global__ void ColLayout(uint32_t m, uint32_t n, DataT const* in, @@ -46,47 +43,51 @@ namespace rocwmma DataT param1, DataT param2) { - if constexpr (FragSize_guard::enable()) + if constexpr(FragSize_guard::enable()) { enum : uint32_t { - MaxVectorWidth - = detail::MaxVWSelector::Result, - VectorWidth = std::is_same_v ? MaxVectorWidth : 1 + BlockHeight = BlockM, + BlockWidth = BlockN, + + BlockDim = BlockM, + BlockK = BlockN, + + MaxVectorWidth = detail::MaxVWSelector::Result, + VectorWidth = MaxVectorWidth }; - using IOTraits = IOTraits; - using LayoutT = typename LayoutProfile::Col::MatrixLayout; - using Mapping = MappingUtil; + using IOTraits = IOTraits; + + using LayoutT = conditional_t< + is_same_v, + MatrixLayout::ColInlineVW, + MatrixLayout::ColOrthoVW>; + using Mapping = MappingUtil; + + constexpr auto ioCount = IOTraits::IOCount; auto baseOffset = LayoutT::baseOffset(); - auto iocount = IOTraits::IOCount; auto matrixCoord = Mapping::matrixCoord(); - enum : uint32_t + auto currentOffset = matrixCoord + baseOffset; + for(auto i = 0u; i < ioCount; ++i) { - MajorIndex = std::is_same_v ? 0 : 1, - MinorIndex = std::is_same_v ? 1 : 0 - }; - - for(uint32_t i = 0; i < iocount; ++i) - { - for(int j = 0; j < VectorWidth; j++) + for(auto j = 0u; j < VectorWidth; ++j) { - auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + auto index = Mapping::dataOffset(currentOffset, ld) + j; out[index] = in[index]; } - baseOffset += LayoutT::incrementalOffset(i); + currentOffset += LayoutT::incrementalOffset(i); } } } - + } // namespace rocwmma #endif // ROCWMMA_DEVICE_COL_LAYOUT_HPP diff --git a/test/unit/layout_test/device/colnt_layout.hpp b/test/unit/layout_test/device/colnt_layout.hpp index 07e1399a..47ccb0ef 100644 --- a/test/unit/layout_test/device/colnt_layout.hpp +++ b/test/unit/layout_test/device/colnt_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,16 +28,13 @@ #define ROCWMMA_DEVICE_COLNT_LAYOUT_HPP #include "unit_test_traits.hpp" +#include #include -#include #include namespace rocwmma { - template + template __global__ void ColNTLayout(uint32_t m, uint32_t n, DataT const* in, @@ -46,44 +43,45 @@ namespace rocwmma DataT param1, DataT param2) { - if constexpr (FragSize_guard::enable()) + if constexpr(FragSize_guard::enable()) { enum : uint32_t { - MaxVectorWidth - = detail::MaxVWSelector::Result, - VectorWidth = std::is_same_v ? MaxVectorWidth : 1 + BlockHeight = BlockM, + BlockWidth = BlockN, + + BlockDim = BlockM, + BlockK = BlockN, + + MaxVectorWidth = detail::MaxVWSelector::Result, + VectorWidth = std::is_same_v ? MaxVectorWidth : 1u, }; - using IOTraits = IOTraits; - using LayoutT - = typename LayoutProfile::ColNT::MatrixLayout; - using Mapping = MappingUtil; + using IOTraits = IOTraits; + using LayoutT = conditional_t< + is_same_v, + MatrixLayout::ColOrthoVW, + MatrixLayout::ColOrthoVW>; + using Mapping = MappingUtil; + constexpr auto ioCount = IOTraits::IOCount; auto baseOffset = LayoutT::baseOffset(); - auto iocount = IOTraits::IOCount; auto matrixCoord = Mapping::matrixCoord(); - enum : uint32_t - { - MajorIndex = std::is_same_v ? 0 : 1, - MinorIndex = std::is_same_v ? 1 : 0 - }; - - for(uint32_t i = 0; i < iocount; ++i) + auto currentOffset = matrixCoord + baseOffset; + for(auto i = 0u; i < ioCount; ++i) { - for(int j = 0; j < VectorWidth; j++) + for(auto j = 0u; j < VectorWidth; ++j) { - auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + auto index = Mapping::dataOffset(currentOffset, ld) + j; out[index] = in[index]; } - baseOffset += LayoutT::incrementalOffset(i); + currentOffset += LayoutT::incrementalOffset(i); } } } diff --git a/test/unit/layout_test/device/row_layout.hpp b/test/unit/layout_test/device/row_layout.hpp index ac12ad23..da201fbf 100644 --- a/test/unit/layout_test/device/row_layout.hpp +++ b/test/unit/layout_test/device/row_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,17 +28,14 @@ #define ROCWMMA_DEVICE_ROW_LAYOUT_HPP #include "unit_test_traits.hpp" +#include #include -#include #include namespace rocwmma { - template + template __global__ void RowLayout(uint32_t m, uint32_t n, DataT const* in, @@ -47,12 +44,12 @@ namespace rocwmma DataT param1, DataT param2) { - if constexpr (FragSize_guard::enable()) + if constexpr(FragSize_guard::enable()) { enum : uint32_t { @@ -60,40 +57,37 @@ namespace rocwmma BlockWidth = BlockN, BlockDim = BlockN, - KDim = BlockM, + BlockK = BlockM, - MaxVectorWidth - = detail::MaxVWSelector::Result, - VectorWidth = std::is_same_v ? MaxVectorWidth : 1 + MaxVectorWidth = detail::MaxVWSelector::Result, + VectorWidth = MaxVectorWidth }; - using IOTraits = IOTraits; - using LayoutT = typename LayoutProfile::Row::MatrixLayout; - using Mapping = MappingUtil; + using IOTraits = IOTraits; + using LayoutT = conditional_t< + is_same_v, + MatrixLayout::RowInlineVW, + MatrixLayout::RowOrthoVW>; + using Mapping = MappingUtil; + + constexpr auto ioCount = IOTraits::IOCount; auto baseOffset = LayoutT::baseOffset(); - auto iocount = IOTraits::IOCount; auto matrixCoord = Mapping::matrixCoord(); - enum : uint32_t - { - MajorIndex = std::is_same_v ? 0 : 1, - MinorIndex = std::is_same_v ? 1 : 0 - }; - - for(uint32_t i = 0; i < iocount; ++i) + auto currentOffset = matrixCoord + baseOffset; + for(auto i = 0u; i < ioCount; ++i) { - for(uint32_t j = 0; j < VectorWidth; j++) + for(auto j = 0u; j < VectorWidth; ++j) { - auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + auto index = Mapping::dataOffset(currentOffset, ld) + j; out[index] = in[index]; } - baseOffset += LayoutT::incrementalOffset(i); + currentOffset += LayoutT::incrementalOffset(i); } } } - + } // namespace rocwmma #endif // ROCWMMA_DEVICE_ROW_LAYOUT_HPP diff --git a/test/unit/layout_test/device/rownt_layout.hpp b/test/unit/layout_test/device/rownt_layout.hpp index beaae4b7..fc8967b8 100644 --- a/test/unit/layout_test/device/rownt_layout.hpp +++ b/test/unit/layout_test/device/rownt_layout.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,17 +28,14 @@ #define ROCWMMA_DEVICE_ROWNT_LAYOUT_HPP #include "unit_test_traits.hpp" +#include #include -#include #include namespace rocwmma { - template + template __global__ void RowNTLayout(uint32_t m, uint32_t n, DataT const* in, @@ -47,12 +44,12 @@ namespace rocwmma DataT param1, DataT param2) { - if constexpr (FragSize_guard::enable()) + if constexpr(FragSize_guard::enable()) { enum : uint32_t { @@ -60,39 +57,32 @@ namespace rocwmma BlockWidth = BlockN, BlockDim = BlockN, - KDim = BlockM, + BlockK = BlockM, - MaxVectorWidth - = std::is_same_v - ? 1 - : detail::MaxVWSelector::Result, - VectorWidth = std::is_same_v ? MaxVectorWidth : 1, + MaxVectorWidth = detail::MaxVWSelector::Result, + VectorWidth = std::is_same_v ? MaxVectorWidth : 1u, }; - using IOTraits = IOTraits; - using LayoutT - = typename LayoutProfile::RowNT::MatrixLayout; + using IOTraits = IOTraits; + using LayoutT = conditional_t< + is_same_v, + MatrixLayout::RowOrthoVW, + MatrixLayout::RowOrthoVW>; using Mapping = MappingUtil; + constexpr auto ioCount = IOTraits::IOCount; auto baseOffset = LayoutT::baseOffset(); - auto iocount = IOTraits::IOCount; auto matrixCoord = Mapping::matrixCoord(); - enum : uint32_t - { - MajorIndex = std::is_same_v ? 0 : 1, - MinorIndex = std::is_same_v ? 1 : 0 - }; - - for(uint32_t i = 0; i < iocount; ++i) + auto currentOffset = matrixCoord + baseOffset; + for(auto i = 0u; i < ioCount; ++i) { - for(uint32_t j = 0; j < VectorWidth; j++) + for(auto j = 0u; j < VectorWidth; ++j) { - auto index = (get(matrixCoord) * ld + get(matrixCoord)) - + Mapping::dataOffset(baseOffset, ld) + j; + auto index = Mapping::dataOffset(currentOffset, ld) + j; out[index] = in[index]; } - baseOffset += LayoutT::incrementalOffset(i); + currentOffset += LayoutT::incrementalOffset(i); } } } diff --git a/test/unit/layout_traits_test/CMakeLists.txt b/test/unit/layout_traits_test/CMakeLists.txt new file mode 100644 index 00000000..de94d7e7 --- /dev/null +++ b/test/unit/layout_traits_test/CMakeLists.txt @@ -0,0 +1,51 @@ +############################################################################### +# +# MIT License +# +# Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +############################################################################### + +# Include path for current test files +set(ROCWMMA_TEST_INCLUDE_DIRS ${CMAKE_CURRENT_SOURCE_DIR} ${ROCWMMA_TEST_INCLUDE_DIRS}) + +set(LayoutTraitsTestSources ${UnitCommonSources} + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_16.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_32.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_64.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_128.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_256.cpp + + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_16.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_32.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_64.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_128.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/layout_traits_int_256.cpp + + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/smoketest_layout_traits.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/smoketest_layout_traits_int.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/regressiontest_layout_traits.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/regressiontest_layout_traits_int.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/extendedtest_layout_traits.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/test/emulation/extendedtest_layout_traits_int.cpp + ) + +add_rocwmma_unit_test(layout_traits_test ${LayoutTraitsTestSources}) diff --git a/test/unit/layout_traits_test/detail/layout_traits.hpp b/test/unit/layout_traits_test/detail/layout_traits.hpp new file mode 100644 index 00000000..d7e1c4c8 --- /dev/null +++ b/test/unit/layout_traits_test/detail/layout_traits.hpp @@ -0,0 +1,155 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_LAYOUT_TRAITS_TEST_DETAIL_HPP +#define ROCWMMA_LAYOUT_TRAITS_TEST_DETAIL_HPP + +#include "device/layout_traits.hpp" +#include "helper_macros.hpp" +#include "unit_kernel_base.hpp" + +namespace rocwmma +{ + + // Wrapper into the actual device function + template + struct LayoutTraitsKernel final : public UnitKernelBase + { + private: + using Base = UnitKernelBase; + + template + using TestGuard = FragSize_guard; + + public: + LayoutTraitsKernel() = default; + ~LayoutTraitsKernel() final = default; + + void setupImpl(typename Base::DataStorage::ProblemSize const& probsize) final + { + // Need at least 1 element for the result + auto& dataInstance = Base::DataStorage::instance(); + dataInstance->resizeStorage(probsize); + + dataInstance->hostOut().get()[0] = static_cast(ERROR_VALUE); + dataInstance->copyData(dataInstance->deviceOut(), dataInstance->hostOut(), 1); + + // Pass in warpSize from host to validate + Base::mParam1 = static_cast(Base::DeviceInfo::instance()->warpSize()); + } + + void validateResultsImpl() final + { + auto& dataInstance = Base::DataStorage::instance(); + + // Cache current kernel result from device + dataInstance->copyData(dataInstance->hostOut(), dataInstance->deviceOut(), 1); + + // Check the single output result + Base::mValidationResult = (dataInstance->hostOut().get()[0] == DataT(SUCCESS_VALUE)); + } + + bool checkQuirks() const final + { + auto waveSize = Base::DeviceInfo::instance()->warpSize(); + auto deviceArch = Base::DeviceInfo::instance()->getGcnArch(); + + // The test guard for this class requires 2 values at runtime. + auto dispatchGuard = [waveSize, deviceArch]() { + bool dispatchResult = false; + +#define CASE_IMPL_ASSIGN2(WAVE_SIZE, ARCH_ID) \ + dispatchResult = TestGuard::enable(); + +#define SWITCH_BODY_WAVE_SIZE(ARCH_ID) \ + ROCWMMA_SWITCH_BODY2_ARG2( \ + waveSize, CASE_IMPL_ASSIGN2, HipDevice::Wave32, HipDevice::Wave64, ARCH_ID) + +#define DISPATCH_GUARD_BODY \ + ROCWMMA_SWITCH_BODY10_ARG1(deviceArch, \ + SWITCH_BODY_WAVE_SIZE, \ + HipDevice::GFX908, \ + HipDevice::GFX90A, \ + HipDevice::GFX940, \ + HipDevice::GFX941, \ + HipDevice::GFX942, \ + HipDevice::GFX1100, \ + HipDevice::GFX1101, \ + HipDevice::GFX1102, \ + HipDevice::GFX1200, \ + HipDevice::GFX1201) + + DISPATCH_GUARD_BODY + +#undef CASE_IMPL_ASSIGN2 +#undef SWITCH_BODY_WAVE_SIZE +#undef DISPATCH_GUARD_BODY + + return dispatchResult; + }; + + return Base::checkQuirks() && dispatchGuard(); + } + + typename Base::KernelFunc kernelImpl() const final + { + return typename Base::KernelFunc(layoutTraitsTest); + } + }; + + // This is the GeneratorImpl class + struct LayoutTraitsGenerator + { + // Indices to test parameters + enum : uint32_t + { + BlockM = 0, + BlockN = 1, + DataT = 2, + DataLayoutT = 3 + }; + + using ResultT = std::shared_ptr; + + template + static ResultT generate(std::tuple testParams) + { + // Map GTest params to Kernel params + using TestParamsT = std::tuple; + using KernelT + = LayoutTraitsKernel::value, // BlockM + std::tuple_element_t::value, // BlockN + std::tuple_element_t, // DataT + std::tuple_element_t // DataLayout + >; + + return std::make_shared(); + } + }; + +} // namespace rocwmma + +#endif // ROCWMMA_LAYOUT_TRAITS_TEST_DETAIL_HPP diff --git a/test/unit/layout_traits_test/detail/layout_traits_int.hpp b/test/unit/layout_traits_test/detail/layout_traits_int.hpp new file mode 100644 index 00000000..4de13641 --- /dev/null +++ b/test/unit/layout_traits_test/detail/layout_traits_int.hpp @@ -0,0 +1,165 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_LAYOUT_TRAITS_INT_TEST_DETAIL_HPP +#define ROCWMMA_LAYOUT_TRAITS_INT_TEST_DETAIL_HPP + +#include "device/layout_traits_int.hpp" +#include "helper_macros.hpp" +#include "unit_kernel_base.hpp" + +namespace rocwmma +{ + + // Wrapper into the actual device function + template + struct LayoutTraitsIntKernel final : public UnitKernelBase + { + private: + using Base = UnitKernelBase; + + template + using TestGuard = FragSize_guard; + + public: + LayoutTraitsIntKernel() = default; + ~LayoutTraitsIntKernel() final = default; + + void setupImpl(typename Base::DataStorage::ProblemSize const& probsize) final + { + // Need at least 1 element for the result + auto& dataInstance = Base::DataStorage::instance(); + dataInstance->resizeStorage(probsize); + + dataInstance->hostOut().get()[0] = static_cast(ERROR_VALUE); + dataInstance->copyData(dataInstance->deviceOut(), dataInstance->hostOut(), 1); + + // Pass in warpSize from host to validate + Base::mParam1 = static_cast(Base::DeviceInfo::instance()->warpSize()); + } + + void validateResultsImpl() final + { + auto& dataInstance = Base::DataStorage::instance(); + + // Cache current kernel result from device + dataInstance->copyData(dataInstance->hostOut(), dataInstance->deviceOut(), 1); + + // Check the single output result + Base::mValidationResult = (dataInstance->hostOut().get()[0] == DataT(SUCCESS_VALUE)); + } + + bool checkQuirks() const final + { + auto waveSize = Base::DeviceInfo::instance()->warpSize(); + auto deviceArch = Base::DeviceInfo::instance()->getGcnArch(); + + // The test guard for this class requires 2 values at runtime. + auto dispatchGuard = [waveSize, deviceArch]() { + bool dispatchResult = false; + +#define CASE_IMPL_ASSIGN2(WAVE_SIZE, ARCH_ID) \ + dispatchResult = TestGuard::enable(); + +#define SWITCH_BODY_WAVE_SIZE(ARCH_ID) \ + ROCWMMA_SWITCH_BODY2_ARG2( \ + waveSize, CASE_IMPL_ASSIGN2, HipDevice::Wave32, HipDevice::Wave64, ARCH_ID) + +#define DISPATCH_GUARD_BODY \ + ROCWMMA_SWITCH_BODY10_ARG1(deviceArch, \ + SWITCH_BODY_WAVE_SIZE, \ + HipDevice::GFX908, \ + HipDevice::GFX90A, \ + HipDevice::GFX940, \ + HipDevice::GFX941, \ + HipDevice::GFX942, \ + HipDevice::GFX1100, \ + HipDevice::GFX1101, \ + HipDevice::GFX1102, \ + HipDevice::GFX1200, \ + HipDevice::GFX1201) + + DISPATCH_GUARD_BODY + +#undef CASE_IMPL_ASSIGN2 +#undef SWITCH_BODY_WAVE_SIZE +#undef DISPATCH_GUARD_BODY + + return dispatchResult; + }; + + return Base::checkQuirks() && dispatchGuard(); + } + + typename Base::KernelFunc kernelImpl() const final + { + return typename Base::KernelFunc( + layoutTraitsIntTest); + } + }; + + // This is the GeneratorImpl class + struct LayoutTraitsIntGenerator + { + // Indices to test parameters + enum : uint32_t + { + BlockM = 0, + BlockN = 1, + DataT = 2, + DataLayoutT = 3, + MmaDim = 4, + SplitK = 5, + }; + + using ResultT = std::shared_ptr; + + template + static ResultT generate(std::tuple testParams) + { + // Map GTest params to Kernel params + using TestParamsT = std::tuple; + using KernelT = LayoutTraitsIntKernel< + std::tuple_element_t::value, // BlockM + std::tuple_element_t::value, // BlockN + std::tuple_element_t, // DataT + std::tuple_element_t, // DataLayout + std::tuple_element_t::value, // MmaDim + std::tuple_element_t::value // SplitK + >; + + return std::make_shared(); + } + }; + +} // namespace rocwmma + +#endif // ROCWMMA_LAYOUT_TRAITS_TEST_DETAIL_HPP diff --git a/test/unit/layout_traits_test/device/layout_traits.hpp b/test/unit/layout_traits_test/device/layout_traits.hpp new file mode 100644 index 00000000..e469ffdb --- /dev/null +++ b/test/unit/layout_traits_test/device/layout_traits.hpp @@ -0,0 +1,1566 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_DEVICE_LAYOUT_TRAITS_TEST_HPP +#define ROCWMMA_DEVICE_LAYOUT_TRAITS_TEST_HPP + +#include + +#include "unit_test_traits.hpp" + +static constexpr uint32_t ERROR_VALUE = 7; +static constexpr uint32_t SUCCESS_VALUE = 0; + +namespace rocwmma +{ + + template + ROCWMMA_HOST bool + testLayoutPair(const char* file, const char* line, std::ostream& stream = std::cout) + { + constexpr bool is_layout_same_result = is_layout_same_v; + constexpr bool is_layout_orthogonal_result = is_layout_orthogonal_v; + constexpr bool compare_result = ((is_layout_same_result == ExpectSame) + && (is_layout_orthogonal_result == ExpectOrthogonal)); + + if constexpr(DebugOnFail) + { + stream << "File: " << file << " L:" << line << std::endl; + stream << "" << std::endl; + stream << "Lhs: " << LayoutLhs{} << std::endl; + stream << rocwmma::layout_traits{}; + stream << "Rhs: " << LayoutRhs{} << std::endl; + stream << rocwmma::layout_traits{}; + stream << "is_layout_same: " << is_layout_same_result << " Expected: " << ExpectSame + << std::endl; + stream << "is_layout_orthogonal: " << is_layout_orthogonal_result + << " Expected: " << ExpectOrthogonal << std::endl; + stream << "Result:" << (compare_result ? "PASS" : "FAIL") << std::endl; + stream << "" << std::endl; + } + + return compare_result; + } + + ROCWMMA_DEVICE inline bool isFirstThread() + { + return (threadIdx.x == 0) && (threadIdx.y == 0) && (threadIdx.z == 0) && (blockIdx.x == 0) + && (blockIdx.y == 0) && (blockIdx.z == 0); + } + + template + ROCWMMA_DEVICE bool testLayoutPair(const char* file, uint32_t line) + { + constexpr bool is_layout_same_result = is_layout_same_v; + constexpr bool is_layout_orthogonal_result = is_layout_orthogonal_v; + constexpr bool compare_result = ((is_layout_same_result == ExpectSame) + && (is_layout_orthogonal_result == ExpectOrthogonal)); + + if(!compare_result && DebugOnFail && isFirstThread()) + { + printf("File: %s L:%d\n", file, line); + printf("\n"); + printf("is_layout_same: %d (Expected: %d)\n", is_layout_same_result, ExpectSame); + printf("is_layout_orthogonal: %d (Expected: %d)\n", + is_layout_orthogonal_result, + ExpectOrthogonal); + printf("%s\n", (compare_result ? "PASS" : "FAIL")); + printf("\n"); + } + + return compare_result; + } + +#define ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( \ + LayoutLhs, LayoutRhs, ExpectSame, ExpectOrthogonal, DebugOnFail) \ + testLayoutPair(__FILE__, \ + __LINE__); + + template + ROCWMMA_DEVICE bool dataLayoutTraitsTest() + { + constexpr bool debug_on_fail = true; + + using rocwmma::DataLayout::ColMajor; + using rocwmma::DataLayout::RowMajor; + + // DataLayouts are invariant of matrix layout properties + // Test both the meta tags and functional classes + using SameMeta = conditional_t, row_major, col_major>; + using OrthoMeta = conditional_t, col_major, row_major>; + using SameFunc = conditional_t, RowMajor, ColMajor>; + using OrthoFunc = conditional_t, ColMajor, RowMajor>; + + bool result = true; + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, SameMeta, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, OrthoMeta, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, SameFunc, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(DataLayoutT, OrthoFunc, false, true, debug_on_fail); + // clang-format on + + return result; + } + + template + struct RegisterLayoutTestingSet + { + using ColInline = RegisterLayout::Storage< + MatrixLayout::ColInlineVW, + DataLayout>; + using ColOrtho = RegisterLayout::Storage< + MatrixLayout::ColOrthoVW, + DataLayout>; + using RowInline = RegisterLayout::Storage< + MatrixLayout::RowInlineVW, + DataLayout>; + using RowOrtho = RegisterLayout::Storage< + MatrixLayout::RowOrthoVW, + DataLayout>; + + using MmaInput = RegisterLayout::MmaInput; + using MmaAcc = RegisterLayout::MmaAcc; + }; + + template + using MatrixLayout_t = typename layout_traits::MatrixLayout; + + template + ROCWMMA_DEVICE bool matrixLayoutTraitsTestNonInterleaved0() + { + constexpr bool debug_on_fail = true; + + // Testing MatrixLayout properties + // MatrixLayouts are invariant to vector width + constexpr uint32_t VectorWidth = MaxVectorWidth; + using Set = RegisterLayoutTestingSet; + + bool result = true; + + // Matrix <-> Matrix layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE constexpr bool testRowMajor() + { + return is_layout_same_v; + } + + template + ROCWMMA_DEVICE constexpr bool testColMajor() + { + return is_layout_same_v; + } + + template + ROCWMMA_DEVICE constexpr bool testMmaDim() + { + return (MmaDim == 16u && (bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED) + || (MmaDim == 32u && (bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED + && !is_same_v); + } + + template + ROCWMMA_DEVICE constexpr bool testMmaAccVW() + { + return MaxVectorWidth + == ((bool)ROCWMMA_ARCH_GFX12 + ? 8u + : ((is_same_v || (bool)ROCWMMA_ARCH_GFX11) ? 1u : 4u)); + } + + template + ROCWMMA_DEVICE constexpr void debug() + { + if(isFirstThread()) + { + using traits_lhs = layout_traits; + using traits_rhs = layout_traits; + printf("testCompatibleRegisterParams: %d\n", LayoutTraits_impl::testCompatibleRegisterParams()); + printf("MmaDim: %d, MmaDim: %d\n", traits_lhs::MmaDim, traits_rhs::MmaDim); + printf("DataFormat: %d, DataFormat: %d\n", (int)traits_lhs::Format, (int)traits_rhs::Format); + printf("is_valid: %d, is_valid: %d\n", traits_lhs::is_valid, traits_rhs::is_valid); + printf("is_same_dataT: %d\n", is_same_v); + } + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved0() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = same + constexpr uint32_t VectorWidth = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + + bool result = true; + + // Covered in another test case + if constexpr(VectorWidth == 1u) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved1() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = orthogonal + constexpr uint32_t VectorWidth = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet>; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + + bool result = true; + + // Covered in another test case + if constexpr(VectorWidth == 1u) + { + return result; + } + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_acc_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved2() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = 1u + // datalayout = same + constexpr uint32_t VectorWidth = 1u; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + bool result = true; + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved3() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = 1u + // datalayout = orthogonal + constexpr uint32_t VectorWidth = 1u; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet>; + + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + bool result = true; + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw && is_mma_dim), debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved4() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW0 = 1u + // VW1 = MaxVW + // datalayout = same + constexpr uint32_t VectorWidth0 = 1u; + constexpr uint32_t VectorWidth1 = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + + bool result = true; + + // Case tested in #0,1,2,3 + if constexpr(VectorWidth0 == VectorWidth1) + { + return result; + } + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_col_mjr, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_row_mjr, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_row_mjr, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved5() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW0 = 1u + // VW1 = MaxVW + // datalayout = orthogonal + constexpr uint32_t VectorWidth0 = 1u; + constexpr uint32_t VectorWidth1 = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet>; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw = testMmaAccVW(); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + constexpr bool is_mma_acc_row_mjr = is_row_mjr && is_mma_dim && is_acc_vw; + constexpr bool is_mma_acc_col_mjr = is_col_mjr && is_mma_dim && is_acc_vw; + + bool result = true; + + // Case tested in #0,1,2,3 + if constexpr(VectorWidth0 == VectorWidth1) + { + return result; + } + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_row_mjr, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_col_mjr, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_acc_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, is_mma_acc_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_acc_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, is_mma_acc_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_acc_col_mjr, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved6() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = 1 + // MaxVW0 = 1 + // MaxVW1 = MaxVW + // datalayout = same + constexpr uint32_t VectorWidth = 1u; + constexpr uint32_t MaxVectorWidth0 = MaxVectorWidth == 1u ? 4u : 1u; + constexpr uint32_t MaxVectorWidth1 = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw0 = testMmaAccVW(); + constexpr bool is_acc_vw1 = testMmaAccVW(); + + bool result = true; + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved7() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = 1 + // MaxVW0 = 1 + // MaxVW1 = MaxVW + // datalayout = orthogonal + constexpr uint32_t VectorWidth = 1u; + constexpr uint32_t MaxVectorWidth0 = MaxVectorWidth == 1u ? 4u : 1u; + constexpr uint32_t MaxVectorWidth1 = MaxVectorWidth; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet>; + + constexpr bool is_mma_dim = testMmaDim(); + constexpr bool is_acc_vw0 = testMmaAccVW(); + constexpr bool is_acc_vw1 = testMmaAccVW(); + + bool result = true; + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_dim, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_dim, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_dim, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, (is_acc_vw0 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, (is_acc_vw0 && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, (is_acc_vw1 && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, (is_acc_vw1 && is_mma_dim), debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved8() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = same + // Different BlockDim / BlockK + constexpr uint32_t VectorWidth = MaxVectorWidth; + constexpr uint32_t BlockDim0 = BlockDim; + constexpr uint32_t BlockDim1 = BlockDim == 32u ? 64u : 32u; + constexpr uint32_t BlockK0 = BlockK; + constexpr uint32_t BlockK1 = BlockK == 32u ? 64u : 32u; + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved9() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = same + // Different size DataT + constexpr uint32_t VectorWidth = MaxVectorWidth; + using DataT0 = DataT; + using DataT1 = conditional_t< + sizeof(DataT) == 1u, + int16_t, + conditional_t>>>; + + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + // Already checked same types + if constexpr(is_same_v) + { + return result; + } + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestNonInterleaved10() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // VW = MaxVW + // datalayout = same + // Same size DataT + constexpr uint32_t VectorWidth = MaxVectorWidth; + using DataT0 = DataT; + using DataT1 = conditional_t< + sizeof(DataT) == 1u, + int8_t, + conditional_t>>>; + + using Set0 = RegisterLayoutTestingSet; + using Set1 = RegisterLayoutTestingSet; + + bool result = true; + + // Already tested same type + if constexpr(is_same_v) + { + return result; + } + + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool testBarrageNonInterleaved() + { + bool result = true; + + // clang-format off + result &= matrixLayoutTraitsTestNonInterleaved0(); + result &= registerLayoutTraitsTestNonInterleaved0(); + result &= registerLayoutTraitsTestNonInterleaved1(); + result &= registerLayoutTraitsTestNonInterleaved2(); + result &= registerLayoutTraitsTestNonInterleaved3(); + result &= registerLayoutTraitsTestNonInterleaved4(); + result &= registerLayoutTraitsTestNonInterleaved5(); + result &= registerLayoutTraitsTestNonInterleaved6(); + result &= registerLayoutTraitsTestNonInterleaved7(); + result &= registerLayoutTraitsTestNonInterleaved8(); + result &= registerLayoutTraitsTestNonInterleaved9(); + result &= registerLayoutTraitsTestNonInterleaved10(); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestA() + { + // TODO: WaveCount + constexpr uint32_t WaveCount = 1u; + constexpr uint32_t BlockDim = BlockM; + constexpr uint32_t BlockK = BlockN; + constexpr uint32_t MaxVW = detail::MaxVWSelector::Result; + + bool result = true; + result &= dataLayoutTraitsTest(); + result &= testBarrageNonInterleaved(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestB() + { + // TODO: WaveCount + constexpr uint32_t WaveCount = 1u; + constexpr uint32_t BlockDim = BlockN; + constexpr uint32_t BlockK = BlockM; + constexpr uint32_t MaxVW = rocwmma::detail:: + MaxVWSelector::Result; + + bool result = true; + result &= dataLayoutTraitsTest(); + result &= testBarrageNonInterleaved(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestAcc() + { + // TODO: WaveCount + constexpr uint32_t WaveCount = 1u; + constexpr uint32_t BlockDim = BlockN; + constexpr uint32_t BlockK = BlockM; + constexpr uint32_t MaxVW = rocwmma::detail:: + MaxVWSelector::Result; + + bool result = true; + result &= dataLayoutTraitsTest(); + result &= testBarrageNonInterleaved(); + + return result; + } + + template + __global__ void layoutTraitsTest(uint32_t m, + uint32_t n, + DataT const* in, + DataT* out, + uint32_t ld, + DataT param1, + DataT param2) + { + __shared__ int32_t result; + result = 0; + synchronize_workgroup(); + + bool success = true; + + success &= layoutTraitsTestA(); + success &= layoutTraitsTestB(); + success &= layoutTraitsTestAcc(); + + // Reduce error count + atomicAdd(&result, (int32_t)success); + + // Wait for all threads + synchronize_workgroup(); + + // Just need one thread to update output + if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 + && blockIdx.y == 0 && blockIdx.z == 0) + { + out[0] = static_cast(result == 0 ? ERROR_VALUE : SUCCESS_VALUE); + } + } + +} // namespace rocwmma + +#endif // ROCWMMA_DEVICE_LAYOUT_TRAITS_TEST_HPP diff --git a/test/unit/layout_traits_test/device/layout_traits_int.hpp b/test/unit/layout_traits_test/device/layout_traits_int.hpp new file mode 100644 index 00000000..acd69f54 --- /dev/null +++ b/test/unit/layout_traits_test/device/layout_traits_int.hpp @@ -0,0 +1,1605 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_DEVICE_LAYOUT_TRAITS_INT_TEST_HPP +#define ROCWMMA_DEVICE_LAYOUT_TRAITS_INT_TEST_HPP + +#include + +#include "unit_test_traits.hpp" + +static constexpr uint32_t ERROR_VALUE = 7; +static constexpr uint32_t SUCCESS_VALUE = 0; + +namespace rocwmma +{ + template + ROCWMMA_HOST bool + testLayoutPair(const char* file, const char* line, std::ostream& stream = std::cout) + { + constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; + constexpr bool is_layout_orthogonal_result + = rocwmma::is_layout_orthogonal_v; + constexpr bool compare_result = ((is_layout_same_result == ExpectSame) + && (is_layout_orthogonal_result == ExpectOrthogonal)); + + if constexpr(DebugOnFail) + { + stream << "File: " << file << " L:" << line << std::endl; + stream << "" << std::endl; + stream << "Lhs: " << LayoutLhs{} << std::endl; + stream << rocwmma::layout_traits{}; + stream << "Rhs: " << LayoutRhs{} << std::endl; + stream << rocwmma::layout_traits{}; + stream << "is_layout_same: " << is_layout_same_result << " Expected: " << ExpectSame + << std::endl; + stream << "is_layout_orthogonal: " << is_layout_orthogonal_result + << " Expected: " << ExpectOrthogonal << std::endl; + stream << "Result:" << (compare_result ? "PASS" : "FAIL") << std::endl; + stream << "" << std::endl; + } + + return compare_result; + } + + ROCWMMA_DEVICE inline bool isFirstThread() + { + return (threadIdx.x == 0) && (threadIdx.y == 0) && (threadIdx.z == 0) && (blockIdx.x == 0) + && (blockIdx.y == 0) && (blockIdx.z == 0); + } + + template + ROCWMMA_DEVICE constexpr void debugRegisterFormats() + { + if(isFirstThread()) + { + using traits_lhs = layout_traits; + using traits_rhs = layout_traits; + printf("testCompatibleRegisterParams: %d\n", LayoutTraits_impl::testCompatibleRegisterParams()); + printf("MmaDim: %d, MmaDim: %d\n", traits_lhs::MmaDim, traits_rhs::MmaDim); + printf("DataFormat: %d, DataFormat: %d\n", (int)traits_lhs::Format, (int)traits_rhs::Format); + printf("is_valid: %d, is_valid: %d\n", traits_lhs::is_valid, traits_rhs::is_valid); + printf("is_same_dataT: %d\n", is_same_v); + } + } + + template + ROCWMMA_DEVICE bool testLayoutPair(const char* file, uint32_t line) + { + constexpr bool is_layout_same_result = rocwmma::is_layout_same_v; + constexpr bool is_layout_orthogonal_result + = rocwmma::is_layout_orthogonal_v; + constexpr bool compare_result = ((is_layout_same_result == ExpectSame) + && (is_layout_orthogonal_result == ExpectOrthogonal)); + + if(!compare_result && DebugOnFail && isFirstThread()) + { + printf("File: %s L:%d\n", file, line); + printf("\n"); + printf("is_layout_same: %d (Expected: %d)\n", is_layout_same_result, ExpectSame); + printf("is_layout_orthogonal: %d (Expected: %d)\n", + is_layout_orthogonal_result, + ExpectOrthogonal); + printf("%s\n", (compare_result ? "PASS" : "FAIL")); + printf("\n"); + } + + return compare_result; + } + +#define ROCWMMA_TEST_LAYOUT_TRAITS_PAIR( \ + LayoutLhs, LayoutRhs, ExpectSame, ExpectOrthogonal, DebugOnFail) \ + testLayoutPair(__FILE__, \ + __LINE__); + + template + struct RegisterLayoutIntTestingSet + { + using ColInline = RegisterLayout::Storage< + MatrixLayout::ColInlineInt, + DataLayoutT>; + using ColOrtho = RegisterLayout::Storage< + MatrixLayout::ColOrthoInt, + DataLayoutT>; + using RowInline = RegisterLayout::Storage< + MatrixLayout::RowInlineInt, + DataLayoutT>; + using RowOrtho = RegisterLayout::Storage< + MatrixLayout::RowOrthoInt, + DataLayoutT>; + + using MmaInput = RegisterLayout::MmaInput; + using MmaAcc = RegisterLayout::MmaAcc; + }; + + template + using MatrixLayout_t = typename layout_traits::MatrixLayout; + + template + ROCWMMA_DEVICE bool matrixLayoutTraitsTestInterleaved0() + { + constexpr bool debug_on_fail = true; + + // Testing MatrixLayout properties + // MatrixLayouts are invariant to vector width + using Set + = RegisterLayoutIntTestingSet; + + bool result = true; + + // Matrix <-> Matrix layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, true, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(MatrixLayout_t, MatrixLayout_t, true, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE constexpr bool testRowMajor() + { + return is_layout_same_v; + } + + template + ROCWMMA_DEVICE constexpr bool testColMajor() + { + return is_layout_same_v; + } + + template + ROCWMMA_DEVICE constexpr bool testMmaDim() + { + return (MmaDim == 16u && (bool)ROCWMMA_BLOCK_DIM_16_SUPPORTED) + || (MmaDim == 32u && (bool)ROCWMMA_BLOCK_DIM_32_SUPPORTED + && !is_same_v); + } + + template + ROCWMMA_DEVICE constexpr uint32_t dimPerThread() + { + return BlockDim / MmaDim; + } + + template + ROCWMMA_DEVICE constexpr uint32_t kPerThread() + { + return BlockK * MmaDim / (WaveSize * SplitK); + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved0() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = same + // Checks identity quirk condition + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt_eq_1 = (dpt == 1u); + constexpr bool is_kpt_eq_1 = (kpt == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk = is_dpt_eq_1 || is_kpt_eq_1; + + constexpr bool is_row_mjr_dpt_1 = is_row_mjr && is_dpt_eq_1; + constexpr bool is_row_mjr_kpt_1 = is_row_mjr && is_kpt_eq_1; + constexpr bool is_col_mjr_dpt_1 = is_col_mjr && is_dpt_eq_1; + constexpr bool is_col_mjr_kpt_1 = is_col_mjr && is_kpt_eq_1; + + bool result = true; + + if constexpr(!is_id_quirk) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, (is_row_mjr_dpt_1 || is_col_mjr_kpt_1 ), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, is_kpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, (is_row_mjr || (is_dpt_eq_1 && is_kpt_eq_1)) , false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, (is_row_mjr_dpt_1 || is_col_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, (is_col_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, is_dpt_eq_1, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, is_kpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, (is_col_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, (is_col_mjr_dpt_1 || is_row_mjr_kpt_1), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, (is_row_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, is_dpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, (is_col_mjr_dpt_1 || is_row_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false,((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false,((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved1() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = same + // Checks non-identity quirk condition + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk = (dpt == 1u) || (kpt == 1u); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + + bool result = true; + + if constexpr(is_id_quirk) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, is_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, true, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved2() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = orthogonal + // MmaDim = same + // SplitK = same + // Checks identity quirk condition + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 = RegisterLayoutIntTestingSet, + MmaDim, + SplitK>; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt_eq_1 = (dpt == 1u); + constexpr bool is_kpt_eq_1 = (kpt == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk = is_dpt_eq_1 || is_kpt_eq_1; + + constexpr bool is_row_mjr_dpt_1 = is_row_mjr && is_dpt_eq_1; + constexpr bool is_row_mjr_kpt_1 = is_row_mjr && is_kpt_eq_1; + constexpr bool is_col_mjr_dpt_1 = is_col_mjr && is_dpt_eq_1; + constexpr bool is_col_mjr_kpt_1 = is_col_mjr && is_kpt_eq_1; + + bool result = true; + + if constexpr(!is_id_quirk) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, is_kpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, (is_col_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true , false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, (is_col_mjr_dpt_1 || is_row_mjr_kpt_1), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, (is_row_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, is_dpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, (is_col_mjr_dpt_1 || is_row_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); // Can be invalid in same way + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, (is_row_mjr_dpt_1 || is_col_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, is_kpt_eq_1, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, (is_row_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, (is_row_mjr_dpt_1 || is_col_mjr_kpt_1), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, (is_col_mjr || (is_dpt_eq_1 && is_kpt_eq_1)), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, is_dpt_eq_1, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), false, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, ((is_col_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, ((is_row_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, ((is_row_mjr || is_kpt_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, ((is_col_mjr || is_dpt_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved3() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = orthogonal + // MmaDim = same + // SplitK = same + // Checks non-identity quirk condition + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 = RegisterLayoutIntTestingSet, + MmaDim, + SplitK>; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt_eq_1 = (dpt == 1u); + constexpr bool is_kpt_eq_1 = (kpt == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk = is_dpt_eq_1 || is_kpt_eq_1; + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + + bool result = true; + + if constexpr(is_id_quirk) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, is_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, true, false, debug_on_fail); // Can be invalid in same way + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, is_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, true, false, debug_on_fail); // Can be invalid in same way + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, is_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_col_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved4() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = different + // BlockK = different + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = same + constexpr uint32_t BlockDim0 = BlockDim; + constexpr uint32_t BlockDim1 = BlockDim == 32u ? 64u : 32u; + constexpr uint32_t BlockK0 = BlockK; + constexpr uint32_t BlockK1 = BlockK == 32u ? 64u : 32u; + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt0_eq_1 = (dpt0 == 1u); + constexpr bool is_kpt0_eq_1 = (kpt0 == 1u); + constexpr bool is_dpt1_eq_1 = (dpt1 == 1u); + constexpr bool is_kpt1_eq_1 = (kpt1 == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk0 = (is_dpt0_eq_1 || is_kpt0_eq_1); + constexpr bool is_id_quirk1 = (is_dpt1_eq_1 || is_kpt1_eq_1); + + bool result = true; + + // Current test case deals with quirks validation + if constexpr((is_id_quirk0 != is_id_quirk1) || !is_id_quirk0) + { + return result; + } + + // Ensure MmaDim layout constraints are met + if constexpr(BlockDim0 >= MmaDim && BlockDim1 >= MmaDim) + { + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // Same MmaDim + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, ((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, ((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, ((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, ((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved5() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = different + // BlockK = different + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = same + constexpr uint32_t BlockDim0 = BlockDim; + constexpr uint32_t BlockDim1 = BlockDim == 32u ? 64u : 32u; + constexpr uint32_t BlockK0 = BlockK; + constexpr uint32_t BlockK1 = BlockK == 32u ? 64u : 32u; + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt0_eq_1 = (dpt0 == 1u); + constexpr bool is_kpt0_eq_1 = (kpt0 == 1u); + constexpr bool is_dpt1_eq_1 = (dpt1 == 1u); + constexpr bool is_kpt1_eq_1 = (kpt1 == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk0 = (is_dpt0_eq_1 || is_kpt0_eq_1); + constexpr bool is_id_quirk1 = (is_dpt1_eq_1 || is_kpt1_eq_1); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + + bool result = true; + + // Other test case deals with quirks validation + if constexpr((is_id_quirk0 != is_id_quirk1) || is_id_quirk0) + { + return result; + } + + // Ensure MmaDim layout constraints are met + if constexpr(BlockDim0 >= MmaDim && BlockDim1 >= MmaDim) + { + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // Same MmaDim + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved6() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = different, same size + // DataLayoutT = same + // MmaDim = same + // SplitK = same + using DataT0 = DataT; + using DataT1 = conditional_t< + sizeof(DataT) == 1u, + int8_t, + conditional_t>>>; + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + bool result = true; + + // Already checked same types + if constexpr(is_same_v) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved7() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = different, different size + // DataLayoutT = same + // MmaDim = same + // SplitK = same + using DataT0 = DataT; + using DataT1 = conditional_t< + sizeof(DataT) == 1u, + int16_t, + conditional_t>>>; + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + bool result = true; + + // Already checked same types + if constexpr(is_same_v) + { + return result; + } + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved8() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = different, valid + // SplitK = same + constexpr uint32_t MmaDim0 = MmaDim; + constexpr uint32_t MmaDim1 = MmaDim == 16 ? 32u : 16u; + + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + bool result = true; + + // Ensure MmaDim layout constraints are met + if constexpr(dpt0 > 0u && kpt0 > 0u && dpt1 > 0u && kpt1 > 0u) + { + using Set0 = RegisterLayoutIntTestingSet; + using Set1 = RegisterLayoutIntTestingSet; + + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved9() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = different + constexpr uint32_t SplitK0 = SplitK; + constexpr uint32_t SplitK1 = SplitK == 4u ? 2u : 4u; + + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt0_eq_1 = (dpt0 == 1u); + constexpr bool is_kpt0_eq_1 = (kpt0 == 1u); + constexpr bool is_dpt1_eq_1 = (dpt1 == 1u); + constexpr bool is_kpt1_eq_1 = (kpt1 == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk0 = (is_dpt0_eq_1 || is_kpt0_eq_1); + constexpr bool is_id_quirk1 = (is_dpt1_eq_1 || is_kpt1_eq_1); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + + bool result = true; + + // Current test deals with quirk case validation + if constexpr((is_id_quirk0 != is_id_quirk1) || !is_id_quirk0) + { + return result; + } + + // Ensure layout requirements are satisfied + if constexpr(dpt0 > 0u && kpt0 > 0u && dpt1 > 0u && kpt1 > 0u) + { + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, ((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, ((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, ((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, ((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), false, debug_on_fail); + } + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, ((is_row_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false,((is_col_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, ((is_col_mjr || is_kpt0_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false,((is_row_mjr || is_dpt0_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, ((is_row_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false,((is_col_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, ((is_col_mjr || is_kpt1_eq_1) && is_mma_dim), debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false,((is_row_mjr || is_dpt1_eq_1) && is_mma_dim), debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool registerLayoutTraitsTestInterleaved10() + { + constexpr bool debug_on_fail = true; + + // Non-interleaved + // BlockDim = same + // BlockK = same + // DataT = same + // DataLayoutT = same + // MmaDim = same + // SplitK = different + constexpr uint32_t SplitK0 = SplitK; + constexpr uint32_t SplitK1 = SplitK == 4u ? 2u : 4u; + + using Set0 + = RegisterLayoutIntTestingSet; + using Set1 + = RegisterLayoutIntTestingSet; + + constexpr bool is_row_mjr = testRowMajor(); + constexpr bool is_col_mjr = testColMajor(); + constexpr bool is_mma_dim = testMmaDim(); + constexpr uint32_t dpt0 = dimPerThread(); + constexpr uint32_t kpt0 = kPerThread(); + constexpr uint32_t dpt1 = dimPerThread(); + constexpr uint32_t kpt1 = kPerThread(); + + // VW tests for quirks + constexpr bool is_dpt0_eq_1 = (dpt0 == 1u); + constexpr bool is_kpt0_eq_1 = (kpt0 == 1u); + constexpr bool is_dpt1_eq_1 = (dpt1 == 1u); + constexpr bool is_kpt1_eq_1 = (kpt1 == 1u); + + // Identity quirk where interleaved register layouts match + // regardless of their formats (e.g., AOS / SOA) + constexpr bool is_id_quirk0 = (is_dpt0_eq_1 || is_kpt0_eq_1); + constexpr bool is_id_quirk1 = (is_dpt1_eq_1 || is_kpt1_eq_1); + + constexpr bool is_mma_row_mjr = is_row_mjr && is_mma_dim; + constexpr bool is_mma_col_mjr = is_col_mjr && is_mma_dim; + + bool result = true; + + // Other test handles quirk case validation + if constexpr((is_id_quirk0 != is_id_quirk1) || is_id_quirk0) + { + return result; + } + + // Ensure layout requirements are satisfied + if constexpr(dpt0 > 0u && kpt0 > 0u && dpt1 > 0u && kpt1 > 0u) + { + // Storage <-> storage layout + // clang-format off + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::ColInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::ColInline, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline,typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowOrtho, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowOrtho, false, false, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::RowInline, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::RowInline, false, false, debug_on_fail); + + // Storage <-> mma layouts + // gfx11 have unique mma formats that must always be transformed from storage + if constexpr ((bool)ROCWMMA_ARCH_GFX11) + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } + // other targets have mma formats that may overlap storage formats + else + { + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaInput, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaInput, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaInput, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaInput, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColOrtho, is_mma_row_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowOrtho, is_mma_col_mjr, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + } + + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColOrtho, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::ColInline, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowOrtho, typename Set1::MmaAcc, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::RowInline, typename Set1::MmaAcc, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColOrtho, false, is_mma_row_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::ColInline, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowOrtho, false, is_mma_col_mjr, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::RowInline, false, is_mma_row_mjr, debug_on_fail); + + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaInput, typename Set1::MmaAcc, false, false, debug_on_fail); + result &= ROCWMMA_TEST_LAYOUT_TRAITS_PAIR(typename Set0::MmaAcc, typename Set1::MmaInput, false, false, debug_on_fail); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool testBarrageInterleaved() + { + bool result = true; + + constexpr uint32_t dpt = dimPerThread(); + constexpr uint32_t kpt = kPerThread(); + + // Must satisfy layout requirement + if constexpr(dpt > 0u && kpt > 0u && BlockK >= kpt) + { + // clang-format off + result &= matrixLayoutTraitsTestInterleaved0(); + result &= registerLayoutTraitsTestInterleaved0(); + result &= registerLayoutTraitsTestInterleaved1(); + result &= registerLayoutTraitsTestInterleaved2(); + result &= registerLayoutTraitsTestInterleaved3(); + result &= registerLayoutTraitsTestInterleaved4(); + result &= registerLayoutTraitsTestInterleaved5(); + result &= registerLayoutTraitsTestInterleaved6(); + result &= registerLayoutTraitsTestInterleaved7(); + result &= registerLayoutTraitsTestInterleaved8(); + result &= registerLayoutTraitsTestInterleaved9(); + result &= registerLayoutTraitsTestInterleaved10(); + // clang-format on + } + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestA() + { + constexpr uint32_t BlockDim = BlockM; + constexpr uint32_t BlockK = BlockN; + + bool result = true; + result &= testBarrageInterleaved(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestB() + { + constexpr uint32_t BlockDim = BlockN; + constexpr uint32_t BlockK = BlockM; + + bool result = true; + result &= testBarrageInterleaved(); + + return result; + } + + template + ROCWMMA_DEVICE bool layoutTraitsTestAcc() + { + // TODO: WaveCount + constexpr uint32_t BlockDim = BlockN; + constexpr uint32_t BlockK = BlockM; + + bool result = true; + result &= testBarrageInterleaved(); + + return result; + } + + template + __global__ void layoutTraitsIntTest(uint32_t m, + uint32_t n, + DataT const* in, + DataT* out, + uint32_t ld, + DataT param1, + DataT param2) + { + __shared__ int32_t result; + result = 0; + synchronize_workgroup(); + + bool success = true; + + success &= layoutTraitsTestA(); + success &= layoutTraitsTestB(); + success &= layoutTraitsTestAcc(); + + // Reduce error count + atomicAdd(&result, (int32_t)success); + + // Wait for all threads + synchronize_workgroup(); + + // Just need one thread to update output + if(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && blockIdx.x == 0 + && blockIdx.y == 0 && blockIdx.z == 0) + { + out[0] = static_cast(result == 0 ? 7 : 0); + } + } + +} // namespace rocwmma + +#endif // ROCWMMA_DEVICE_LAYOUT_TRAITS_TEST_HPP diff --git a/test/unit/layout_traits_test/test/common_includes.hpp b/test/unit/layout_traits_test/test/common_includes.hpp new file mode 100644 index 00000000..a534adab --- /dev/null +++ b/test/unit/layout_traits_test/test/common_includes.hpp @@ -0,0 +1,36 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#ifndef ROCWMMA_LAYOUT_TRAITS_TEST_COMMON_INCLUDES_HPP +#define ROCWMMA_LAYOUT_TRAITS_TEST_COMMON_INCLUDES_HPP + +#include "kernel_generator.hpp" +#include "unit_test.hpp" +#include "unit_test_macros.hpp" +#include "unit_test_params.hpp" +#include "unit_test_traits.hpp" + +#endif // ROCWMMA_LAYOUT_TRAITS_TEST_COMMON_INCLUDES_HPP diff --git a/test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits.cpp b/test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits.cpp new file mode 100644 index 00000000..202b724b --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits.cpp @@ -0,0 +1,77 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = std::tuple, I<16u>>, + std::tuple, I<32u>>, + std::tuple, I<64u>>, + std::tuple, I<128u>>>; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationExtendedLayoutTraitsTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits_int.cpp b/test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits_int.cpp new file mode 100644 index 00000000..165a8b08 --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/extendedtest_layout_traits_int.cpp @@ -0,0 +1,81 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits_int.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32u>>; + using SplitKs = std::tuple, I<2u>, I<4u>>; + using BlockSizes = std::tuple, I<16u>>, + std::tuple, I<32u>>, + std::tuple, I<64u>>, + std::tuple, I<128u>>>; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationExtendedLayoutTraitsIntTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits.cpp b/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits.cpp new file mode 100644 index 00000000..dbd84026 --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits.cpp @@ -0,0 +1,76 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = std::tuple, I<16u>>, + std::tuple, I<32u>>, + std::tuple, I<64u>>>; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationRegressionLayoutTraitsTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits_int.cpp b/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits_int.cpp new file mode 100644 index 00000000..71d30a9e --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/regressiontest_layout_traits_int.cpp @@ -0,0 +1,80 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits_int.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32u>>; + using SplitKs = std::tuple, I<2u>, I<4u>>; + using BlockSizes = std::tuple, I<16u>>, + std::tuple, I<32u>>, + std::tuple, I<64u>>>; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationRegressionLayoutTraitsIntTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits.cpp b/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits.cpp new file mode 100644 index 00000000..603dfdb8 --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = std::tuple, I<16u>>, std::tuple, I<32u>>>; + using DataLayouts = std::tuple; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationSmokeLayoutTraitsTest, TestParams); diff --git a/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits_int.cpp b/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits_int.cpp new file mode 100644 index 00000000..602a2a3d --- /dev/null +++ b/test/unit/layout_traits_test/test/emulation/smoketest_layout_traits_int.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "detail/layout_traits_int.hpp" +#include "test/common_includes.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple>; + using SplitKs = std::tuple>; + using BlockSizes = std::tuple, I<16u>>, std::tuple, I<32u>>>; + using DataLayouts = std::tuple; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(EmulationSmokeLayoutTraitsIntTest, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_128.cpp b/test/unit/layout_traits_test/test/layout_traits_128.cpp new file mode 100644 index 00000000..4f0f580b --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_128.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes128; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest128, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_16.cpp b/test/unit/layout_traits_test/test/layout_traits_16.cpp new file mode 100644 index 00000000..d97bca1a --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_16.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes16; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest16, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_256.cpp b/test/unit/layout_traits_test/test/layout_traits_256.cpp new file mode 100644 index 00000000..1917a999 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_256.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes256; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest256, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_32.cpp b/test/unit/layout_traits_test/test/layout_traits_32.cpp new file mode 100644 index 00000000..58c4b016 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_32.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes32; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest32, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_64.cpp b/test/unit/layout_traits_test/test/layout_traits_64.cpp new file mode 100644 index 00000000..e961e12d --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_64.cpp @@ -0,0 +1,74 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using BlockSizes = typename Base::TestBlockSizes64; + using DataLayouts = typename Base::TestLayoutsAll; + using KernelParams = typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsTest64, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_128.cpp b/test/unit/layout_traits_test/test/layout_traits_int_128.cpp new file mode 100644 index 00000000..459ff6a8 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_128.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits_int.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32>, I<64>>; + using SplitKs = std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes128; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest128, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_16.cpp b/test/unit/layout_traits_test/test/layout_traits_int_16.cpp new file mode 100644 index 00000000..5e708a5c --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_16.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits_int.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple>; + using SplitKs = std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes16; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest16, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_256.cpp b/test/unit/layout_traits_test/test/layout_traits_int_256.cpp new file mode 100644 index 00000000..a50a44bf --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_256.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits_int.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32>, I<64>>; + using SplitKs = std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes256; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest256, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_32.cpp b/test/unit/layout_traits_test/test/layout_traits_int_32.cpp new file mode 100644 index 00000000..5f222865 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_32.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits_int.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32>>; + using SplitKs = std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes32; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest32, TestParams); diff --git a/test/unit/layout_traits_test/test/layout_traits_int_64.cpp b/test/unit/layout_traits_test/test/layout_traits_int_64.cpp new file mode 100644 index 00000000..3de30ed2 --- /dev/null +++ b/test/unit/layout_traits_test/test/layout_traits_int_64.cpp @@ -0,0 +1,78 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ + +#include + +#include "common_includes.hpp" +#include "detail/layout_traits_int.hpp" + +namespace rocwmma +{ + + struct TestParams : public UnitTestParams + { + using Base = UnitTestParams; + using Types = typename Base::TestAllSizeTypes; + using MmaDims = std::tuple, I<32>, I<64>>; + using SplitKs = std::tuple, I<2>, I<4>>; + using BlockSizes = typename Base::TestBlockSizes64; + using DataLayouts = typename Base::TestLayoutsAll; + + using KernelParams = + typename CombineLists::Result; + + // Assemble the kernel generator + using GeneratorImpl = LayoutTraitsIntGenerator; + using KernelGenerator = KernelGenerator; + + // Sanity check for kernel generator + static_assert(std::is_same::value, + "Kernels from this generator do not match testing interface"); + + static inline typename KernelGenerator::ResultT kernels() + { + return KernelGenerator::generate(); + } + + static inline std::vector threadBlocks() + { + auto warpSize = HipDevice::instance()->warpSize(); + // clang-format off + return { {warpSize, 1} }; + // clang-format on + } + + static inline std::vector problemSizes() + { + // clang-format off + return { {1024, 1024} }; + // clang-format on + } + }; + +} // namespace rocwmma + +ROCWMMA_GENERATE_UNIT_GTEST_SUITE(LayoutTraitsIntTest64, TestParams); diff --git a/test/unit/tuple_test/device/tuple.hpp b/test/unit/tuple_test/device/tuple.hpp index cce72d5a..8de10318 100644 --- a/test/unit/tuple_test/device/tuple.hpp +++ b/test/unit/tuple_test/device/tuple.hpp @@ -2,7 +2,7 @@ * * MIT License * - * Copyright (C) 2021-2024 Advanced Micro Devices, Inc. + * Copyright (C) 2021-2025 Advanced Micro Devices, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal @@ -28,7 +28,6 @@ #define ROCWMMA_DEVICE_TUPLE_TEST_HPP #include -#include #include #include #include