Skip to content

Commit

Permalink
Merge pull request #472 from cgmillette/interleave-dev
Browse files Browse the repository at this point in the history
Interleave dev
  • Loading branch information
cgmillette authored Dec 18, 2024
2 parents ada4c03 + 19af315 commit c831ff5
Show file tree
Hide file tree
Showing 61 changed files with 9,270 additions and 1,797 deletions.
15 changes: 10 additions & 5 deletions library/include/rocwmma/internal/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand Down Expand Up @@ -146,6 +146,11 @@ static_assert(0, "Unsupported architecture");
#define ROCWMMA_ARCH_GFX94X 1
#endif

#if ROCWMMA_ARCH_HOST
#define ROCWMMA_BLOCK_DIM_16_SUPPORTED 1
#define ROCWMMA_BLOCK_DIM_32_SUPPORTED 1
#endif

#if !defined(ROCWMMA_ARCH_GFX9)
#define ROCWMMA_ARCH_GFX9 0
#endif
Expand Down Expand Up @@ -201,10 +206,10 @@ static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DI
#endif

#if ROCWMMA_ARCH_GFX12
static_assert((bool)(ROCWMMA_WAVE32_MODE) && !(bool)(ROCWMMA_WAVE64_MODE),
"rocWMMA supports only wave32 for gfx12 arch");
static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DIM_32_SUPPORTED),
"rocWMMA supports only block size of 16 for gfx12 arch");
static_assert((bool)(ROCWMMA_WAVE32_MODE) && !(bool)(ROCWMMA_WAVE64_MODE),
"rocWMMA supports only wave32 for gfx12 arch");
static_assert((bool)(ROCWMMA_BLOCK_DIM_16_SUPPORTED) && !(bool)(ROCWMMA_BLOCK_DIM_32_SUPPORTED),
"rocWMMA supports only block size of 16 for gfx12 arch");
#endif

///
Expand Down
12 changes: 11 additions & 1 deletion library/include/rocwmma/internal/coop_io_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,7 @@
#ifndef ROCWMMA_COOP_IO_CONFIG_HPP
#define ROCWMMA_COOP_IO_CONFIG_HPP

#include "./layout/register_layout_transforms.hpp"
#include "coop_load.hpp"
#include "coop_store.hpp"
#include "io_layout.hpp"
Expand Down Expand Up @@ -85,6 +86,15 @@ namespace rocwmma
typename IOLayout::MatrixLayout,
IOLayout::VW>;

using PostLoadXForm = register_layout_transform<typename IOLayout::StorageLayout,
typename IOLayout::FragmentLayout>;

using PreMmaXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::MmaLayout>;

using PreStoreXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::StorageLayout>;

using Storer = CooperativeStore<IOShape::BlockDim,
IOShape::KDim,
DataT,
Expand Down
1 change: 0 additions & 1 deletion library/include/rocwmma/internal/coop_load.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#define ROCWMMA_COOP_LOAD_HPP

#include "io_traits.hpp"
#include "layout.hpp"
#include "opaque_load.hpp"
#include "types.hpp"
#include "utils.hpp"
Expand Down
1 change: 0 additions & 1 deletion library/include/rocwmma/internal/coop_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
#define ROCWMMA_COOP_STORE_HPP

#include "io_traits.hpp"
#include "layout.hpp"
#include "opaque_store.hpp"
#include "types.hpp"
#include "utils.hpp"
Expand Down
32 changes: 27 additions & 5 deletions library/include/rocwmma/internal/io_config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*
* MIT License
*
* Copyright (C) 2021-2024 Advanced Micro Devices, Inc. All rights reserved.
* Copyright (C) 2021-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
Expand All @@ -26,6 +26,7 @@
#ifndef ROCWMMA_IO_CONFIG_HPP
#define ROCWMMA_IO_CONFIG_HPP

#include "./layout/register_layout_transforms.hpp"
#include "broadcast.hpp"
#include "coop_load.hpp"
#include "coop_store.hpp"
Expand All @@ -37,7 +38,6 @@

namespace rocwmma
{

/**
* \defgroup Rocwmma_ioconf ROCWMMA IOConfig
* @brief ROCWMMA fragment input and output configurations
Expand Down Expand Up @@ -88,6 +88,21 @@ namespace rocwmma
typename IOLayout::MatrixLayout,
IOLayout::VW>;

using PostLoadXForm = register_layout_transform<typename IOLayout::StorageLayout,
typename IOLayout::FragmentLayout>;

using PreMmaXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::MmaLayout>;

// Currently, only makes sense to have a post-mma transform on acc layouts
using PostMmaXForm = conditional_t<is_same_v<MatrixT, accumulator>,
register_layout_transform<typename IOLayout::MmaLayout,
typename IOLayout::FragmentLayout>,
register_layout_transform_nop>;

using PreStoreXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::StorageLayout>;

using Storer = OpaqueStore<IOShape::BlockDim,
IOShape::KDim,
DataT,
Expand All @@ -106,10 +121,17 @@ namespace rocwmma
template <uint32_t BlockM, uint32_t BlockN, uint32_t BlockK, typename DataT>
struct IOConfig<accumulator, BlockM, BlockN, BlockK, DataT, void>
{
using IOShape = IOShape<accumulator, BlockM, BlockN, BlockK>;
using IOTraits = IOTraits<IOShape::BlockDim, IOShape::KDim, DataT>;
using PackUtil = PackUtil<DataT>;
using IOShape = IOShape<accumulator, BlockM, BlockN, BlockK>;
using IOLayout = IOLayout<accumulator, IOShape::BlockDim, IOShape::KDim, DataT, void, 1u>;
using IOTraits = IOTraits<IOShape::BlockDim, IOShape::KDim, DataT>;
using PackUtil = PackUtil<DataT>;
using Broadcaster = Broadcast<DataT, IOTraits::UnpackedSize>;

using PreMmaXForm = register_layout_transform<typename IOLayout::FragmentLayout,
typename IOLayout::MmaLayout>;

using PostMmaXForm = register_layout_transform<typename IOLayout::MmaLayout,
typename IOLayout::FragmentLayout>;
};
/** @}*/

Expand Down
Loading

0 comments on commit c831ff5

Please sign in to comment.