diff --git a/examples/09_gate_recurrent_unit/kernel_func.hpp b/examples/09_gate_recurrent_unit/kernel_func.hpp index 0dd76188a..dd5fb4abc 100644 --- a/examples/09_gate_recurrent_unit/kernel_func.hpp +++ b/examples/09_gate_recurrent_unit/kernel_func.hpp @@ -86,7 +86,7 @@ struct fused_config_t { {start_x_b, start_y_b}); \ gemm_args.init(mem_desc_a, mem_desc_b, inner_loop_count_##id); \ op(g, matAcc_##acc_id, gemm_args); \ - SW_BARRIER(); + sw_barrier(); #define MATC_STORE(ptr_c) \ mem_desc_c.init( \ @@ -229,7 +229,7 @@ struct gru_layer { int start_n = (j)*wg_tile_n; CONFIG_SETTING(batch_size, -1, hidden_size); matAcc_0.init(0); - SW_BARRIER(); + sw_barrier(); // calculate reset gate: r_t = \sigmoid(X_t x W_ir + h_{t - 1} x W_hr) // acc0 = X_t x W_ir @@ -278,19 +278,19 @@ struct gru_layer { matAcc_0.reg = matAcc_0.reg * (1 - matAcc_1.reg) + matAcc_1.reg * xetla_cvt(mat_hidden.reg); - SW_BARRIER(); + sw_barrier(); if (seq_id == seq_len - 1) { MATC_STORE(args->layer_output); - SW_BARRIER(); + sw_barrier(); __esimd_barrier(); } MATC_STORE(args->cell_out_ptr + seq_id * io_size); - SW_BARRIER(); + sw_barrier(); __esimd_barrier(); MATC_STORE(args->one_cell_ptr + (seq_id % 2) * io_size); - SW_BARRIER(); + sw_barrier(); __esimd_barrier(); } args->hx_ptr = args->one_cell_ptr + (seq_id % 2) * io_size; @@ -386,7 +386,7 @@ struct kernel_xcoder_gru_fusion { args.W_hz_ptr = (W_hz_ptr); args.W_in_ptr = (W_in_ptr); args.W_hn_ptr = (W_hn_ptr); - SW_BARRIER(); + sw_barrier(); fused_op::call(item, &args); ping = (ping + 1) % 2; pong = (pong + 1) % 2; @@ -411,7 +411,7 @@ struct kernel_xcoder_gru_fusion { ? hidden_out_ptr : (ping_pong_buffer + ping * one_layer_size); args.layer_ptr = ((ping_pong_buffer + pong * one_layer_size)); - SW_BARRIER(); + sw_barrier(); fused_op::call(item, &args); ping = (ping + 1) % 2; pong = (pong + 1) % 2; diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp index 315e5d01c..b6d14a148 100644 --- a/include/common/core/arch_config.hpp +++ b/include/common/core/arch_config.hpp @@ -35,16 +35,41 @@ template <> struct load_store_attr_t { /// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490 static constexpr bool has_hw_block_2d = true; + // If Transposed and Transformed are both set to false + // BlockHeight must not exceed 32. static constexpr uint32_t max_load_height_in_elem = 32; + + // BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for + // dwords, and 8 for qwords. static constexpr uint32_t max_load_width_in_bytes = 64; + + // If Transposed is true then + // BlockWidth must be 1,2,4 for qwords and be in range [1..8] for dwords. static constexpr uint32_t max_trans_load_width_in_bytes = 32; + + // BlockHeight must be 8 for qwords and be in range [1..32] for dwords. + static constexpr uint32_t max_trans_load_height_in_elem = 32; + + // If Transformed is true + // BlockWidth must be in range [4..16] for bytes and [2..16] for word. static constexpr uint32_t max_vnni_load_width_in_elems = 16; + + // BlockHeight must be in range [4..32] for bytes and [2..32] for words. static constexpr uint32_t min_vnni_load_height_in_bytes = 4; + // BlockHeight must not exceed 8. static constexpr uint32_t max_store_height_in_elem = 8; + + // BlockWidth must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8 + // for qwords. static constexpr uint32_t max_store_width_in_bytes = 64; + // BlockHeight must not exceed 32. + // BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for + // dwords, and 8 for qwords. static constexpr uint32_t max_load_size_in_bytes = 2048; + + // BlockWidth * BlockHeight * sizeof(T) must not exceed 512. static constexpr uint32_t max_store_size_in_bytes = 512; static constexpr uint32_t special_prefetch_width_in_bytes = 64; @@ -97,7 +122,7 @@ struct load_store_attr_t { static constexpr uint32_t max_aligned_load_vec_len = 256; static constexpr uint32_t max_store_vec_len = 256; static constexpr uint32_t max_aligned_store_vec_len = 256; - static constexpr uint32_t max_prefetch_vec_len = 32; + static constexpr uint32_t max_prefetch_vec_len = 256; static constexpr uint32_t max_channel_num = 16; }; diff --git a/include/common/core/barrier.hpp b/include/common/core/barrier.hpp index 4ad0eaa69..3956798ed 100644 --- a/include/common/core/barrier.hpp +++ b/include/common/core/barrier.hpp @@ -26,6 +26,16 @@ namespace gpu::xetla { /// @addtogroup xetla_core_barrier /// @{ +/// sw_barrier, insert software scheduling barrier, for better code control +/// + +void sw_barrier() { +#if __INTEL_LLVM_COMPILER >= 20250000 +#else + __ESIMD_NS::fence<__ESIMD_NS::fence_mask::sw_barrier>(); +#endif +} + /// @brief Initialize the number of named barrier index for a kernel. /// Available only on PVC. Only need to initialize once at the beginning. /// diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp index 93bedbfe0..e5ff7bf17 100644 --- a/include/common/core/memory.hpp +++ b/include/common/core/memory.hpp @@ -320,6 +320,53 @@ __XETLA_API void xetla_prefetch_global( #endif } +/// 2D USM pointer block prefetch. +/// Supported platforms: PVC +/// VISA instruction: lsc_load_block2d.ugm +/// +/// Prefetches elements located at specified address. +/// +/// @tparam T is element type. +/// @tparam BlockWidth is the block width in number of elements. +/// @tparam BlockHeight is the block height in number of elements. +/// @tparam NBlocks is the number of blocks. +/// @tparam L1H is L1 cache hint. +/// @tparam L2H is L2 cache hint. +/// @tparam N is the data size +/// @param Ptr is the surface base address for this operation. +/// @param SurfaceWidth is the surface width minus 1 in bytes +/// @param SurfaceHeight is the surface height minus 1 in rows +/// @param SurfacePitch is the surface pitch minus 1 in bytes +/// @param X is zero based X-coordinate of the left upper rectangle corner in +/// number of elements. +/// @param Y is zero based Y-coordinate of the left upper rectangle corner in +/// rows. +/// +template < + typename T, + int BlockWidth, + int BlockHeight = 1, + int NBlocks = 1, + cache_hint L1H = cache_hint::none, + cache_hint L2H = cache_hint::none, + int N = __ESIMD_ENS::detail::get_lsc_block_2d_data_size< + T, + NBlocks, + BlockHeight, + BlockWidth, + false, + false>()> +__XETLA_API void xetla_prefetch_global( + const T* Ptr, + unsigned SurfaceWidth, + unsigned SurfaceHeight, + unsigned SurfacePitch, + int X, + int Y) { + return __ESIMD_ENS::lsc_prefetch_2d( + Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y); +} + /// template /// void prefetch(const T *p, OffsetT byte_offset, @@ -358,14 +405,102 @@ __XETLA_API void xetla_prefetch_global(T* p, uint64_t byte_offset = 0) { #endif } +/// 2D USM pointer block load. +/// Supported platforms: PVC +/// VISA instruction: lsc_load_block2d.ugm +/// +/// Collects elements located at specified address and returns them +/// as a single \ref simd object. +/// +/// @tparam T is element type. +/// @tparam BlockWidth is the block width in number of elements. +/// @tparam BlockHeight is the block height in number of elements. +/// @tparam NBlocks is the number of blocks. +/// @tparam Transposed is the transposed version or not. +/// @tparam Transformed is apply VNNI transform or not. +/// @tparam L1H is L1 cache hint. +/// @tparam L2H is L2 cache hint. +/// @tparam N is the data size +/// @param Ptr is the surface base address for this operation. +/// @param SurfaceWidth is the surface width minus 1 in bytes +/// @param SurfaceHeight is the surface height minus 1 in rows +/// @param SurfacePitch is the surface pitch minus 1 in bytes +/// @param X is zero based X-coordinate of the left upper rectangle corner in +/// number of elements. +/// @param Y is zero based Y-coordinate of the left upper rectangle corner in +/// rows. +/// @return is a vector of type T and size N, where N is +/// BlockWidth * BlockHeight * NBlocks, if transformed; +/// otherwise, +/// N = roundUpNextMultiple(BlockHeight, 4 / sizeof(T)) * +/// getNextPowerOf2(BlockWidth) * NBlocks +/// +template < + typename T, + int BlockWidth, + int BlockHeight = 1, + int NBlocks = 1, + bool Transposed = false, + bool Transformed = false, + cache_hint L1H = cache_hint::none, + cache_hint L2H = cache_hint::none, + int N = __ESIMD_ENS::detail::get_lsc_block_2d_data_size< + T, + NBlocks, + BlockHeight, + BlockWidth, + Transposed, + Transformed>()> +__XETLA_API xetla_vector xetla_load_global( + const T* Ptr, + size_t SurfaceWidth, + size_t SurfaceHeight, + size_t SurfacePitch, + int X, + int Y) { + if constexpr (std::is_same_v) { + auto ret = xetla_load_global< + fp16, + BlockWidth, + BlockHeight, + NBlocks, + Transposed, + Transformed, + L1H, + L2H>( + reinterpret_cast(Ptr), + SurfaceWidth, + SurfaceHeight, + SurfacePitch, + X, + Y); + return ret.xetla_format(); + } else if constexpr (BlockWidth * sizeof(T) < sizeof(uint32_t)) { + xetla_vector byte_offsets = + xetla_vector_gen(0, SurfacePitch); + return xetla_load_global(Ptr, byte_offsets); + } else { + return __ESIMD_ENS::lsc_load_2d< + T, + BlockWidth, + BlockHeight, + NBlocks, + Transposed, + Transformed, + gpu::xetla::detail::get_cache_hint(L1H), + gpu::xetla::detail::get_cache_hint(L2H), + N>(Ptr, SurfaceWidth - 1, SurfaceHeight - 1, SurfacePitch - 1, X, Y); + } +} + /// simd block_load(const T* ptr, size_t byte_offset, /// props={}); // (usm-bl-2) /// This function loads a contiguous memory block from address referenced /// by USM pointer \p ptr and the given \p byte_offset. /// /// There may be temporary restrictions depending on L1, L2 cache hints, -/// See details in the 'Restrictions' section below. The restrictions will be -/// relaxed in the future. +/// See details in the 'Restrictions' section below. The restrictions will +/// be relaxed in the future. /// /// The parameter \p props specifies the optional compile-time properties /// of the type esimd::properties and may include esimd::cache_hint_L1, @@ -383,7 +518,8 @@ __XETLA_API void xetla_prefetch_global(T* p, uint64_t byte_offset = 0) { /// /// Restrictions - cache hint imposed - temporary: /// If L1 or L2 cache hint is passed, then: -/// R1: The pointer must be at least 4-byte aligned for elements of 4-bytes or +/// R1: The pointer must be at least 4-byte aligned for elements of 4-bytes +/// or /// smaller and 8-byte aligned for 8-byte elements. /// R2: The number of elements for 8-byte data: 1, 2, 3, 4, 8, 16, 32, 64; /// for 4-byte data: 1, 2, 3, 4, 8, 16, 32, 64, @@ -574,6 +710,71 @@ __XETLA_API xetla_vector xetla_load_global( #endif } +/// 2D USM pointer block store. +/// Supported platforms: PVC +/// VISA instruction: lsc_store_block2d.ugm +/// +/// Stores elements at specified address. +/// +/// @tparam T is element type. +/// @tparam BlockWidth is the block width in number of elements. +/// @tparam BlockHeight is the block height in number of elements. +/// @tparam L1H is L1 cache hint. +/// @tparam L2H is L2 cache hint. +/// @tparam N is the data size +/// @param Ptr is the surface base address for this operation. +/// @param SurfaceWidth is the surface width minus 1 in bytes +/// @param SurfaceHeight is the surface height minus 1 in rows +/// @param SurfacePitch is the surface pitch minus 1 in bytes +/// @param X is zero based X-coordinate of the left upper rectangle corner in +/// number of elements. +/// @param Y is zero based Y-coordinate of the left upper rectangle corner in +/// rows. +/// @param Vals is a vector to store of type T and size N, where +/// N = roundUpNextMultiple(BlockHeight, 4 / sizeof(T)) * +/// getNextPowerOf2(BlockWidth) * NBlocks +/// +template < + typename T, + int BlockWidth, + int BlockHeight = 1, + cache_hint L1H = cache_hint::none, + cache_hint L2H = cache_hint::none, + int N = __ESIMD_ENS::detail::get_lsc_block_2d_data_size< + T, + 1u, + BlockHeight, + BlockWidth, + false, + false>()> +__XETLA_API void xetla_store_global( + T* Ptr, + unsigned SurfaceWidth, + unsigned SurfaceHeight, + unsigned SurfacePitch, + int X, + int Y, + auto&& Vals) { + if constexpr (std::is_same_v) { + xetla_store_global( + reinterpret_cast(Ptr), + SurfaceWidth, + SurfaceHeight, + SurfacePitch, + X, + Y, + Vals.xetla_format()); + } else { + __ESIMD_ENS::lsc_store_2d< + T, + BlockWidth, + BlockHeight, + gpu::xetla::detail::get_cache_hint(L1H), + gpu::xetla::detail::get_cache_hint(L2H), + N>( + Ptr, SurfaceWidth - 1, SurfaceHeight - 1, SurfacePitch - 1, X, Y, Vals); + } +} /// template /// void scatter(T *p, simd byte_offsets, simd vals, @@ -951,6 +1152,10 @@ __XETLA_API xetla_vector xetla_load_local( xetla_vector offsets, xetla_mask pred = 1) { using T = native_type_t; + DEBUG_INVOKE( + dbg_level::core, + core::general_1d:: + template check_restriction(offsets)); return __ESIMD_ENS:: lsc_slm_gather( @@ -975,6 +1180,11 @@ __XETLA_API xetla_vector xetla_load_local( template __XETLA_API xetla_vector xetla_load_local(uint32_t offset) { using T = native_type_t; + // DEBUG_INVOKE( + // dbg_level::core, + // core::general_1d::template + // check_restriction( + // (uint64_t)offset)); return __ESIMD_NS::slm_block_load(offset); } @@ -1005,6 +1215,10 @@ __XETLA_API void xetla_store_local( xetla_vector vals, xetla_mask pred = 1) { using T = native_type_t; + DEBUG_INVOKE( + dbg_level::core, + core::general_1d:: + template check_restriction(offsets)); __ESIMD_ENS:: lsc_slm_scatter( diff --git a/include/common/utils/common.hpp b/include/common/utils/common.hpp index 7d3b8fe15..d3bca4818 100644 --- a/include/common/utils/common.hpp +++ b/include/common/utils/common.hpp @@ -275,8 +275,8 @@ enum class reg_layout : uint8_t { tiled = 1, vnni_tiled = 2, transpose_tiled = 3, - /// this is vnni tiled format, but for each block, they are stored in col - /// major order + /// this is vnni tiled format, but for each block, they are stored in + /// col-major order vnni_tiled_col_major = 4 }; enum class store_op : uint8_t { diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp index c793e2acb..140e61593 100644 --- a/include/subgroup/tile/impl/load_xe.hpp +++ b/include/subgroup/tile/impl/load_xe.hpp @@ -89,12 +89,13 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr uint32_t num_block_x = tile_desc::num_block_x; static constexpr uint32_t num_block_y = tile_desc::num_block_y; - static constexpr uint32_t num_block = tile_desc::num_block; static constexpr gpu_arch arch_tag = payload_t::arch_tag; static constexpr reg_layout reg_layout_ = tile_desc::register_layout; - static constexpr bool is_vnni_reverse = payload_t::mem_dword_transpose && + // In the case of pack, tranpose is in vnni format + static constexpr bool is_vnni_reverse = + payload_t::mem_transpose_dtype_less4bytes && ((reg_layout_ == reg_layout::tiled) || (reg_layout_ == reg_layout::transpose_tiled)); static constexpr bool reg_transpose = tile_desc::reg_transpose; @@ -106,36 +107,62 @@ tile_load(tile_t& tile, payload_t& payload) { static constexpr bool mem_transform = payload_t::mem_transform; using load_store_attr = load_store_attr_t; + + // static constexpr uint32_t max_load_width_in_elem = trans + // ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype) + // : load_store_attr::max_load_width_in_bytes / sizeof(dtype); + // static constexpr uint32_t max_load_height_in_elem = trans + // ? load_store_attr::max_trans_load_height_in_elem + // : load_store_attr::max_load_height_in_elem; + // static constexpr uint32_t max_trans_load_width_in_elem = + // load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype); + // static constexpr uint32_t max_load_width_in_elem = + // load_store_attr::max_load_width_in_bytes / sizeof(dtype); + + // static constexpr uint32_t max_trans_load_height_in_elem = + // load_store_attr::max_trans_load_height_in_elem; + + // static constexpr uint32_t max_load_height_in_elem = + // load_store_attr::max_load_height_in_elem; + static constexpr uint32_t elems_per_CL = load_store_attr::cache_line_size_in_bytes / sizeof(dtype); + static constexpr uint32_t elems_per_reg = register_bytes_t::reg_in_bytes / sizeof(dtype); - static constexpr int32_t max_load_block_height = - load_store_attr::max_load_height_in_elem; - static constexpr int32_t max_block_width = - load_store_attr::max_load_width_in_bytes / sizeof(dtype); - static constexpr int32_t max_trans_block_width = - load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype); - - static constexpr uint32_t ld_blk_size_y_limit = - mem_transpose ? max_trans_block_width : max_load_block_height; - static constexpr uint32_t ld_blk_size_y = reg_transpose - ? block_size_y - : (block_size_y > ld_blk_size_y_limit ? ld_blk_size_y_limit - : block_size_y); + + static constexpr uint32_t max_load_width_in_elem = trans + ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype) + : load_store_attr::max_load_width_in_bytes / sizeof(dtype); + + static constexpr uint32_t max_load_blk_height_in_elem = trans + ? load_store_attr::max_trans_load_height_in_elem + : load_store_attr::max_load_height_in_elem; + + static constexpr uint32_t ld_blk_width = std::min( + (mem_transpose ? block_size_y : block_size_x), max_load_width_in_elem); + + static constexpr uint32_t ld_blk_height = std::min( + (mem_transpose ? block_size_x : block_size_y), + max_load_blk_height_in_elem); + + static constexpr uint32_t ld_blk_size_y = + mem_transpose ? ld_blk_width : ld_blk_height; + + static constexpr uint32_t ld_blk_size_y_limit = mem_transpose + ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype) + : load_store_attr::max_load_height_in_elem; // array len is used to make sure memory load is cache line aligned // disabled while register or memory transpose static constexpr uint8_t arr_len_candidate = - (reg_transpose || - mem_transpose + ((reg_transpose || mem_transpose) // block elements should be integer // times of register bytes - || ((block_size_y * block_size_x) % elems_per_reg != 0) + || ((block_elems) % elems_per_reg != 0) // tail blocks also need to meet above condition - || - (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0)) || - (block_size_y > ld_blk_size_y_limit) + || (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0)) + // || (block_size_y > load_store_attr::max_load_height_in_elem) ? 1 : (((tile_size_x % elems_per_CL) == 0) ? (((elems_per_CL % block_size_x) == 0) @@ -143,98 +170,82 @@ tile_load(tile_t& tile, payload_t& payload) { : 1) : ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x) : 1)); - static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) || - (arr_len_candidate == 2) || (arr_len_candidate == 4); - - static constexpr uint8_t arr_len = - is_valid_arr_len_candidate ? arr_len_candidate : 1; - - static_assert( - reg_transpose || mem_transpose || - (!mem_transpose && (block_size_x * arr_len) <= max_block_width), - "When reg_transpose was disabled, check 2d block width " - "restriction"); - static_assert( - !reg_transpose || - (!mem_transpose && - (block_size_x * arr_len) <= max_trans_block_width) || - (mem_transpose && (block_size_y * arr_len) <= max_block_width), - "When reg_transpose was enabled, check 2d block width " - "restriction"); - static_assert( - !reg_transpose || - (!mem_transpose && (block_size_y <= max_load_block_height)) || - (mem_transpose && (block_size_x) <= max_load_block_height), - "When reg_transpose was enabled, check 2d block height " - "restriction"); - static_assert( - tile_size_x % (block_size_x * arr_len) == 0, - "tile_size_x should be a multiple of (block_size_x * arr_len)"); + // NBlocks must be {1,2,4} for bytes and words, {1,2} for dwords, 1 for + // qwords. + static constexpr bool arr_len = + ((arr_len_candidate == 1) || + (arr_len_candidate == 2 && sizeof(dtype) <= 4) || + (arr_len_candidate == 4 && sizeof(dtype) <= 2)) + ? arr_len_candidate + : 1; + + if constexpr (!trans && !mem_transform) { + static_assert( + (ld_blk_width * arr_len) <= max_load_width_in_elem, + "When Transposed and Transformed are both set to false, BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8 for qwords"); + } else if constexpr (mem_transform) { + static_assert( + (ld_blk_width * arr_len) <= max_load_width_in_elem, + "When Transformed is true then, BlockWidth * NBlocks must not exceed 64 for bytes and 32 for words."); + } static_assert( (reg_transpose && ((block_size_x * sizeof(dtype)) % sizeof(load_dtype) == 0)) || ((block_size_y * sizeof(dtype)) % sizeof(load_dtype) == 0), "check vnni limitation for DW transpose"); - auto payload_2d = payload.payloads.xetla_format(); #pragma unroll for (uint32_t i = 0; i < num_block_y; ++i) { - constexpr uint32_t load_block_elems = block_elems * arr_len; - auto payload_row = - payload_2d.xetla_select(i * num_block_x, 0); - detail::reset_tile_desc_core< - num_block_x, - block_size_x, - ld_blk_size_y, - scale_factor, - arr_len, - mem_transpose>(payload_row); + int offset_y = i * block_size_y; #pragma unroll for (uint32_t j = 0; j < num_block_x; j += arr_len) { - xetla_tdescriptor tdesc = payload_row.row(j); + int32_t offset_x = j * block_size_x; + constexpr uint32_t load_block_elems = block_elems * arr_len; auto reg_blk = tile.reg.xetla_select( (i * num_block_x + j) * block_elems); - constexpr uint32_t ld_blk_height = (reg_transpose && trans) - ? detail::getNextPowerOf2() - : ld_blk_size_y; - constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len; + constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len; xetla_vector reg_tmp; #pragma unroll for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) { constexpr uint32_t load_elems = ld_blk_size_y * block_size_x * arr_len; - - reg_tmp.xetla_format>() = xetla_tload_global< - load_dtype, - ld_blk_height * block_size_x * arr_len / scale_factor, - L1, - L2, + uint32_t address_offset_x = + (mem_transpose ? (offset_y + ii * ld_blk_size_y) : offset_x) / + scale_factor; + uint32_t address_offset_y = + mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y); + reg_tmp.xetla_format>() = xetla_load_global< + native_type_t, + ld_blk_width / scale_factor, + ld_blk_height, + arr_len, trans, mem_transform, - arch_tag>(tdesc); + L1, + L2>( + reinterpret_cast>( + payload.base_ptr), + payload.surface_width, + payload.surface_height, + payload.surface_pitch, + payload.offset_x + address_offset_x, + payload.offset_y + address_offset_y); + if constexpr (reg_transpose && trans) { reg_blk.xetla_select(ii * load_elems) .xetla_format>() = reg_tmp .xetla_format< native_type_t, - block_size_x / scale_factor, + ld_blk_width / scale_factor, ld_blk_height>() .xetla_select< - block_size_x / scale_factor, + ld_blk_width / scale_factor, 1, ld_blk_size_y, 1>(0, 0); } else { reg_blk.xetla_select(ii * tmp_size) = reg_tmp; } - - if constexpr (mem_transpose) { - xetla_update_tdesc_offsetx( - tdesc.xetla_format(), ld_blk_size_y / scale_factor); - } else { - xetla_update_tdesc_offsety( - tdesc.xetla_format(), ld_blk_size_y); - } } // exceed HW limitation if constexpr (block_size_y % ld_blk_size_y != 0) { @@ -244,26 +255,29 @@ tile_load(tile_t& tile, payload_t& payload) { remained_start_y * block_size_x * arr_len; constexpr uint32_t remained_blk_size_y = block_size_y % ld_blk_size_y; constexpr uint32_t load_elems = - remained_blk_size_y * block_size_x * arr_len; + remained_blk_size_y * block_size_x * arr_len / scale_factor; constexpr uint8_t block_width = - mem_transpose ? (remained_blk_size_y / scale_factor) : block_size_x; + (mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor; constexpr uint8_t block_height = - trans ? block_size_x : remained_blk_size_y; - constexpr uint32_t block_widthx_widthy_arrlen = - (block_width - 1) | ((block_height - 1) << 8); - gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen( - tdesc.xetla_format(), block_widthx_widthy_arrlen); - + mem_transpose ? block_size_x : remained_blk_size_y; reg_blk.xetla_select(remained_start) - .xetla_format>() = xetla_tload_global< - load_dtype, - (load_elems / scale_factor), - L1, - L2, + .xetla_format>() = xetla_load_global< + native_type_t, + block_width, + block_height, + arr_len, trans, mem_transform, - arch_tag>(tdesc); + L1, + L2>( + reinterpret_cast>( + payload.base_ptr), + payload.surface_width, + payload.surface_height, + payload.surface_pitch, + payload.offset_x + offset_x / scale_factor, + payload.offset_y + offset_y + remained_start_y); } } } @@ -276,18 +290,11 @@ tile_load(tile_t& tile, payload_t& payload) { (!reg_transpose && (remained_size_y > ld_blk_size_y_limit)) ? ld_blk_size_y_limit : remained_size_y; - auto payload_row = payload_2d.xetla_select( - num_block_y * num_block_x, 0); - detail::reset_tile_desc_core< - num_block_x, - block_size_x, - remained_ld_blk_size_y, - scale_factor, - arr_len, - mem_transpose>(payload_row); + #pragma unroll for (uint32_t j = 0; j < num_block_x; j += arr_len) { - xetla_tdescriptor tdesc = payload_row.row(j); + int32_t offset_x = j * block_size_x; + // xetla_tdescriptor tdesc = payload_row.row(j); auto reg_blk = tile.reg.xetla_select( processed_elems + j * remained_block_elems); constexpr uint32_t ld_blk_height = (reg_transpose && trans) @@ -301,15 +308,23 @@ tile_load(tile_t& tile, payload_t& payload) { constexpr uint32_t load_elems = remained_ld_blk_size_y * block_size_x * arr_len; - reg_tmp.xetla_format>() = xetla_tload_global< - load_dtype, - (ld_blk_height * block_size_x * arr_len / scale_factor), - L1, - L2, + reg_tmp.xetla_format>() = xetla_load_global< + native_type_t, + block_size_x / scale_factor, + remained_ld_blk_size_y, + arr_len, trans, mem_transform, - arch_tag>(tdesc); - + L1, + L2>( + reinterpret_cast>( + payload.base_ptr), + payload.surface_width, + payload.surface_height, + payload.surface_pitch, + payload.offset_x + offset_x / scale_factor, + payload.offset_y + num_block_y * block_size_y + + ii * remained_ld_blk_size_y); if constexpr (reg_transpose && trans) { reg_blk.xetla_select(ii * load_elems) .xetla_format>() = @@ -326,14 +341,14 @@ tile_load(tile_t& tile, payload_t& payload) { } else { reg_blk.xetla_select(ii * tmp_size) = reg_tmp; } - if constexpr (mem_transpose) { - xetla_update_tdesc_offsetx( - tdesc.xetla_format(), - remained_ld_blk_size_y / scale_factor); - } else { - xetla_update_tdesc_offsety( - tdesc.xetla_format(), remained_ld_blk_size_y); - } + // if constexpr (mem_transpose) { + // xetla_update_tdesc_offsetx( + // tdesc.xetla_format(), + // remained_ld_blk_size_y / scale_factor); + // } else { + // xetla_update_tdesc_offsety( + // tdesc.xetla_format(), remained_ld_blk_size_y); + // } } constexpr uint32_t final_ld_blk_size_y = remained_size_y % remained_ld_blk_size_y; @@ -344,22 +359,40 @@ tile_load(tile_t& tile, payload_t& payload) { constexpr uint32_t final_load_elems = final_ld_blk_size_y * block_size_x * arr_len; constexpr uint8_t block_width = - mem_transpose ? (final_ld_blk_size_y / scale_factor) : block_size_x; + (mem_transpose ? final_ld_blk_size_y : block_size_x) / scale_factor; constexpr uint8_t block_height = - trans ? block_size_x : final_ld_blk_size_y; - constexpr uint32_t block_widthx_widthy_arrlen = - (block_width - 1) | ((block_height - 1) << 8); - gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen( - tdesc.xetla_format(), block_widthx_widthy_arrlen); + mem_transpose ? block_size_x : final_ld_blk_size_y; + // constexpr uint32_t block_widthx_widthy_arrlen = + // (block_width - 1) | ((block_height - 1) << 8); + // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen( + // tdesc.xetla_format(), block_widthx_widthy_arrlen); reg_blk.xetla_select(final_start) - .xetla_format>() = xetla_tload_global< - load_dtype, - final_load_elems / scale_factor, - L1, - L2, + .xetla_format>() = xetla_load_global< + native_type_t, + block_width, + block_height, + arr_len, trans, mem_transform, - arch_tag>(tdesc); + L1, + L2>( + reinterpret_cast>( + payload.base_ptr), + payload.surface_width, + payload.surface_height, + payload.surface_pitch, + payload.offset_x + offset_x / scale_factor, + payload.offset_y + num_block_y * block_size_y + + remained_size_y / remained_ld_blk_size_y * + remained_ld_blk_size_y); + // xetla_tload_global< + // load_dtype, + // final_load_elems / scale_factor, + // L1, + // L2, + // trans, + // mem_transform, + // arch_tag>(tdesc); } } } @@ -426,7 +459,8 @@ tile_load(tile_t& tile, payload_t& payload) { /// @brief This function loads data from unaligned-2D memory surface. /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into -/// registers. Each block will be loaded serially by its corresponding payload. +/// registers. Each block will be loaded serially by its corresponding +/// payload. /// @tparam tile_t Is the tile_t struct contains registers. /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory @@ -550,7 +584,8 @@ tile_load(tile_t& tile, payload_t& payload) { /// @brief This function loads data from unaligned-2D memory surface. /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into -/// registers. Each block will be loaded serially by its corresponding payload. +/// registers. Each block will be loaded serially by its corresponding +/// payload. /// @tparam tile_t Is the tile_t struct contains registers. /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory @@ -615,7 +650,8 @@ tile_load(tile_t& tile, payload_t& payload) { /// @brief This function loads data from unaligned-2D memory surface. /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into -/// registers. Each block will be loaded serially by its corresponding payload. +/// registers. Each block will be loaded serially by its corresponding +/// payload. /// @tparam tile_t Is the tile_t struct contains registers. /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory @@ -755,8 +791,8 @@ tile_load( } /// @brief Is the data load func from local shared memory to register file, -/// which supports the memory surface is 1d or 2d scenario. And we always assume -/// data in SLM is row major. +/// which supports the memory surface is 1d or 2d scenario. And we always +/// assume data in SLM is row major. /// @tparam tile_t Is the tile_t struct contains registers /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory @@ -838,8 +874,8 @@ tile_load(tile_t& tile, payload_t& payload) { } /// @brief Is the data load func from shared local memory to register file, -/// which supports the memory surface is 1d scenario. And the src memory layout -/// is always row major. +/// which supports the memory surface is 1d scenario. And the src memory +/// layout is always row major. /// @tparam tile_t Is the tile_t struct contains registers. /// These registers will be the destination of load operation. /// @tparam payload_t Is the mem_payload_t struct describing the memory diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp index 697aba49e..3480a7820 100644 --- a/include/subgroup/tile/impl/payload_xe.hpp +++ b/include/subgroup/tile/impl/payload_xe.hpp @@ -75,22 +75,53 @@ struct mem_payload_t< static constexpr bool trans = (mem_transpose ^ reg_transpose) && !(std::is_same_v || std::is_same_v); - static constexpr bool mem_transform = (sizeof(dtype) < 4) && !mem_transpose && + // Transformed and Transposed cannot be set to true at the same time. + // If Transformed is true then: + // sizeof(T) must be 1- or 2-byte (bytes or words). + static constexpr bool mem_transform = (sizeof(dtype) <= 2) && !trans && (register_layout == reg_layout::vnni_tiled || register_layout == reg_layout::vnni_tiled_col_major); - static constexpr bool mem_dword_transpose = (sizeof(dtype) < 4) && trans; - using mem_dtype = - typename std::conditional::type; + // If Transposed is true then: + // sizeof(T) must be 4- or 8-byte (dwords or qwords). + static constexpr bool mem_transpose_dtype_less4bytes = + (sizeof(dtype) < 4) && trans; + + using mem_dtype = typename std:: + conditional_t; static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype); + dtype* base_ptr; + uint32_t surface_width; + uint32_t surface_height; + uint32_t surface_pitch; + int32_t offset_x; + int32_t offset_y; + xetla_vector payloads; inline mem_payload_t(const this_payload_t& rhs) { + this->base_ptr = rhs.base_ptr; + this->surface_width = rhs.surface_width; + this->surface_height = rhs.surface_height; + this->surface_pitch = rhs.surface_pitch; + this->offset_x = rhs.offset_x; + this->offset_y = rhs.offset_y; + this->payloads = rhs.payloads; } inline mem_payload_t(mem_desc_t& mem_desc) { + this->base_ptr = (dtype*)mem_desc.base.base; + this->surface_width = + (mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype); + this->surface_height = + (mem_transpose ? mem_desc.shape.x : mem_desc.shape.y); + this->surface_pitch = mem_desc.shape.stride * sizeof(dtype); + this->offset_x = (mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) / + int32_t(scale_factor); + this->offset_y = mem_transpose ? mem_desc.coord.x : mem_desc.coord.y; + xetla_tdescriptor base_tdesc = mem_desc.get_tdesc(); int32_t offset = gpu::xetla::detail::xetla_get_tensor_offset_x(base_tdesc) / int32_t(scale_factor); @@ -106,7 +137,15 @@ struct mem_payload_t< uint32_t surface_pitch, int32_t surface_offset_x = 0, int32_t surface_offset_y = 0) { + this->base_ptr = p; + this->surface_width = surface_width * sizeof(dtype); + this->surface_height = surface_height; + this->surface_pitch = surface_pitch * sizeof(dtype); + this->offset_x = surface_offset_x / int32_t(scale_factor); + this->offset_y = surface_offset_y; + xetla_tdescriptor base_tdesc; + xetla_fill_tdesc( base_tdesc.xetla_format(), p, @@ -119,6 +158,16 @@ struct mem_payload_t< } __XETLA_API void init(mem_desc_t& mem_desc) { + this->base_ptr = (dtype*)mem_desc.base.base; + this->surface_width = + (mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype); + this->surface_height = + (mem_transpose ? mem_desc.shape.x : mem_desc.shape.y); + this->surface_pitch = mem_desc.shape.stride * sizeof(dtype); + this->offset_x = (mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) / + int32_t(scale_factor); + this->offset_y = (mem_transpose ? mem_desc.coord.x : mem_desc.coord.y); + xetla_tdescriptor base_tdesc = mem_desc.get_tdesc(); int32_t offset = gpu::xetla::detail::xetla_get_tensor_offset_x(base_tdesc) / int32_t(scale_factor); @@ -142,6 +191,13 @@ struct mem_payload_t< uint32_t surface_pitch, int32_t surface_offset_x = 0, int32_t surface_offset_y = 0) { + this->base_ptr = p; + this->surface_width = surface_width * sizeof(dtype); + this->surface_height = surface_height; + this->surface_pitch = surface_pitch * sizeof(dtype); + this->offset_x = surface_offset_x / int32_t(scale_factor); + this->offset_y = surface_offset_y; + xetla_tdescriptor base_tdesc; xetla_fill_tdesc( base_tdesc.xetla_format(), @@ -160,6 +216,13 @@ struct mem_payload_t< // ~mem_payload_t(){} inline this_payload_t& operator=(const this_payload_t& rhs) { + this->base_ptr = rhs.base_ptr; + this->surface_width = rhs.surface_width; + this->surface_height = rhs.surface_height; + this->surface_pitch = rhs.surface_pitch; + this->offset_x = rhs.offset_x; + this->offset_y = rhs.offset_y; + this->payloads = rhs.payloads; return *this; } @@ -168,12 +231,14 @@ struct mem_payload_t< __XETLA_API void update_tdesc(int offset) { auto payloads_2d = payloads.xetla_format(); if constexpr (update_dir == tdesc_update_dir::x_dir) { + offset_x += offset / scale_factor; #pragma unroll for (uint32_t i = 0; i < num_block; i++) { xetla_update_tdesc_offsetx( payloads_2d.row(i), offset / int32_t(scale_factor)); } } else { + offset_y += offset; #pragma unroll for (uint32_t i = 0; i < num_block; i++) { xetla_update_tdesc_offsety(payloads_2d.row(i), offset); @@ -1150,10 +1215,9 @@ struct mem_payload_t< static constexpr uint32_t tile_size_y = tile_desc::tile_size_y; static constexpr uint32_t block_size_x = tile_desc::block_size_x; static constexpr uint32_t block_size_y = tile_desc::block_size_y; - static constexpr uint32_t tile_bytes = - tile_size_x * tile_size_y * sizeof(dtype); + static constexpr uint32_t tile_bytes = tile_desc::tile_elems * sizeof(dtype); static constexpr uint32_t block_bytes = - block_size_x * block_size_y * sizeof(dtype); + tile_desc::block_elems * sizeof(dtype); using this_payload_t = mem_payload_t; @@ -1230,7 +1294,7 @@ struct mem_payload_t< base_offset = mem_transpose ? base_x * pitch_in_bytes + base_y * sizeof(dtype) : base_y * pitch_in_bytes + base_x * sizeof(dtype); - base_ptr = (mem_dtype*)mem_tdesc.base.base; + base_ptr = reinterpret_cast(mem_tdesc.base.base); xetla_vector channel_index = xetla_vector_gen(0, 1); @@ -1710,11 +1774,12 @@ struct prefetch_payload_t< reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t<(!arch_has_2d_load_store)&&( - ((block_size_y_ != 1 || tile_size_y_ != 1) && - mem_layout_ == mem_layout::row_major) || - ((block_size_x_ != 1 || tile_size_x_ != 1) && - mem_layout_ == mem_layout::col_major))>> { + std::enable_if_t< + (!arch_has_2d_load_store) && + (((block_size_y_ != 1 || tile_size_y_ != 1) && + mem_layout_ == mem_layout::row_major) || + ((block_size_x_ != 1 || tile_size_x_ != 1) && + mem_layout_ == mem_layout::col_major))>> { using dtype = native_type_t; using mem_desc_t = mem_desc_t; @@ -1735,10 +1800,8 @@ struct prefetch_payload_t< static constexpr uint32_t tile_size_y = tile_desc::tile_size_y; static constexpr uint32_t block_size_x = tile_desc::block_size_x; static constexpr uint32_t block_size_y = tile_desc::block_size_y; - static constexpr uint32_t tile_bytes = - tile_size_x * tile_size_y * sizeof(dtype); - static constexpr uint32_t block_bytes = - block_size_x * block_size_y * sizeof(dtype); + static constexpr uint32_t tile_bytes = tile_desc::block_elems * sizeof(dtype); + static constexpr uint32_t block_bytes = tile_desc::tile_elems * sizeof(dtype); private: using this_payload_t = @@ -1781,41 +1844,43 @@ struct prefetch_payload_t< static constexpr uint32_t num_channel = select_channel( std::min(mem_transpose ? block_size_x : block_size_y, max_channel)); - static constexpr uint32_t mem_tile_size_w = - mem_transpose ? tile_size_y : tile_size_x; - static constexpr uint32_t mem_tile_size_h = - mem_transpose ? tile_size_x : tile_size_y; - using load_store_attr = - typename arch_attr_t::template load_store_attr; - static constexpr uint32_t special_prefetch_width = - load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype); - static constexpr uint32_t normal_prefetch_width = - load_store_attr::max_load_width_in_bytes / sizeof(dtype); - static constexpr bool is_special_prefetch = - (mem_tile_size_w % special_prefetch_width) == 0; - - static constexpr uint32_t block_size_w = is_special_prefetch - ? special_prefetch_width - : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w - : normal_prefetch_width); - static constexpr uint32_t block_size_h = - load_store_attr::max_load_height_in_elem; - // could have over-prefetch, but that's should be fine - static constexpr uint32_t max_num_block_w = - (mem_tile_size_w + block_size_w - 1) / block_size_w; - static constexpr uint32_t num_coop_sg = num_coop_sg_; - static constexpr uint32_t num_coop_sg_w = - detail::gcd::value; - static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w; - - static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w; - static constexpr uint32_t tile_size_w = block_size_w * num_block_w; - static constexpr uint32_t tile_size_h = - (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h; - static constexpr uint32_t num_block_h = - (tile_size_h + block_size_h - 1) / block_size_h; + // static constexpr uint32_t mem_tile_size_w = + // mem_transpose ? tile_size_y : tile_size_x; + // static constexpr uint32_t mem_tile_size_h = + // mem_transpose ? tile_size_x : tile_size_y; + + // static constexpr uint32_t special_prefetch_width = + // load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype); + // static constexpr uint32_t normal_prefetch_width = + // load_store_attr::max_load_width_in_bytes / sizeof(dtype); + // static constexpr bool is_special_prefetch = + // (mem_tile_size_w % special_prefetch_width) == 0; + + // static constexpr uint32_t block_size_w = is_special_prefetch + // ? special_prefetch_width + // : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w + // : normal_prefetch_width); + // static constexpr uint32_t block_size_h = + // load_store_attr::max_load_height_in_elem; + // // could have over-prefetch, but that's should be fine + // static constexpr uint32_t max_num_block_w = + // (mem_tile_size_w + block_size_w - 1) / block_size_w; + // static constexpr uint32_t num_coop_sg = num_coop_sg_; + // static constexpr uint32_t num_coop_sg_w = + // detail::gcd::value; + // static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w; + + // static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w; + // static constexpr uint32_t tile_size_w = block_size_w * num_block_w; + // static constexpr uint32_t tile_size_h = + // (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h; + // static constexpr uint32_t num_block_h = + // (tile_size_h + block_size_h - 1) / block_size_h; xetla_vector channel_offset; + xetla_vector step_x; + xetla_vector step_y; + uint64_t base_offset; uint32_t base_x; uint32_t base_y; @@ -1852,13 +1917,15 @@ struct prefetch_payload_t< return *this; } - inline prefetch_payload_t(mem_desc_t& mem_desc, uint32_t coop_id = 0) { - uint32_t coop_id_x = coop_id % num_coop_sg_w; - uint32_t coop_id_y = coop_id / num_coop_sg_w; + inline prefetch_payload_t( + mem_desc_t& mem_desc, + [[maybe_unused]] uint32_t coop_id = 0) { + // uint32_t coop_id_x = coop_id % num_coop_sg_w; + // uint32_t coop_id_y = coop_id / num_coop_sg_w; pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype); - base_x = mem_desc.coord.x + coop_id_x * tile_size_w; - base_y = mem_desc.coord.y + coop_id_y * tile_size_h; + base_x = mem_desc.coord.x; + base_y = mem_desc.coord.y; width_in_elems = mem_desc.shape.x; height_in_elems = mem_desc.shape.y; base_offset = mem_transpose @@ -1878,13 +1945,15 @@ struct prefetch_payload_t< int surface_pitch, int surface_offset_x, int surface_offset_y, - uint32_t coop_id = 0) { - uint32_t coop_id_x = coop_id % num_coop_sg_w; - uint32_t coop_id_y = coop_id / num_coop_sg_w; + [[maybe_unused]] uint32_t coop_id = 0) { + // uint32_t coop_id_x = coop_id % num_coop_sg_w; + // uint32_t coop_id_y = coop_id / num_coop_sg_w; + // base_x = surface_offset_x + coop_id_x * tile_size_w; + // base_y = surface_offset_y + coop_id_y * tile_size_h; pitch_in_bytes = surface_pitch * sizeof(dtype); - base_x = surface_offset_x + coop_id_x * tile_size_w; - base_y = surface_offset_y + coop_id_y * tile_size_h; + base_x = surface_offset_x; + base_y = surface_offset_y; width_in_elems = surface_width; height_in_elems = surface_height; base_offset = mem_transpose @@ -1897,13 +1966,17 @@ struct prefetch_payload_t< channel_offset = channel_index * pitch_in_bytes; } - inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) { - uint32_t coop_id_x = coop_id % num_coop_sg_w; - uint32_t coop_id_y = coop_id / num_coop_sg_w; + inline void init( + mem_desc_t& mem_desc, + [[maybe_unused]] uint32_t coop_id = 0) { + // uint32_t coop_id_x = coop_id % num_coop_sg_w; + // uint32_t coop_id_y = coop_id / num_coop_sg_w; + // base_x = mem_desc.coord.x + coop_id_x * tile_size_w; + // base_y = mem_desc.coord.y + coop_id_y * tile_size_h; pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype); - base_x = mem_desc.coord.x + coop_id_x * tile_size_w; - base_y = mem_desc.coord.y + coop_id_y * tile_size_h; + base_x = mem_desc.coord.x; + base_y = mem_desc.coord.y; width_in_elems = mem_desc.shape.x; height_in_elems = mem_desc.shape.y; base_offset = mem_transpose @@ -1960,9 +2033,10 @@ struct prefetch_payload_t< reg_layout_>, num_coop_sg_, arch_tag_, - std::enable_if_t<(arch_has_2d_load_store)&&( - ((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || - ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { + std::enable_if_t< + (arch_has_2d_load_store) && + (((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) || + ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> { using dtype = dtype_; using mem_desc_t = mem_desc_t; diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp index 243379b9d..f14941518 100644 --- a/include/subgroup/tile/impl/prefetch_xe.hpp +++ b/include/subgroup/tile/impl/prefetch_xe.hpp @@ -104,8 +104,7 @@ tile_prefetch(payload_t& payload) { using prefetch_dtype = typename payload_t::prefetch_dtype; constexpr uint32_t num_channel = payload_t::num_channel; #pragma unroll - for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y; - i++) { + for (uint32_t i = 0; i < tile_desc::num_block_y; i++) { uint32_t offset_y = i * tile_desc::block_size_y; #pragma unroll for (uint32_t j = 0; j < tile_desc::num_block_x; j++) { @@ -126,7 +125,6 @@ tile_prefetch(payload_t& payload) { L2>( payload.base_ptr, payload.channel_offset + payload.base_offset + address_offset); - // } } } } diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp index bec5b007a..7f0645da9 100644 --- a/include/subgroup/tile/impl/store_xe.hpp +++ b/include/subgroup/tile/impl/store_xe.hpp @@ -98,51 +98,44 @@ tile_store(tile_t& tile, payload_t& payload) { static constexpr uint32_t num_block_x = tile_desc::num_block_x; static constexpr uint32_t num_block_y = tile_desc::num_block_y; - static constexpr uint32_t num_block = tile_desc::num_block; - using load_store_attr = typename arch_attr_t< - payload_t::arch_tag>::template load_store_attr; + static constexpr gpu_arch arch_tag = payload_t::arch_tag; - static constexpr int32_t max_block_width = - load_store_attr::max_load_width_in_bytes / sizeof(dtype); - static constexpr int32_t max_store_block_height = + using load_store_attr = load_store_attr_t; + static constexpr uint32_t max_store_width_in_elem = + load_store_attr::max_store_width_in_bytes / sizeof(dtype); + static constexpr uint32_t max_store_height_in_elem = load_store_attr::max_store_height_in_elem; - static_assert( - (max_block_width % block_size_x) == 0, - "max_block_width should be a multiply of block size x."); + static constexpr uint32_t elems_per_CL = load_store_attr::cache_line_size_in_bytes / sizeof(dtype); + + static_assert( + (max_store_width_in_elem % block_size_x) == 0, + "max_store_width_in_elem should be a multiply of block_size_x."); + static constexpr uint32_t st_blk_size_y = - block_size_y > max_store_block_height ? max_store_block_height - : block_size_y; + std::min(block_size_y, max_store_height_in_elem); + // to make sure full CL store - static constexpr uint32_t st_block_x = ((tile_size_x % elems_per_CL) == 0) + static constexpr uint32_t st_blk_size_x = ((tile_size_x % elems_per_CL) == 0) ? elems_per_CL : (((elems_per_CL % tile_size_x) == 0) ? tile_size_x : block_size_x); - static constexpr uint8_t arr_len_candidate = st_block_x / block_size_x; + static constexpr uint8_t arr_len_candidate = st_blk_size_x / block_size_x; static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) || (arr_len_candidate == 2) || (arr_len_candidate == 4); static constexpr uint8_t arr_len = is_valid_arr_len_candidate ? arr_len_candidate : 1; - auto payload_2d = payload.payloads.xetla_format(); #pragma unroll for (uint32_t i = 0; i < num_block_y; ++i) { - constexpr uint32_t store_block_elems = block_elems * arr_len; - auto payload_row = - payload_2d.xetla_select(i * num_block_x, 0); - detail::reset_tile_desc_core< - num_block_x, - block_size_x * arr_len, - st_blk_size_y, - 1, - 1, - false>(payload_row); + int32_t offset_y = i * block_size_y; #pragma unroll for (uint32_t j = 0; j < num_block_x; j += arr_len) { - xetla_tdescriptor tdesc = payload_row.row(j); + int32_t offset_x = j * block_size_x; + constexpr uint32_t store_block_elems = block_elems * arr_len; auto reg_blk = tile.reg.xetla_select( (i * num_block_x + j) * block_elems); xetla_vector combine_blk; @@ -150,41 +143,50 @@ tile_store(tile_t& tile, payload_t& payload) { native_type_t, block_size_y, block_size_x * arr_len>(); + /* combine_blk_2d + ____________ ____________ + | || | + | block || block | + | || | + |____________||____________| + */ #pragma unroll - for (uint32_t combine_i = 0; combine_i < arr_len; ++combine_i) { + for (uint32_t block_id = 0; block_id < arr_len; ++block_id) { combine_blk_2d.xetla_select( - 0, combine_i * block_size_x) = - reg_blk.xetla_select(combine_i * block_elems); + 0, block_id * block_size_x) = + reg_blk.xetla_select(block_id * block_elems); } #pragma unroll - for (uint32_t ii = 0; ii < block_size_y / st_blk_size_y; ++ii) { - constexpr uint32_t store_elems = st_blk_size_y * block_size_x * arr_len; + for (uint32_t ii = 0; ii < block_size_y; ii += st_blk_size_y) { + constexpr uint32_t store_elems = st_blk_size_y * st_blk_size_x; auto st_blk = - combine_blk.xetla_select(ii * store_elems); - xetla_tstore_global( - tdesc, st_blk); - xetla_update_tdesc_offsety( - tdesc.xetla_format(), st_blk_size_y); + combine_blk.xetla_select(ii * st_blk_size_x); + xetla_store_global( + payload.base_ptr, + payload.surface_width, + payload.surface_height, + payload.surface_pitch, + payload.offset_x + offset_x, + payload.offset_y + offset_y + ii, + st_blk); } // exceed hardware limitation if constexpr ((block_size_y % st_blk_size_y) != 0) { - constexpr uint32_t blk_remained_start = block_size_y / st_blk_size_y * - st_blk_size_y * block_size_x * arr_len; + constexpr uint32_t blk_remained_start = + block_size_y / st_blk_size_y * st_blk_size_y * st_blk_size_x; constexpr uint8_t blk_remained_y = block_size_y % st_blk_size_y; - constexpr uint8_t blk_remained_elems = - blk_remained_y * block_size_x * arr_len; + constexpr uint8_t blk_remained_elems = blk_remained_y * st_blk_size_x; auto st_blk = combine_blk.xetla_select(blk_remained_start); - constexpr uint32_t block_widthx_widthy_arrlen = - (block_size_x * arr_len - 1) | ((blk_remained_y - 1) << 8); - gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen( - tdesc.xetla_format(), block_widthx_widthy_arrlen); - xetla_tstore_global< - dtype, - blk_remained_elems, - L1, - L2, - payload_t::arch_tag>(tdesc, st_blk); + xetla_store_global( + payload.base_ptr, + payload.surface_width, + payload.surface_height, + payload.surface_pitch, + payload.offset_x + offset_x, + payload.offset_y + offset_y + + block_size_y / st_blk_size_y * st_blk_size_y, + st_blk); } } } @@ -194,19 +196,10 @@ tile_store(tile_t& tile, payload_t& payload) { constexpr uint32_t processed_elems = num_block_y * num_block_x * block_elems; constexpr uint32_t remained_st_blk_size_y = - st_blk_size_y > remained_size_y ? remained_size_y : st_blk_size_y; - auto payload_row = payload_2d.xetla_select( - num_block_y * num_block_x, 0); - detail::reset_tile_desc_core< - num_block_x, - block_size_x * arr_len, - remained_st_blk_size_y, - 1, - 1, - false>(payload_row); + std::min(st_blk_size_y, remained_size_y); #pragma unroll for (uint32_t j = 0; j < num_block_x; j += arr_len) { - xetla_tdescriptor tdesc = payload_row.row(j); + int offset_x = j * block_size_x; auto reg_blk = tile.reg.xetla_select( processed_elems + j * remained_block_elems); // Do combination @@ -214,46 +207,53 @@ tile_store(tile_t& tile, payload_t& payload) { auto combine_blk_2d = combine_blk.xetla_format< native_type_t, remained_size_y, - block_size_x * arr_len>(); + st_blk_size_x>(); #pragma unroll - for (uint32_t combine_i = 0; combine_i < arr_len; ++combine_i) { + for (uint32_t block_id = 0; block_id < arr_len; ++block_id) { combine_blk_2d.xetla_select( - 0, combine_i * block_size_x) = + 0, block_id * block_size_x) = reg_blk.xetla_select( - combine_i * remained_block_elems); + block_id * remained_block_elems); } #pragma unroll - for (uint32_t ii = 0; ii < remained_size_y / remained_st_blk_size_y; - ++ii) { - constexpr uint32_t store_elems = - remained_st_blk_size_y * block_size_x * arr_len; + for (uint32_t ii = 0; ii < remained_size_y; + ii += remained_st_blk_size_y) { + constexpr uint32_t store_elems = remained_st_blk_size_y * st_blk_size_x; auto st_blk = - combine_blk.xetla_select(ii * store_elems); - xetla_tstore_global( - tdesc, st_blk); - xetla_update_tdesc_offsety( - tdesc.xetla_format(), remained_st_blk_size_y); + combine_blk.xetla_select(ii * st_blk_size_x); + xetla_store_global< + dtype, + st_blk_size_x, + remained_st_blk_size_y, + L1, + L2>( + payload.base_ptr, + payload.surface_width, + payload.surface_height, + payload.surface_pitch, + payload.offset_x + offset_x, + payload.offset_y + num_block_y * block_size_y + ii, + st_blk); } constexpr uint32_t final_st_blk_size_y = remained_size_y % remained_st_blk_size_y; if constexpr (final_st_blk_size_y != 0) { constexpr uint32_t final_start = remained_size_y / - remained_st_blk_size_y * remained_st_blk_size_y * block_size_x * - arr_len; + remained_st_blk_size_y * remained_st_blk_size_y * st_blk_size_x; constexpr uint32_t final_store_elems = - final_st_blk_size_y * block_size_x * arr_len; + final_st_blk_size_y * st_blk_size_x; auto st_blk = combine_blk.xetla_select(final_start); - constexpr uint32_t block_widthx_widthy_arrlen = - (block_size_x * arr_len - 1) | ((final_st_blk_size_y - 1) << 8); - gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen( - tdesc.xetla_format(), block_widthx_widthy_arrlen); - xetla_tstore_global< - dtype, - final_store_elems, - L1, - L2, - payload_t::arch_tag>(tdesc, st_blk); + xetla_store_global( + payload.base_ptr, + payload.surface_width, + payload.surface_height, + payload.surface_pitch, + payload.offset_x + offset_x, + payload.offset_y + num_block_y * block_size_y + + remained_size_y / remained_st_blk_size_y * + remained_st_blk_size_y, + st_blk); } } } diff --git a/tests/integration/default_config/group_gemm/kernel_func.hpp b/tests/integration/default_config/group_gemm/kernel_func.hpp index 26b2fb181..7b7b31404 100644 --- a/tests/integration/default_config/group_gemm/kernel_func.hpp +++ b/tests/integration/default_config/group_gemm/kernel_func.hpp @@ -108,6 +108,9 @@ struct default_config_group_gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "default_config_group_gemm_test_func"; } diff --git a/tests/integration/default_config/kernel_gemm/kernel_func.hpp b/tests/integration/default_config/kernel_gemm/kernel_func.hpp index 3745343d0..24b82bb94 100644 --- a/tests/integration/default_config/kernel_gemm/kernel_func.hpp +++ b/tests/integration/default_config/kernel_gemm/kernel_func.hpp @@ -65,6 +65,9 @@ struct default_config_kernel_gemm_test_func { gpu_arch::XeHpc, // GPU arch tune_option>; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "default_config_kernel_gemm_test_func"; } diff --git a/tests/integration/gemm/bf16/kernel_func.hpp b/tests/integration/gemm/bf16/kernel_func.hpp index 6345047df..286e53dbb 100644 --- a/tests/integration/gemm/bf16/kernel_func.hpp +++ b/tests/integration/gemm/bf16/kernel_func.hpp @@ -76,6 +76,9 @@ struct bf16_gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "bf16_gemm_test_func"; } diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp index 0432fa349..cd7d52a7a 100644 --- a/tests/integration/gemm/fp16/common.hpp +++ b/tests/integration/gemm/fp16/common.hpp @@ -59,7 +59,7 @@ class TestBaseFP16f : public TestBase { using data_type_b = fp16; using data_type_c = fp16; using data_type_acc = float; - static constexpr mma_engine engine = mma_engine::fpu; + static constexpr mma_engine engine = mma_engine::xmx; }; class TestBaseFP16x : public TestBase { diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp index 32ef12461..78dca8b4a 100644 --- a/tests/integration/gemm/fp16/main.cpp +++ b/tests/integration/gemm/fp16/main.cpp @@ -32,7 +32,6 @@ TYPED_TEST_P(fp16_gemm_test, esimd) { esimd_compile_string); } REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd); -using tests = - ::testing::Types; +using tests = ::testing::Types; INSTANTIATE_TYPED_TEST_SUITE_P(fp16_gemm_test_suite, fp16_gemm_test, tests); diff --git a/tests/integration/gemm/fp32/common.hpp b/tests/integration/gemm/fp32/common.hpp index 7ce5098fc..67303c398 100644 --- a/tests/integration/gemm/fp32/common.hpp +++ b/tests/integration/gemm/fp32/common.hpp @@ -97,7 +97,7 @@ class Test3 : public TestBase { static constexpr size_t mat_m = 16; static constexpr size_t mat_n = 64; static constexpr size_t mat_k = 32; - static constexpr size_t wg_m = 8; + static constexpr size_t wg_m = 16; static constexpr size_t wg_n = 64; static constexpr size_t sg_m = 1; static constexpr size_t sg_n = 64; @@ -205,7 +205,7 @@ class Test8 : public TestBase { static constexpr uint32_t global_kslicing = 2; static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = float; using data_type_b = float; using data_type_c = float; @@ -227,7 +227,6 @@ class Test9 : public TestBase { static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; static constexpr mem_layout layout_b = mem_layout::row_major; - static constexpr mma_engine engine = mma_engine::xmx; using data_type_a = float; using data_type_b = float; using data_type_c = float; @@ -245,10 +244,10 @@ class Test10 : public TestBase { static constexpr size_t sg_m = 32; static constexpr size_t sg_n = 64; static constexpr size_t sg_k = 8; - static constexpr uint32_t global_kslicing = 2; + static constexpr uint32_t global_kslicing = 1; static constexpr uint32_t local_kslicing = 1; static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::col_major; + static constexpr mem_layout layout_b = mem_layout::row_major; using data_type_a = float; using data_type_b = float; using data_type_c = float; @@ -258,9 +257,9 @@ class Test10 : public TestBase { class Test11 : public TestBase { public: static constexpr size_t batch_size = 35; - static constexpr size_t mat_m = 4192; - static constexpr size_t mat_k = 1136; - static constexpr size_t mat_n = 688; + static constexpr size_t mat_m = 4193; + static constexpr size_t mat_k = 1134; + static constexpr size_t mat_n = 686; static constexpr size_t wg_m = 256; static constexpr size_t wg_n = 256; static constexpr size_t sg_m = 32; @@ -314,4 +313,4 @@ class result_validate { Test::layout_a, Test::layout_b); } -}; +}; \ No newline at end of file diff --git a/tests/integration/gemm/fp32/kernel_func.hpp b/tests/integration/gemm/fp32/kernel_func.hpp index 962acfe96..a83518076 100644 --- a/tests/integration/gemm/fp32/kernel_func.hpp +++ b/tests/integration/gemm/fp32/kernel_func.hpp @@ -77,6 +77,9 @@ struct fp32_gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "fp32_gemm_test_func"; } diff --git a/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt b/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt index 4bdf0a42b..6a3a155f3 100644 --- a/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt +++ b/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt @@ -1,11 +1,6 @@ get_filename_component(ProjectId ${CMAKE_CURRENT_SOURCE_DIR} NAME) string(REPLACE " " "_" ProjectId ${ProjectId}) -set(ProjectIdClient ${ProjectId}) -set(ProjectIdXe ${ProjectId}) -string(PREPEND ProjectIdClient "gemm_client_") -string(PREPEND ProjectIdXe "gemm_xe_") +string(PREPEND ProjectIdClient "gemm_") -FILE(GLOB src_client main_client.cpp) -add_integration_test(${ProjectIdClient} ${src_client}) -FILE(GLOB src_xe main_xe.cpp) -add_integration_test(${ProjectIdXe} ${src_xe}) +FILE(GLOB src main.cpp) +add_integration_test(${ProjectId} ${src}) diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main.cpp similarity index 100% rename from tests/integration/gemm/int4_dequantization_bias/main_client.cpp rename to tests/integration/gemm/int4_dequantization_bias/main.cpp diff --git a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp deleted file mode 100644 index 0cc9d8d6f..000000000 --- a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp +++ /dev/null @@ -1,641 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2022-2023 Intel Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - *******************************************************************************/ - -#include -#include "xetla.hpp" -// #define UT_DEBUG 1 -using namespace gpu::xetla; -// The number of times the kernel is executed -constexpr int ITER = 100; - -class test1 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 16384; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class test2 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 22016; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 128; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 128; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 4; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv1 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 12288; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv2 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv3 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 11008; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv4 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 11008; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 32; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 4; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv5 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 8; - static constexpr size_t mat_n = 151936; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv6 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 12288; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv7 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv8 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 11008; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv9 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 4096; - static constexpr size_t mat_k = 11008; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 4; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; -class qkv10 { - public: - // Extract the parameters required by different test cases - static constexpr size_t mat_m = 1; - static constexpr size_t mat_n = 151936; - static constexpr size_t mat_k = 4096; - static constexpr size_t wg_m = 8; - static constexpr size_t wg_n = 64; - static constexpr size_t sg_m = 8; - static constexpr size_t sg_n = 16; - static constexpr size_t sg_k = 16; - static constexpr size_t dequant_s = 64; - static constexpr size_t num_buffer = 64; - static constexpr size_t local_kslicing = 8; - static constexpr size_t global_kslicing = 1; - static constexpr mem_layout layout_a = mem_layout::row_major; - static constexpr mem_layout layout_b = mem_layout::row_major; - using data_type_a = fp16; - using data_type_b = int4x2; - using data_type_c = fp16; -}; - -template < - typename data_type_a, - typename data_type_b, - typename data_type_c, - typename data_type_acc = float, - typename data_type_bias = data_type_a> -int gemm_result_validate( - data_type_a* A, - data_type_b* B, - data_type_c* C, - data_type_bias* bias, - uint32_t m, - uint32_t k, - uint32_t n, - mem_layout mem_layout_a_ = mem_layout::row_major, - mem_layout mem_layout_b_ = mem_layout::row_major) { - buff_cmp::buff_vals data(C, m, n, n); - std::vector gold_C(m * n, 0); - get_gemm_gold( - m, n, k, mem_layout_a_, mem_layout_b_, A, B, gold_C.data()); - - // BiasAdd - for (uint32_t i = 0; i < gold_C.size(); ++i) { - uint32_t col = i % n; - gold_C[i] += bias[col]; - } - - buff_cmp::buff_vals other(gold_C.data(), m, n, n); - - bool result = buff_cmp::xetla_buff_cmp(data, other, "gemm validation"); - - std::cout << (!result ? "FAILED\n" : "PASSED\n"); - return result ? 0 : 1; -} - -template -void dequantize_gemm_run(int iter) { - using namespace gpu; - // Accept incoming parameters - constexpr size_t matrix_m = Test::mat_m; - constexpr size_t matrix_n = Test::mat_n; - constexpr size_t matrix_k = Test::mat_k; - constexpr uint32_t global_kslicing = Test::global_kslicing; - constexpr uint32_t local_kslicing = Test::local_kslicing; - static constexpr mem_layout layout_b = Test::layout_b; - constexpr size_t wg_tile_m = Test::wg_m; - constexpr size_t wg_tile_n = Test::wg_n; - constexpr size_t sg_tile_m = Test::sg_m; - constexpr size_t sg_tile_n = Test::sg_n; - constexpr size_t sg_tile_k = Test::sg_k; - constexpr size_t dequant_s = Test::dequant_s; - using data_type_a = typename Test::data_type_a; - using data_type_b = typename Test::data_type_b; - using data_type_c = typename Test::data_type_c; - using data_type_zero_pt = int4x2; - using data_type_scale = fp16; - using data_type_acc_in = fp16; - using data_type_acc = float; - using data_type_bias = fp16; - - constexpr size_t size_a = matrix_m * matrix_k; - constexpr size_t size_b = matrix_k * matrix_n / 2; - - constexpr size_t size_scale_m = matrix_k / dequant_s; - constexpr size_t size_scale_n = matrix_n; - constexpr size_t size_scale = size_scale_m * size_scale_n; - - constexpr size_t size_zero_pt_m = matrix_k / dequant_s; - constexpr size_t size_zero_pt_n = matrix_n / 2; - constexpr size_t size_zero_pt = size_zero_pt_m * size_zero_pt_n; - - constexpr size_t size_c = matrix_m * matrix_n; - constexpr size_t size_bias = matrix_n; - - // Turn on the enable_profiling property to facilitate subsequent profiling - sycl::property_list properties{sycl::property::queue::enable_profiling()}; - auto queue = sycl::queue(properties); - auto context = queue.get_info(); - auto device = queue.get_info(); - - std::cout << "Running on " << device.get_info() << "\n"; - - using tile_shape = - xetla::group::tile_shape_t; - static constexpr uint32_t periodic_sync_interval = 0; - static constexpr uint32_t prefetch_distance = 0; - - using mem_desc_a_t = xetla::mem_desc_t< - data_type_a, - mem_layout::row_major, - mem_space::global, - DEVICE_MEM_ALIGNMENT / sizeof(data_type_a)>; - using mem_desc_b_t = xetla::mem_desc_t< - data_type_b, - layout_b, - mem_space::global, - DEVICE_MEM_ALIGNMENT / sizeof(data_type_b)>; - using mem_desc_c_t = xetla::mem_desc_t< - data_type_c, - mem_layout::row_major, - mem_space::global, - DEVICE_MEM_ALIGNMENT / sizeof(data_type_c)>; - - using mem_desc_bias_t = xetla::mem_desc_t< - data_type_bias, - mem_layout::row_major, - mem_space::global, - DEVICE_MEM_ALIGNMENT / sizeof(data_type_bias)>; - - using compute_attr = xetla::group:: - compute_attr_t; - using perf_tuning_knob = xetla::group:: - perf_tuning_knob_t; - static constexpr quant_info quant_info{ - quant_mode::I4_SYM, Test::dequant_s, layout_b}; - - using compute_policy = xetla::group::compute_policy_int4_dequantize< - compute_attr, - perf_tuning_knob, - data_type_scale, - data_type_zero_pt, - quant_info, - mma_engine::xmx, - gpu_arch::XeHpc>; - - using gemm_t = xetla::group:: - gemm_t; - - using bias_op_t = - gpu::xetla::subgroup::bias_add_op_t; - using tile_op_t = gpu::xetla::subgroup::chained_tile_op_t; - - using epilogue_t = xetla::group::epilogue_t< - xetla::group::epilogue_policy_tile_op, - tile_shape, - mem_desc_c_t>; - - using group_swizzle = xetla::kernel::group_swizzle_default; - using gemm_op_t = xetla::kernel::gemm_universal_t< - gpu::xetla::kernel::dispatch_policy_int4_dequantize_kslicing< - group_swizzle, - global_kslicing, - local_kslicing>, - gemm_t, - epilogue_t>; - - size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n); - size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n); - - // Define and initialize the data required for the calculation - auto* A_h = static_cast( - malloc_host(size_a * sizeof(data_type_a), context)); - auto* B_h = static_cast( - malloc_host(size_b * sizeof(data_type_b), context)); - auto* C_h = static_cast( - malloc_host(size_c * sizeof(data_type_c), context)); - auto* Acc_h = static_cast( - malloc_host(size_acc * sizeof(data_type_acc), context)); - auto* Cnt_h = - static_cast(malloc_host(size_cnt * sizeof(uint32_t), context)); - auto* scale_h = static_cast( - malloc_host(size_scale * sizeof(data_type_scale), context)); - auto* zero_pt_h = static_cast( - malloc_host(size_zero_pt * sizeof(data_type_zero_pt), context)); - auto* bias_h = static_cast( - malloc_host(size_bias * sizeof(data_type_bias), context)); - - auto* A_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_a * sizeof(data_type_a), device, context)); - auto* B_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_b * sizeof(data_type_b), device, context)); - auto* C_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_c * sizeof(data_type_c), device, context)); - auto* Acc_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_acc * sizeof(data_type_acc), device, context)); - auto* Cnt_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, size_cnt * sizeof(uint32_t), device, context)); - auto* scale_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, - size_scale * sizeof(data_type_scale), - device, - context)); - auto* zero_pt_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, - size_zero_pt * sizeof(data_type_zero_pt), - device, - context)); - auto* bias_d = static_cast(aligned_alloc_device( - DEVICE_MEM_ALIGNMENT, - size_bias * sizeof(data_type_bias), - device, - context)); - - for (unsigned i = 0; i < size_a; ++i) { - A_h[i] = random_float(); -#ifdef UT_DEBUG - A_h[i] = 1.f; -#endif - } - for (unsigned i = 0; i < size_b; ++i) { - B_h[i] = uint8_t(random_uint8()); -#ifdef UT_DEBUG - B_h[i] = 153; -#endif - } - for (unsigned i = 0; i < size_scale; ++i) { - scale_h[i] = random_float(); -#ifdef UT_DEBUG - scale_h[i] = 1.f; -#endif - } - for (unsigned i = 0; i < size_zero_pt; ++i) { - zero_pt_h[i] = 0.f; - } - for (unsigned i = 0; i < size_c; ++i) { - C_h[i] = 0; - } - for (unsigned i = 0; i < size_acc; ++i) { - Acc_h[i] = 0; - } - for (unsigned i = 0; i < size_cnt; ++i) { - Cnt_h[i] = 0; - } - for (unsigned i = 0; i < size_bias; ++i) { - bias_h[i] = random_float(); -#ifdef UT_DEBUG - bias_h[i] = 0.f; -#endif - } - - queue.memcpy((void*)A_d, (void*)A_h, size_a * sizeof(data_type_a)).wait(); - queue.memcpy((void*)B_d, (void*)B_h, size_b * sizeof(data_type_b)).wait(); - queue.memcpy((void*)C_d, (void*)C_h, size_c * sizeof(data_type_c)).wait(); - queue.memcpy((void*)Acc_d, (void*)Acc_h, size_acc * sizeof(data_type_acc)) - .wait(); - queue.memcpy((void*)Cnt_d, (void*)Cnt_h, size_cnt * sizeof(uint32_t)).wait(); - queue - .memcpy( - (void*)scale_d, (void*)scale_h, size_scale * sizeof(data_type_scale)) - .wait(); - queue - .memcpy( - (void*)zero_pt_d, - (void*)zero_pt_h, - size_zero_pt * sizeof(data_type_zero_pt)) - .wait(); - queue.memcpy((void*)bias_d, (void*)bias_h, size_bias * sizeof(data_type_bias)) - .wait(); - - // set up gemm arguments - typename bias_op_t::shape_t bias_add_shape(matrix_n, 1, matrix_n); - using epilogue_args_t = epilogue_t::arguments_t; - - epilogue_args_t epilogue_args( - {// epilogue_args init list - // It accepts the base pointer to matrix D, and its dimensions - {bias_d, bias_add_shape}}); - - typename gemm_op_t::template arguments_t gemm_arg( - matrix_m, - matrix_k, - matrix_n, - A_d, - matrix_k, - B_d, - matrix_n, - C_d, - matrix_n, - scale_d, - matrix_n, - Acc_d, - Cnt_d, - epilogue_args); - - cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg); - if (!gemm_op_t::can_implement(gemm_arg)) { - std::cout << "The arguments cannot be supported, aborting ... " - << std::endl; - FAIL(); - } - - size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n; - profiling_helper prof("dequantize_gemm", ops, "gflops"); - try { - for (int i = 0; i < iter; i++) { - prof.cpu_start(); - auto e_esimd = queue.submit([&](handler& cgh) { - cgh.parallel_for( - nd_range, [=](nd_item<3> item) SYCL_ESIMD_KERNEL { - // allocate slm and nbarrier resource - slm_barrier_init(); - gemm_op_t gemm_op; - gemm_op(item, gemm_arg); - }); - }); - e_esimd.wait(); - prof.cpu_end(); - prof.add_gpu_event(e_esimd); - } - } catch (cl::sycl::exception const& e) { - std::cout << "SYCL exception caught: " << e.what() << '\n'; - FAIL(); - } - - // performance - prof.print_profiling_result(profiling_selector::GPU); - - std::vector dequantize_b(matrix_k * matrix_n, 0); - for (uint32_t i = 0; i < matrix_k / dequant_s; i++) { - for (uint32_t j = 0; j < matrix_n / 2; j++) { - int start_in = i * dequant_s * matrix_n / 2 + j; - int start_out = i * dequant_s * matrix_n + j * 2; - int start_scale = i * size_scale_n + j * 2; - for (uint32_t ii = 0; ii < dequant_s; ii++) { - uint8_t data_in = B_h[start_in + ii * matrix_n / 2]; - int8_t data_0 = int8_t(data_in & 0x0f) - 8; - int8_t data_1 = int8_t(data_in >> 4) - 8; - dequantize_b[start_out + ii * matrix_n] = - fp16(data_0) * scale_h[start_scale]; - dequantize_b[start_out + ii * matrix_n + 1] = - fp16(data_1) * scale_h[start_scale + 1]; - } - } - } - - queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait(); - ASSERT_EQ( - 0, - gemm_result_validate( - A_h, dequantize_b.data(), C_h, bias_h, matrix_m, matrix_k, matrix_n)); - - free(A_h, context); - free(B_h, context); - free(C_h, context); - free(scale_h, context); - free(zero_pt_h, context); - free(A_d, context); - free(B_d, context); - free(C_d, context); - free(scale_d, context); - free(zero_pt_d, context); - free(Acc_h, context); - free(Cnt_h, context); - free(Acc_d, context); - free(Cnt_d, context); -} - -template -class dequantize_gemm_test : public ::testing::Test {}; -TYPED_TEST_SUITE_P(dequantize_gemm_test); - -TYPED_TEST_P(dequantize_gemm_test, esimd) { - dequantize_gemm_run(ITER); -} - -REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_test, esimd); -using tests = ::testing::Types; -// using tests = ::testing::Types; - -INSTANTIATE_TYPED_TEST_SUITE_P( - dequantize_gemm_test_suite, - dequantize_gemm_test, - tests); diff --git a/tests/integration/gemm/int8/kernel_func.hpp b/tests/integration/gemm/int8/kernel_func.hpp index 3a9e595ca..da2eab012 100644 --- a/tests/integration/gemm/int8/kernel_func.hpp +++ b/tests/integration/gemm/int8/kernel_func.hpp @@ -72,6 +72,9 @@ struct int8gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "int8gemm_test_func"; } diff --git a/tests/integration/gemm/tf32/kernel_func.hpp b/tests/integration/gemm/tf32/kernel_func.hpp index 8c2850b9b..42d69eb51 100644 --- a/tests/integration/gemm/tf32/kernel_func.hpp +++ b/tests/integration/gemm/tf32/kernel_func.hpp @@ -71,6 +71,9 @@ struct tf32_gemm_test_func { using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "tf32_gemm_test_func"; } diff --git a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp index d45ddc0b7..d5534ebdf 100644 --- a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp +++ b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp @@ -68,13 +68,20 @@ struct unaligned_gemm_test_func { using epilogue_t = epilogue_t< epilogue_policy_unaligned, tile_shape, - mem_desc_t>; + mem_desc_t< + dtype_c, + mem_layout::row_major, + mem_space::global, + ldc_alignment>>; using group_swizzle = gpu::xetla::kernel::group_swizzle_default; using dispatch_policy = dispatch_policy_kslicing; using gemm_op_t = gemm_universal_t; + static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count(); + static constexpr uint32_t slm_size = gemm_op_t::get_slm_size(); + static const char* func_name() { return "unaligned_gemm_test_func"; } diff --git a/tests/integration/gemm/unaligned_bf16/main.cpp b/tests/integration/gemm/unaligned_bf16/main.cpp index 9ceaf695e..3d23c78ce 100644 --- a/tests/integration/gemm/unaligned_bf16/main.cpp +++ b/tests/integration/gemm/unaligned_bf16/main.cpp @@ -31,10 +31,7 @@ TYPED_TEST_P(unaligned_gemm_test, esimd) { gemm_exec< TypeParam, result_validate, - unaligned_gemm_func, - unaligned_gemm_func::gemm_op_t::get_slm_size(), - unaligned_gemm_func::gemm_op_t::get_barrier_count()>( - esimd_compile_string); + unaligned_gemm_func>(esimd_compile_string); } REGISTER_TYPED_TEST_SUITE_P(unaligned_gemm_test, esimd); using tests = ::testing::Types< diff --git a/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp b/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp index bc7f63af9..8c47795fc 100644 --- a/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp +++ b/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp @@ -188,7 +188,7 @@ class global_sum_reduce_two_mat_t { mat_zero, matAcc1_payload); subgroup::tile_store( mat_zero, matAcc2_payload); - SW_BARRIER(); + sw_barrier(); } } }; diff --git a/tests/integration/vector_add/int32_1d/kernel_func.hpp b/tests/integration/vector_add/int32_1d/kernel_func.hpp index bca6071d1..3dbbbac47 100644 --- a/tests/integration/vector_add/int32_1d/kernel_func.hpp +++ b/tests/integration/vector_add/int32_1d/kernel_func.hpp @@ -36,7 +36,7 @@ KERNEL_FUNC inline void vector_add_func( /// use block prefetch for b xetla_prefetch_global( b, offset); - SW_BARRIER(); + sw_barrier(); /// use scattered load for a xetla_vector ivector1 = xetla_load_global< dtype, diff --git a/tests/integration/vector_add/tf32_1d/kernel_func.hpp b/tests/integration/vector_add/tf32_1d/kernel_func.hpp index f1a6f177d..79872dec3 100644 --- a/tests/integration/vector_add/tf32_1d/kernel_func.hpp +++ b/tests/integration/vector_add/tf32_1d/kernel_func.hpp @@ -36,7 +36,7 @@ KERNEL_FUNC inline void vector_add_func( /// use block prefetch for b xetla_prefetch_global( b, offset); - SW_BARRIER(); + sw_barrier(); /// use scattered load for a xetla_vector ivector1 = xetla_load_global< dtype, diff --git a/tests/unit/block_load_store/kernel_func.hpp b/tests/unit/block_load_store/kernel_func.hpp index 17465f08c..3c80ed99d 100644 --- a/tests/unit/block_load_store/kernel_func.hpp +++ b/tests/unit/block_load_store/kernel_func.hpp @@ -44,7 +44,7 @@ struct block_load_store_func { cache_hint::cached, cache_hint::cached, arch_tag>(src_tdesc); - SW_BARRIER(); + sw_barrier(); xetla_vector A_load_vec = xetla_tload_global< dtype, bwidth * bheight, diff --git a/tests/unit/tile_load_store/kernel_func.hpp b/tests/unit/tile_load_store/kernel_func.hpp index 89832a46f..cdd4a59f2 100644 --- a/tests/unit/tile_load_store/kernel_func.hpp +++ b/tests/unit/tile_load_store/kernel_func.hpp @@ -242,7 +242,7 @@ struct tile_load_store_atomic_func { matBias.reg = matA.reg; matA.reg = 0; tile_store(matA, payload_store); - SW_BARRIER(); + sw_barrier(); tile_store(matBias, payload_store_add, check_tag); } }; diff --git a/tests/unit/tile_load_store/main.cpp b/tests/unit/tile_load_store/main.cpp index a08f71d57..bc24278b7 100644 --- a/tests/unit/tile_load_store/main.cpp +++ b/tests/unit/tile_load_store/main.cpp @@ -271,34 +271,34 @@ TEST(tile_load_store_atomic_disable_oob_check, esimd) { false>>(nd_range, result_validate); } -TEST(tile_load_store_atomic_boundary, esimd) { - cl::sycl::nd_range<1> nd_range({1}, {1}); - auto result_validate = std::bind( - tile_load_store_result_validate, - _1, - _2, - _3, - 128, - 33554440, - 32, - 32, - 33554432); - kernel_run< - float, - tile_load_store_atomic_func< - float, - 128, - 33554440, - 128, - 32, - 32, - 16, - 16, - true>, - 128 * 1024, - 32, - 4294968320U>(nd_range, result_validate); -} +// TEST(tile_load_store_atomic_boundary, esimd) { +// cl::sycl::nd_range<1> nd_range({1}, {1}); +// auto result_validate = std::bind( +// tile_load_store_result_validate, +// _1, +// _2, +// _3, +// 128, +// 33554440, +// 32, +// 32, +// 33554432); +// kernel_run< +// float, +// tile_load_store_atomic_func< +// float, +// 128, +// 33554440, +// 128, +// 32, +// 32, +// 16, +// 16, +// true>, +// 128 * 1024, +// 32, +// 4294968320U>(nd_range, result_validate); +// } TEST(tile_load_broadcast_store, esimd) { cl::sycl::nd_range<1> nd_range({1}, {1}); @@ -318,25 +318,25 @@ TEST(tile_load_store_1d, esimd) { nd_range, result_validate); } -TEST(tile_load_store_1d_boundary, esimd) { - cl::sycl::nd_range<1> nd_range({1}, {1}); - auto result_validate = std::bind( - tile_load_store_result_validate, - _1, - _2, - _3, - 128, - 33554440, - 128, - 1, - 33554432); - kernel_run< - int, - tile_load_store_1d_func, - 128 * 1024, - 32, - 4294968320U>(nd_range, result_validate); -} +// TEST(tile_load_store_1d_boundary, esimd) { +// cl::sycl::nd_range<1> nd_range({1}, {1}); +// auto result_validate = std::bind( +// tile_load_store_result_validate, +// _1, +// _2, +// _3, +// 128, +// 33554440, +// 128, +// 1, +// 33554432); +// kernel_run< +// int, +// tile_load_store_1d_func, +// 128 * 1024, +// 32, +// 4294968320U>(nd_range, result_validate); +// } TEST(tile_load_store_unaligned_2d, esimd) { cl::sycl::nd_range<1> nd_range({1}, {1}); diff --git a/tests/unit/tile_mma/kernel_func.hpp b/tests/unit/tile_mma/kernel_func.hpp index d8130c4a5..926b20c17 100644 --- a/tests/unit/tile_mma/kernel_func.hpp +++ b/tests/unit/tile_mma/kernel_func.hpp @@ -103,9 +103,9 @@ struct tile_mma_func { matA, matA_payload); subgroup::tile_load( matB, matB_payload); - SW_BARRIER(); + sw_barrier(); tile_mma::mma(matAcc, matAcc, matB, matA); - SW_BARRIER(); + sw_barrier(); matC.reg = xetla_cvt( matAcc.reg); matC_payload.init(c, n, m, n, 0, 0); diff --git a/tests/utils/common.hpp b/tests/utils/common.hpp index 4a79de7c4..991c9cc9a 100644 --- a/tests/utils/common.hpp +++ b/tests/utils/common.hpp @@ -90,6 +90,7 @@ inline auto getTypeName() { } enum class test_result : uint8_t { complete = 0, skip = 1, fail = 2 }; +enum Direction { FWD = 0, BWDD = 1, BWDW = 2 }; template inline result_type generate_real_random( @@ -129,6 +130,28 @@ inline data_type* alloc_host(size_t size) { return host_ptr; } +template +using init_func_t = std::function; + +template +void index_init_func(data_type* data, size_t idx) { + data[idx] = static_cast(idx); +} +template +void no_init_func( + [[maybe_unused]] data_type* data, + [[maybe_unused]] size_t idx) {} + +template +void rand_init_func(data_type* data, size_t idx) { + data[idx] = static_cast(random_float() - 0.5f); +} + +template +void zero_init_func(data_type* data, size_t idx) { + data[idx] = 0; +} + template inline data_type* alloc_host_and_init( size_t size, diff --git a/tests/utils/execution.hpp b/tests/utils/execution.hpp index 85d722ce9..843420373 100644 --- a/tests/utils/execution.hpp +++ b/tests/utils/execution.hpp @@ -1,5 +1,5 @@ /******************************************************************************* - * Copyright (c) 2022-2023 Intel Corporation + * Copyright (c) 2022-2024 Intel Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,22 +20,26 @@ #include #include "common.hpp" #include "profiling.hpp" +#ifdef _WIN32 +#include "windows_functions.hpp" +#endif #include "xetla.hpp" using namespace cl::sycl; using namespace gpu; using namespace gpu::xetla; -template < - class Test, - typename validate_func, - typename KERNEL, - int SLMSIZE = arch_attr_t::local_mem_size, - int BARNUM = 32> -void gemm_exec(const std::string& compile_str, size_t batch = 1) { +template +void gemm_exec( + const std::string& compile_str, + size_t batch = 1, + size_t scaling = 1) { test_result result = test_result::complete; - using gemm_op_t = typename KERNEL::gemm_op_t; + using gemm_op_t = typename kernel_t::gemm_op_t; + + constexpr uint32_t slm_size = kernel_t::slm_size; + constexpr uint32_t barrier_num = kernel_t::barrier_count; using data_type_a = typename Test::data_type_a; using data_type_b = typename Test::data_type_b; @@ -46,6 +50,12 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { constexpr size_t matrix_n = Test::mat_n; constexpr size_t matrix_k = Test::mat_k; + [[maybe_unused]] constexpr size_t wg_tile_m = Test::wg_m; + [[maybe_unused]] constexpr size_t wg_tile_n = Test::wg_n; + [[maybe_unused]] constexpr size_t sg_tile_m = Test::sg_m; + [[maybe_unused]] constexpr size_t sg_tile_n = Test::sg_n; + [[maybe_unused]] constexpr size_t sg_tile_k = Test::sg_k; + size_t size_a = matrix_m * matrix_k; size_t size_b = matrix_k * matrix_n; size_t size_c = matrix_m * matrix_n; @@ -59,16 +69,16 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { auto A = alloc_device_and_init( batch * size_a, - [](data_type_a* data, size_t idx) { - data[idx] = static_cast(random_float()); + [&scaling](data_type_a* data, size_t idx) { + data[idx] = static_cast(scaling * (random_float() - 0.5f)); }, queue, device, context); auto B = alloc_device_and_init( batch * size_b, - [](data_type_b* data, size_t idx) { - data[idx] = static_cast(random_float()); + [&scaling](data_type_b* data, size_t idx) { + data[idx] = static_cast(scaling * (random_float() - 0.5f)); }, queue, device, @@ -81,6 +91,7 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { queue, device, context); + size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n); size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n); auto Acc = alloc_device_and_init( @@ -97,20 +108,20 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { queue, device, context); - - size_t ops = 2 * matrix_m * matrix_n * matrix_k; + long ops = 2 * static_cast(matrix_m) * matrix_n * matrix_k; profiling_helper prof("gemm", ops, "gflops"); - try { std::vector kernelId = {get_kernel_id()}; auto inputBundle = get_kernel_bundle(context, kernelId); - static const std::string env_set_str = - "SYCL_PROGRAM_COMPILE_OPTIONS=" + compile_str; - putenv(const_cast(env_set_str.c_str())); + char* value = getenv("GOGRITS"); + if (value == NULL || strcmp(value, "on") != 0) { + setenv("SYCL_PROGRAM_COMPILE_OPTIONS", compile_str.c_str(), 1); + } kernel_bundle exeBundle = build(inputBundle); - static const std::string env_unset_str = "SYCL_PROGRAM_COMPILE_OPTIONS="; - putenv(const_cast(env_unset_str.c_str())); + if (value == NULL || strcmp(value, "on") != 0) { + unsetenv("SYCL_PROGRAM_COMPILE_OPTIONS"); + } using namespace gpu::xetla::group; using namespace gpu::xetla::kernel; @@ -130,10 +141,10 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { nullptr); cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(arg); - int constexpr warm_up = 10; int constexpr iters = 100; for (size_t i = 0; i < batch; i++) { + prof.cpu_start(); auto A_ptr = A + i * size_a; auto B_ptr = B + i * size_b; auto C_ptr = C + i * size_c; @@ -157,11 +168,13 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { prof.cpu_start(); } auto e_esimd = queue.submit([&](handler& cgh) { - cgh.use_kernel_bundle(exeBundle); + if (value == NULL || strcmp(value, "on") != 0) { + cgh.use_kernel_bundle(exeBundle); + } cgh.parallel_for(nd_range, [=](nd_item<3> item) KERNEL_MAIN { - gpu::xetla::xetla_local_init(); - gpu::xetla::xetla_nbarrier_init(); - KERNEL::run( + gpu::xetla::xetla_local_init(); + gpu::xetla::xetla_nbarrier_init(); + kernel_t::run( item, A_ptr, B_ptr, @@ -184,9 +197,7 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { std::cout << "SYCL exception caught: " << e.what() << '\n'; result = test_result::fail; } - - // performance - prof.print_profiling_result(profiling_selector::GPU); + unsetenv("SYCL_PROGRAM_COMPILE_OPTIONS"); // validation if (result == test_result::complete) { validate_func vfunc; @@ -204,6 +215,7 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { } else if (result != test_result::complete) { FAIL(); } + prof.print_profiling_result(profiling_selector::GPU); } /// @brief The template function to execute kernel in esimd way for unit test @@ -211,54 +223,41 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) { /// /// @tparam data_type data_type The data type of buffer used in kernel and /// buffer allocation -/// @tparam KERNEL the kernel function struct +/// @tparam kernel_t the kernel function struct /// @param nd_range the range of workitems -/// @param validate_result validation function, taking 3 parameters buffer A, B -/// as input C as output +/// @param validate_result validation function, taking 3 parameters buffer A, +/// B as input C as output /// template < typename data_type, - class KERNEL, - size_t SLMSIZE = 8 * 1024, - size_t BARNUM = 32, - size_t Size = 4096> -void kernel_run(auto nd_range, auto validate_result) { + typename kernel_t, + size_t slm_size = 8 * 1024, + size_t barrier_num = 32, + size_t size = 4096> +void kernel_run( + auto nd_range, + auto validate_result, + init_func_t init_func_a = index_init_func, + init_func_t init_func_b = index_init_func, + init_func_t init_func_c = no_init_func) { queue queue{}; auto context = queue.get_info(); auto device = queue.get_info(); std::cout << "Running on " << device.get_info() << "\n"; auto A = alloc_device_and_init( - Size, - [](data_type* data, size_t idx) { - data[idx] = static_cast(idx); - }, - queue, - device, - context); + size, init_func_a, queue, device, context); auto B = alloc_device_and_init( - Size, - [](data_type* data, size_t idx) { - data[idx] = static_cast(idx); - }, - queue, - device, - context); + size, init_func_b, queue, device, context); auto C = alloc_device_and_init( - Size, - [](data_type* data, size_t idx) { - data[idx] = static_cast(idx); - }, - queue, - device, - context); + size, init_func_c, queue, device, context); try { auto e_esimd = queue.submit([&](handler& cgh) { cgh.parallel_for<>(nd_range, [=](nd_item<1> ndi) KERNEL_MAIN { - gpu::xetla::xetla_local_init(); - gpu::xetla::xetla_nbarrier_init(); - KERNEL::run(&ndi, A, B, C); + gpu::xetla::xetla_local_init(); + gpu::xetla::xetla_nbarrier_init(); + kernel_t::run(&ndi, A, B, C); }); }); e_esimd.wait(); @@ -267,9 +266,9 @@ void kernel_run(auto nd_range, auto validate_result) { FAIL(); } - auto A_host = alloc_host_and_copy(A, Size, queue); - auto B_host = alloc_host_and_copy(B, Size, queue); - auto C_host = alloc_host_and_copy(C, Size, queue); + auto A_host = alloc_host_and_copy(A, size, queue); + auto B_host = alloc_host_and_copy(B, size, queue); + auto C_host = alloc_host_and_copy(C, size, queue); ASSERT_EQ(0, validate_result(A_host, B_host, C_host));