diff --git a/examples/09_gate_recurrent_unit/kernel_func.hpp b/examples/09_gate_recurrent_unit/kernel_func.hpp
index 0dd76188a..dd5fb4abc 100644
--- a/examples/09_gate_recurrent_unit/kernel_func.hpp
+++ b/examples/09_gate_recurrent_unit/kernel_func.hpp
@@ -86,7 +86,7 @@ struct fused_config_t {
       {start_x_b, start_y_b});                                   \
   gemm_args.init(mem_desc_a, mem_desc_b, inner_loop_count_##id); \
   op(g, matAcc_##acc_id, gemm_args);                             \
-  SW_BARRIER();
+  sw_barrier();
 
 #define MATC_STORE(ptr_c)                                               \
   mem_desc_c.init(                                                      \
@@ -229,7 +229,7 @@ struct gru_layer {
         int start_n = (j)*wg_tile_n;
         CONFIG_SETTING(batch_size, -1, hidden_size);
         matAcc_0.init(0);
-        SW_BARRIER();
+        sw_barrier();
 
         // calculate reset gate: r_t = \sigmoid(X_t x W_ir + h_{t - 1} x W_hr)
         // acc0 = X_t x W_ir
@@ -278,19 +278,19 @@ struct gru_layer {
         matAcc_0.reg = matAcc_0.reg * (1 - matAcc_1.reg) +
             matAcc_1.reg *
                 xetla_cvt<Act_T, T, matAcc_t::tile_elems>(mat_hidden.reg);
-        SW_BARRIER();
+        sw_barrier();
 
         if (seq_id == seq_len - 1) {
           MATC_STORE(args->layer_output);
-          SW_BARRIER();
+          sw_barrier();
           __esimd_barrier();
         }
         MATC_STORE(args->cell_out_ptr + seq_id * io_size);
-        SW_BARRIER();
+        sw_barrier();
         __esimd_barrier();
 
         MATC_STORE(args->one_cell_ptr + (seq_id % 2) * io_size);
-        SW_BARRIER();
+        sw_barrier();
         __esimd_barrier();
       }
       args->hx_ptr = args->one_cell_ptr + (seq_id % 2) * io_size;
@@ -386,7 +386,7 @@ struct kernel_xcoder_gru_fusion {
     args.W_hz_ptr = (W_hz_ptr);
     args.W_in_ptr = (W_in_ptr);
     args.W_hn_ptr = (W_hn_ptr);
-    SW_BARRIER();
+    sw_barrier();
     fused_op::call(item, &args);
     ping = (ping + 1) % 2;
     pong = (pong + 1) % 2;
@@ -411,7 +411,7 @@ struct kernel_xcoder_gru_fusion {
           ? hidden_out_ptr
           : (ping_pong_buffer + ping * one_layer_size);
       args.layer_ptr = ((ping_pong_buffer + pong * one_layer_size));
-      SW_BARRIER();
+      sw_barrier();
       fused_op::call(item, &args);
       ping = (ping + 1) % 2;
       pong = (pong + 1) % 2;
diff --git a/include/common/core/arch_config.hpp b/include/common/core/arch_config.hpp
index 315e5d01c..b6d14a148 100644
--- a/include/common/core/arch_config.hpp
+++ b/include/common/core/arch_config.hpp
@@ -35,16 +35,41 @@ template <>
 struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
   /// HW limitation checks https://gfxspecs.intel.com/Predator/Home/Index/55490
   static constexpr bool has_hw_block_2d = true;
+  // If Transposed and Transformed are both set to false
+  // BlockHeight must not exceed 32.
   static constexpr uint32_t max_load_height_in_elem = 32;
+
+  // BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for
+  // dwords, and 8 for qwords.
   static constexpr uint32_t max_load_width_in_bytes = 64;
+
+  // If Transposed is true then
+  // BlockWidth must be 1,2,4 for qwords and be in range [1..8] for dwords.
   static constexpr uint32_t max_trans_load_width_in_bytes = 32;
+
+  // BlockHeight must be 8 for qwords and be in range [1..32] for dwords.
+  static constexpr uint32_t max_trans_load_height_in_elem = 32;
+
+  // If Transformed is true
+  // BlockWidth must be in range [4..16] for bytes and [2..16] for word.
   static constexpr uint32_t max_vnni_load_width_in_elems = 16;
+
+  // BlockHeight must be in range [4..32] for bytes and [2..32] for words.
   static constexpr uint32_t min_vnni_load_height_in_bytes = 4;
 
+  // BlockHeight must not exceed 8.
   static constexpr uint32_t max_store_height_in_elem = 8;
+
+  // BlockWidth must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8
+  // for qwords.
   static constexpr uint32_t max_store_width_in_bytes = 64;
 
+  // BlockHeight must not exceed 32.
+  // BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for
+  // dwords, and 8 for qwords.
   static constexpr uint32_t max_load_size_in_bytes = 2048;
+
+  // BlockWidth * BlockHeight * sizeof(T) must not exceed 512.
   static constexpr uint32_t max_store_size_in_bytes = 512;
 
   static constexpr uint32_t special_prefetch_width_in_bytes = 64;
@@ -97,7 +122,7 @@ struct load_store_attr_t<msg_type::block_1d, arch_tag> {
   static constexpr uint32_t max_aligned_load_vec_len = 256;
   static constexpr uint32_t max_store_vec_len = 256;
   static constexpr uint32_t max_aligned_store_vec_len = 256;
-  static constexpr uint32_t max_prefetch_vec_len = 32;
+  static constexpr uint32_t max_prefetch_vec_len = 256;
   static constexpr uint32_t max_channel_num = 16;
 };
 
diff --git a/include/common/core/barrier.hpp b/include/common/core/barrier.hpp
index 4ad0eaa69..3956798ed 100644
--- a/include/common/core/barrier.hpp
+++ b/include/common/core/barrier.hpp
@@ -26,6 +26,16 @@ namespace gpu::xetla {
 /// @addtogroup xetla_core_barrier
 /// @{
 
+/// sw_barrier, insert software scheduling barrier, for better code control
+///
+
+void sw_barrier() {
+#if __INTEL_LLVM_COMPILER >= 20250000
+#else
+  __ESIMD_NS::fence<__ESIMD_NS::fence_mask::sw_barrier>();
+#endif
+}
+
 /// @brief Initialize the number of named barrier index for a kernel.
 /// Available only on PVC. Only need to initialize once at the beginning.
 ///
diff --git a/include/common/core/memory.hpp b/include/common/core/memory.hpp
index 93bedbfe0..e5ff7bf17 100644
--- a/include/common/core/memory.hpp
+++ b/include/common/core/memory.hpp
@@ -320,6 +320,53 @@ __XETLA_API void xetla_prefetch_global(
 #endif
 }
 
+/// 2D USM pointer block prefetch.
+/// Supported platforms: PVC
+/// VISA instruction: lsc_load_block2d.ugm
+///
+/// Prefetches elements located at specified address.
+///
+/// @tparam T is element type.
+/// @tparam BlockWidth is the block width in number of elements.
+/// @tparam BlockHeight is the block height in number of elements.
+/// @tparam NBlocks is the number of blocks.
+/// @tparam L1H is L1 cache hint.
+/// @tparam L2H is L2 cache hint.
+/// @tparam N is the data size
+/// @param Ptr is the surface base address for this operation.
+/// @param SurfaceWidth is the surface width minus 1 in bytes
+/// @param SurfaceHeight is the surface height minus 1 in rows
+/// @param SurfacePitch is the surface pitch minus 1 in bytes
+/// @param X is zero based X-coordinate of the left upper rectangle corner in
+/// number of elements.
+/// @param Y is zero based Y-coordinate of the left upper rectangle corner in
+/// rows.
+///
+template <
+    typename T,
+    int BlockWidth,
+    int BlockHeight = 1,
+    int NBlocks = 1,
+    cache_hint L1H = cache_hint::none,
+    cache_hint L2H = cache_hint::none,
+    int N = __ESIMD_ENS::detail::get_lsc_block_2d_data_size<
+        T,
+        NBlocks,
+        BlockHeight,
+        BlockWidth,
+        false,
+        false>()>
+__XETLA_API void xetla_prefetch_global(
+    const T* Ptr,
+    unsigned SurfaceWidth,
+    unsigned SurfaceHeight,
+    unsigned SurfacePitch,
+    int X,
+    int Y) {
+  return __ESIMD_ENS::lsc_prefetch_2d(
+      Ptr, SurfaceWidth, SurfaceHeight, SurfacePitch, X, Y);
+}
+
 /// template <typename T, int VS = 1, typename OffsetT,
 ///           typename PropertyListT = empty_properties_t>
 /// void prefetch(const T *p, OffsetT byte_offset,
@@ -358,14 +405,102 @@ __XETLA_API void xetla_prefetch_global(T* p, uint64_t byte_offset = 0) {
 #endif
 }
 
+/// 2D USM pointer block load.
+/// Supported platforms: PVC
+/// VISA instruction: lsc_load_block2d.ugm
+///
+/// Collects elements located at specified address and returns them
+/// as a single \ref simd object.
+///
+/// @tparam T is element type.
+/// @tparam BlockWidth is the block width in number of elements.
+/// @tparam BlockHeight is the block height in number of elements.
+/// @tparam NBlocks is the number of blocks.
+/// @tparam Transposed is the transposed version or not.
+/// @tparam Transformed is apply VNNI transform or not.
+/// @tparam L1H is L1 cache hint.
+/// @tparam L2H is L2 cache hint.
+/// @tparam N is the data size
+/// @param Ptr is the surface base address for this operation.
+/// @param SurfaceWidth is the surface width minus 1 in bytes
+/// @param SurfaceHeight is the surface height minus 1 in rows
+/// @param SurfacePitch is the surface pitch minus 1 in bytes
+/// @param X is zero based X-coordinate of the left upper rectangle corner in
+/// number of elements.
+/// @param Y is zero based Y-coordinate of the left upper rectangle corner in
+/// rows.
+/// @return is a vector of type T and size N, where N is
+///  BlockWidth * BlockHeight * NBlocks, if transformed;
+///  otherwise,
+///  N = roundUpNextMultiple(BlockHeight, 4 / sizeof(T)) *
+///   getNextPowerOf2(BlockWidth) * NBlocks
+///
+template <
+    typename T,
+    int BlockWidth,
+    int BlockHeight = 1,
+    int NBlocks = 1,
+    bool Transposed = false,
+    bool Transformed = false,
+    cache_hint L1H = cache_hint::none,
+    cache_hint L2H = cache_hint::none,
+    int N = __ESIMD_ENS::detail::get_lsc_block_2d_data_size<
+        T,
+        NBlocks,
+        BlockHeight,
+        BlockWidth,
+        Transposed,
+        Transformed>()>
+__XETLA_API xetla_vector<T, N> xetla_load_global(
+    const T* Ptr,
+    size_t SurfaceWidth,
+    size_t SurfaceHeight,
+    size_t SurfacePitch,
+    int X,
+    int Y) {
+  if constexpr (std::is_same_v<T, bf16>) {
+    auto ret = xetla_load_global<
+        fp16,
+        BlockWidth,
+        BlockHeight,
+        NBlocks,
+        Transposed,
+        Transformed,
+        L1H,
+        L2H>(
+        reinterpret_cast<const fp16*>(Ptr),
+        SurfaceWidth,
+        SurfaceHeight,
+        SurfacePitch,
+        X,
+        Y);
+    return ret.xetla_format<T>();
+  } else if constexpr (BlockWidth * sizeof(T) < sizeof(uint32_t)) {
+    xetla_vector<uint32_t, BlockHeight> byte_offsets =
+        xetla_vector_gen<uint32_t, BlockHeight>(0, SurfacePitch);
+    return xetla_load_global<T, N, BlockWidth, L1H, L2H>(Ptr, byte_offsets);
+  } else {
+    return __ESIMD_ENS::lsc_load_2d<
+        T,
+        BlockWidth,
+        BlockHeight,
+        NBlocks,
+        Transposed,
+        Transformed,
+        gpu::xetla::detail::get_cache_hint(L1H),
+        gpu::xetla::detail::get_cache_hint(L2H),
+        N>(Ptr, SurfaceWidth - 1, SurfaceHeight - 1, SurfacePitch - 1, X, Y);
+  }
+}
+
 /// simd<T, N> block_load(const T* ptr, size_t byte_offset,
 ///                       props={});  // (usm-bl-2)
 /// This function loads a contiguous memory block from address referenced
 /// by USM pointer \p ptr and the given \p byte_offset.
 ///
 /// There may be temporary restrictions depending on L1, L2 cache hints,
-/// See details in the 'Restrictions' section below. The restrictions will be
-/// relaxed in the future.
+/// See details in the 'Restrictions' section below. The restrictions will
+/// be relaxed in the future.
 ///
 /// The parameter \p props specifies the optional compile-time properties
 /// of the type esimd::properties and may include esimd::cache_hint_L1,
@@ -383,7 +518,8 @@ __XETLA_API void xetla_prefetch_global(T* p, uint64_t byte_offset = 0) {
 ///
 /// Restrictions - cache hint imposed - temporary:
 /// If L1 or L2 cache hint is passed, then:
-/// R1: The pointer must be at least 4-byte aligned for elements of 4-bytes or
+/// R1: The pointer must be at least 4-byte aligned for elements of 4-bytes
+/// or
 ///     smaller and 8-byte aligned for 8-byte elements.
 /// R2: The number of elements for 8-byte data: 1, 2, 3, 4, 8, 16, 32, 64;
 ///     for 4-byte data: 1, 2, 3, 4, 8, 16, 32, 64,
@@ -574,6 +710,71 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
 #endif
 }
 
+/// 2D USM pointer block store.
+/// Supported platforms: PVC
+/// VISA instruction: lsc_store_block2d.ugm
+///
+/// Stores elements at specified address.
+///
+/// @tparam T is element type.
+/// @tparam BlockWidth is the block width in number of elements.
+/// @tparam BlockHeight is the block height in number of elements.
+/// @tparam L1H is L1 cache hint.
+/// @tparam L2H is L2 cache hint.
+/// @tparam N is the data size
+/// @param Ptr is the surface base address for this operation.
+/// @param SurfaceWidth is the surface width minus 1 in bytes
+/// @param SurfaceHeight is the surface height minus 1 in rows
+/// @param SurfacePitch is the surface pitch minus 1 in bytes
+/// @param X is zero based X-coordinate of the left upper rectangle corner in
+/// number of elements.
+/// @param Y is zero based Y-coordinate of the left upper rectangle corner in
+/// rows.
+/// @param Vals is a vector to store of type T and size N, where
+///  N = roundUpNextMultiple(BlockHeight, 4 / sizeof(T)) *
+///   getNextPowerOf2(BlockWidth) * NBlocks
+///
+template <
+    typename T,
+    int BlockWidth,
+    int BlockHeight = 1,
+    cache_hint L1H = cache_hint::none,
+    cache_hint L2H = cache_hint::none,
+    int N = __ESIMD_ENS::detail::get_lsc_block_2d_data_size<
+        T,
+        1u,
+        BlockHeight,
+        BlockWidth,
+        false,
+        false>()>
+__XETLA_API void xetla_store_global(
+    T* Ptr,
+    unsigned SurfaceWidth,
+    unsigned SurfaceHeight,
+    unsigned SurfacePitch,
+    int X,
+    int Y,
+    auto&& Vals) {
+  if constexpr (std::is_same_v<T, bf16>) {
+    xetla_store_global<fp16, BlockWidth, BlockHeight, L1H, L2H>(
+        reinterpret_cast<fp16*>(Ptr),
+        SurfaceWidth,
+        SurfaceHeight,
+        SurfacePitch,
+        X,
+        Y,
+        Vals.xetla_format<fp16>());
+  } else {
+    __ESIMD_ENS::lsc_store_2d<
+        T,
+        BlockWidth,
+        BlockHeight,
+        gpu::xetla::detail::get_cache_hint(L1H),
+        gpu::xetla::detail::get_cache_hint(L2H),
+        N>(
+        Ptr, SurfaceWidth - 1, SurfaceHeight - 1, SurfacePitch - 1, X, Y, Vals);
+  }
+}
 /// template <typename T, int N, int VS = 1, typename OffsetT,
 /// 	  typename PropertyListT = empty_properties_t>
 /// void scatter(T *p, simd<OffsetT, N / VS> byte_offsets, simd<T, N> vals,
@@ -951,6 +1152,10 @@ __XETLA_API xetla_vector<Ty, N * NElts> xetla_load_local(
     xetla_vector<uint32_t, N> offsets,
     xetla_mask<N> pred = 1) {
   using T = native_type_t<Ty>;
+  DEBUG_INVOKE(
+      dbg_level::core,
+      core::general_1d<gpu_arch::XeHpc, Ty>::
+          template check_restriction<NElts, N>(offsets));
 
   return __ESIMD_ENS::
       lsc_slm_gather<T, NElts, gpu::xetla::detail::get_data_size(DS), N>(
@@ -975,6 +1180,11 @@ __XETLA_API xetla_vector<Ty, N * NElts> xetla_load_local(
 template <typename Ty, int NElts = 1, data_size DS = data_size::default_size>
 __XETLA_API xetla_vector<Ty, NElts> xetla_load_local(uint32_t offset) {
   using T = native_type_t<Ty>;
+  // DEBUG_INVOKE(
+  //     dbg_level::core,
+  //     core::general_1d<gpu_arch::XeHpc, Ty>::template
+  //     check_restriction<NElts>(
+  //         (uint64_t)offset));
 
   return __ESIMD_NS::slm_block_load<T, NElts>(offset);
 }
@@ -1005,6 +1215,10 @@ __XETLA_API void xetla_store_local(
     xetla_vector<Ty, N * NElts> vals,
     xetla_mask<N> pred = 1) {
   using T = native_type_t<Ty>;
+  DEBUG_INVOKE(
+      dbg_level::core,
+      core::general_1d<gpu_arch::XeHpc, Ty>::
+          template check_restriction<NElts, N, uint32_t>(offsets));
 
   __ESIMD_ENS::
       lsc_slm_scatter<T, NElts, gpu::xetla::detail::get_data_size(DS), N>(
diff --git a/include/common/utils/common.hpp b/include/common/utils/common.hpp
index 7d3b8fe15..d3bca4818 100644
--- a/include/common/utils/common.hpp
+++ b/include/common/utils/common.hpp
@@ -275,8 +275,8 @@ enum class reg_layout : uint8_t {
   tiled = 1,
   vnni_tiled = 2,
   transpose_tiled = 3,
-  /// this is vnni tiled format, but for each block, they are stored in col
-  /// major order
+  /// this is vnni tiled format, but for each block, they are stored in
+  /// col-major order
   vnni_tiled_col_major = 4
 };
 enum class store_op : uint8_t {
diff --git a/include/subgroup/tile/impl/load_xe.hpp b/include/subgroup/tile/impl/load_xe.hpp
index c793e2acb..140e61593 100644
--- a/include/subgroup/tile/impl/load_xe.hpp
+++ b/include/subgroup/tile/impl/load_xe.hpp
@@ -89,12 +89,13 @@ tile_load(tile_t& tile, payload_t& payload) {
 
   static constexpr uint32_t num_block_x = tile_desc::num_block_x;
   static constexpr uint32_t num_block_y = tile_desc::num_block_y;
-  static constexpr uint32_t num_block = tile_desc::num_block;
 
   static constexpr gpu_arch arch_tag = payload_t::arch_tag;
 
   static constexpr reg_layout reg_layout_ = tile_desc::register_layout;
-  static constexpr bool is_vnni_reverse = payload_t::mem_dword_transpose &&
+  //   In the case of pack, tranpose is in vnni format
+  static constexpr bool is_vnni_reverse =
+      payload_t::mem_transpose_dtype_less4bytes &&
       ((reg_layout_ == reg_layout::tiled) ||
        (reg_layout_ == reg_layout::transpose_tiled));
   static constexpr bool reg_transpose = tile_desc::reg_transpose;
@@ -106,36 +107,62 @@ tile_load(tile_t& tile, payload_t& payload) {
   static constexpr bool mem_transform = payload_t::mem_transform;
 
   using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
+
+  // static constexpr uint32_t max_load_width_in_elem = trans
+  //     ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
+  //     : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
+  //   static constexpr uint32_t max_load_height_in_elem = trans
+  //       ? load_store_attr::max_trans_load_height_in_elem
+  //       : load_store_attr::max_load_height_in_elem;
+  //   static constexpr uint32_t max_trans_load_width_in_elem =
+  //       load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
+  //   static constexpr uint32_t max_load_width_in_elem =
+  //       load_store_attr::max_load_width_in_bytes / sizeof(dtype);
+
+  //   static constexpr uint32_t max_trans_load_height_in_elem =
+  //       load_store_attr::max_trans_load_height_in_elem;
+
+  //   static constexpr uint32_t max_load_height_in_elem =
+  //       load_store_attr::max_load_height_in_elem;
+
   static constexpr uint32_t elems_per_CL =
       load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
+
   static constexpr uint32_t elems_per_reg =
       register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
-  static constexpr int32_t max_load_block_height =
-      load_store_attr::max_load_height_in_elem;
-  static constexpr int32_t max_block_width =
-      load_store_attr::max_load_width_in_bytes / sizeof(dtype);
-  static constexpr int32_t max_trans_block_width =
-      load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
-
-  static constexpr uint32_t ld_blk_size_y_limit =
-      mem_transpose ? max_trans_block_width : max_load_block_height;
-  static constexpr uint32_t ld_blk_size_y = reg_transpose
-      ? block_size_y
-      : (block_size_y > ld_blk_size_y_limit ? ld_blk_size_y_limit
-                                            : block_size_y);
+
+  static constexpr uint32_t max_load_width_in_elem = trans
+      ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
+      : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
+
+  static constexpr uint32_t max_load_blk_height_in_elem = trans
+      ? load_store_attr::max_trans_load_height_in_elem
+      : load_store_attr::max_load_height_in_elem;
+
+  static constexpr uint32_t ld_blk_width = std::min(
+      (mem_transpose ? block_size_y : block_size_x), max_load_width_in_elem);
+
+  static constexpr uint32_t ld_blk_height = std::min(
+      (mem_transpose ? block_size_x : block_size_y),
+      max_load_blk_height_in_elem);
+
+  static constexpr uint32_t ld_blk_size_y =
+      mem_transpose ? ld_blk_width : ld_blk_height;
+
+  static constexpr uint32_t ld_blk_size_y_limit = mem_transpose
+      ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
+      : load_store_attr::max_load_height_in_elem;
 
   // array len is used to make sure memory load is cache line aligned
   // disabled while register or memory transpose
   static constexpr uint8_t arr_len_candidate =
-      (reg_transpose ||
-       mem_transpose
+      ((reg_transpose || mem_transpose)
        // block elements should be integer
        // times of register bytes
-       || ((block_size_y * block_size_x) % elems_per_reg != 0)
+       || ((block_elems) % elems_per_reg != 0)
        // tail blocks also need to meet above condition
-       ||
-       (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0)) ||
-          (block_size_y > ld_blk_size_y_limit)
+       || (((tile_size_y % block_size_y) * block_size_x) % elems_per_reg != 0))
+      // || (block_size_y > load_store_attr::max_load_height_in_elem)
       ? 1
       : (((tile_size_x % elems_per_CL) == 0)
              ? (((elems_per_CL % block_size_x) == 0)
@@ -143,98 +170,82 @@ tile_load(tile_t& tile, payload_t& payload) {
                     : 1)
              : ((tile_size_x < elems_per_CL) ? (tile_size_x / block_size_x)
                                              : 1));
-  static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) ||
-      (arr_len_candidate == 2) || (arr_len_candidate == 4);
-
-  static constexpr uint8_t arr_len =
-      is_valid_arr_len_candidate ? arr_len_candidate : 1;
-
-  static_assert(
-      reg_transpose || mem_transpose ||
-          (!mem_transpose && (block_size_x * arr_len) <= max_block_width),
-      "When reg_transpose was disabled, check 2d block width "
-      "restriction");
-  static_assert(
-      !reg_transpose ||
-          (!mem_transpose &&
-           (block_size_x * arr_len) <= max_trans_block_width) ||
-          (mem_transpose && (block_size_y * arr_len) <= max_block_width),
-      "When reg_transpose was enabled, check 2d block width "
-      "restriction");
-  static_assert(
-      !reg_transpose ||
-          (!mem_transpose && (block_size_y <= max_load_block_height)) ||
-          (mem_transpose && (block_size_x) <= max_load_block_height),
-      "When reg_transpose was enabled, check 2d block height "
-      "restriction");
-  static_assert(
-      tile_size_x % (block_size_x * arr_len) == 0,
-      "tile_size_x should be a multiple of (block_size_x * arr_len)");
+  // NBlocks must be {1,2,4} for bytes and words, {1,2} for dwords, 1 for
+  // qwords.
+  static constexpr bool arr_len =
+      ((arr_len_candidate == 1) ||
+       (arr_len_candidate == 2 && sizeof(dtype) <= 4) ||
+       (arr_len_candidate == 4 && sizeof(dtype) <= 2))
+      ? arr_len_candidate
+      : 1;
+
+  if constexpr (!trans && !mem_transform) {
+    static_assert(
+        (ld_blk_width * arr_len) <= max_load_width_in_elem,
+        "When Transposed and Transformed are both set to false, BlockWidth * NBlocks must not exceed 64 for bytes, 32 for words, 16 for dwords, and 8 for qwords");
+  } else if constexpr (mem_transform) {
+    static_assert(
+        (ld_blk_width * arr_len) <= max_load_width_in_elem,
+        "When Transformed is true then, BlockWidth * NBlocks must not exceed 64 for bytes and 32 for words.");
+  }
   static_assert(
       (reg_transpose &&
        ((block_size_x * sizeof(dtype)) % sizeof(load_dtype) == 0)) ||
           ((block_size_y * sizeof(dtype)) % sizeof(load_dtype) == 0),
       "check vnni limitation for DW transpose");
 
-  auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
 #pragma unroll
   for (uint32_t i = 0; i < num_block_y; ++i) {
-    constexpr uint32_t load_block_elems = block_elems * arr_len;
-    auto payload_row =
-        payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
-    detail::reset_tile_desc_core<
-        num_block_x,
-        block_size_x,
-        ld_blk_size_y,
-        scale_factor,
-        arr_len,
-        mem_transpose>(payload_row);
+    int offset_y = i * block_size_y;
 #pragma unroll
     for (uint32_t j = 0; j < num_block_x; j += arr_len) {
-      xetla_tdescriptor tdesc = payload_row.row(j);
+      int32_t offset_x = j * block_size_x;
+      constexpr uint32_t load_block_elems = block_elems * arr_len;
       auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
           (i * num_block_x + j) * block_elems);
-      constexpr uint32_t ld_blk_height = (reg_transpose && trans)
-          ? detail::getNextPowerOf2<ld_blk_size_y>()
-          : ld_blk_size_y;
-      constexpr uint32_t tmp_size = ld_blk_height * block_size_x * arr_len;
+      constexpr uint32_t tmp_size = ld_blk_width * ld_blk_height * arr_len;
       xetla_vector<dtype, tmp_size> reg_tmp;
 #pragma unroll
       for (uint32_t ii = 0; ii < block_size_y / ld_blk_size_y; ++ii) {
         constexpr uint32_t load_elems = ld_blk_size_y * block_size_x * arr_len;
-
-        reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_tload_global<
-            load_dtype,
-            ld_blk_height * block_size_x * arr_len / scale_factor,
-            L1,
-            L2,
+        uint32_t address_offset_x =
+            (mem_transpose ? (offset_y + ii * ld_blk_size_y) : offset_x) /
+            scale_factor;
+        uint32_t address_offset_y =
+            mem_transpose ? offset_x : (offset_y + ii * ld_blk_size_y);
+        reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
+            native_type_t<load_dtype>,
+            ld_blk_width / scale_factor,
+            ld_blk_height,
+            arr_len,
             trans,
             mem_transform,
-            arch_tag>(tdesc);
+            L1,
+            L2>(
+            reinterpret_cast<const native_type_t<load_dtype*>>(
+                payload.base_ptr),
+            payload.surface_width,
+            payload.surface_height,
+            payload.surface_pitch,
+            payload.offset_x + address_offset_x,
+            payload.offset_y + address_offset_y);
+
         if constexpr (reg_transpose && trans) {
           reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
               .xetla_format<native_type_t<load_dtype>>() =
               reg_tmp
                   .xetla_format<
                       native_type_t<load_dtype>,
-                      block_size_x / scale_factor,
+                      ld_blk_width / scale_factor,
                       ld_blk_height>()
                   .xetla_select<
-                      block_size_x / scale_factor,
+                      ld_blk_width / scale_factor,
                       1,
                       ld_blk_size_y,
                       1>(0, 0);
         } else {
           reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
         }
-
-        if constexpr (mem_transpose) {
-          xetla_update_tdesc_offsetx(
-              tdesc.xetla_format<uint32_t>(), ld_blk_size_y / scale_factor);
-        } else {
-          xetla_update_tdesc_offsety(
-              tdesc.xetla_format<uint32_t>(), ld_blk_size_y);
-        }
       }
       // exceed HW limitation
       if constexpr (block_size_y % ld_blk_size_y != 0) {
@@ -244,26 +255,29 @@ tile_load(tile_t& tile, payload_t& payload) {
             remained_start_y * block_size_x * arr_len;
         constexpr uint32_t remained_blk_size_y = block_size_y % ld_blk_size_y;
         constexpr uint32_t load_elems =
-            remained_blk_size_y * block_size_x * arr_len;
+            remained_blk_size_y * block_size_x * arr_len / scale_factor;
 
         constexpr uint8_t block_width =
-            mem_transpose ? (remained_blk_size_y / scale_factor) : block_size_x;
+            (mem_transpose ? remained_blk_size_y : block_size_x) / scale_factor;
         constexpr uint8_t block_height =
-            trans ? block_size_x : remained_blk_size_y;
-        constexpr uint32_t block_widthx_widthy_arrlen =
-            (block_width - 1) | ((block_height - 1) << 8);
-        gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
-            tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
-
+            mem_transpose ? block_size_x : remained_blk_size_y;
         reg_blk.xetla_select<load_elems, 1>(remained_start)
-            .xetla_format<native_type_t<load_dtype>>() = xetla_tload_global<
-            load_dtype,
-            (load_elems / scale_factor),
-            L1,
-            L2,
+            .xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
+            native_type_t<load_dtype>,
+            block_width,
+            block_height,
+            arr_len,
             trans,
             mem_transform,
-            arch_tag>(tdesc);
+            L1,
+            L2>(
+            reinterpret_cast<const native_type_t<load_dtype*>>(
+                payload.base_ptr),
+            payload.surface_width,
+            payload.surface_height,
+            payload.surface_pitch,
+            payload.offset_x + offset_x / scale_factor,
+            payload.offset_y + offset_y + remained_start_y);
       }
     }
   }
@@ -276,18 +290,11 @@ tile_load(tile_t& tile, payload_t& payload) {
         (!reg_transpose && (remained_size_y > ld_blk_size_y_limit))
         ? ld_blk_size_y_limit
         : remained_size_y;
-    auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
-        num_block_y * num_block_x, 0);
-    detail::reset_tile_desc_core<
-        num_block_x,
-        block_size_x,
-        remained_ld_blk_size_y,
-        scale_factor,
-        arr_len,
-        mem_transpose>(payload_row);
+
 #pragma unroll
     for (uint32_t j = 0; j < num_block_x; j += arr_len) {
-      xetla_tdescriptor tdesc = payload_row.row(j);
+      int32_t offset_x = j * block_size_x;
+      //   xetla_tdescriptor tdesc = payload_row.row(j);
       auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
           processed_elems + j * remained_block_elems);
       constexpr uint32_t ld_blk_height = (reg_transpose && trans)
@@ -301,15 +308,23 @@ tile_load(tile_t& tile, payload_t& payload) {
         constexpr uint32_t load_elems =
             remained_ld_blk_size_y * block_size_x * arr_len;
 
-        reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_tload_global<
-            load_dtype,
-            (ld_blk_height * block_size_x * arr_len / scale_factor),
-            L1,
-            L2,
+        reg_tmp.xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
+            native_type_t<load_dtype>,
+            block_size_x / scale_factor,
+            remained_ld_blk_size_y,
+            arr_len,
             trans,
             mem_transform,
-            arch_tag>(tdesc);
-
+            L1,
+            L2>(
+            reinterpret_cast<const native_type_t<load_dtype*>>(
+                payload.base_ptr),
+            payload.surface_width,
+            payload.surface_height,
+            payload.surface_pitch,
+            payload.offset_x + offset_x / scale_factor,
+            payload.offset_y + num_block_y * block_size_y +
+                ii * remained_ld_blk_size_y);
         if constexpr (reg_transpose && trans) {
           reg_blk.xetla_select<load_elems, 1>(ii * load_elems)
               .xetla_format<native_type_t<load_dtype>>() =
@@ -326,14 +341,14 @@ tile_load(tile_t& tile, payload_t& payload) {
         } else {
           reg_blk.xetla_select<tmp_size, 1>(ii * tmp_size) = reg_tmp;
         }
-        if constexpr (mem_transpose) {
-          xetla_update_tdesc_offsetx(
-              tdesc.xetla_format<uint32_t>(),
-              remained_ld_blk_size_y / scale_factor);
-        } else {
-          xetla_update_tdesc_offsety(
-              tdesc.xetla_format<uint32_t>(), remained_ld_blk_size_y);
-        }
+        // if constexpr (mem_transpose) {
+        //   xetla_update_tdesc_offsetx(
+        //       tdesc.xetla_format<uint32_t>(),
+        //       remained_ld_blk_size_y / scale_factor);
+        // } else {
+        //   xetla_update_tdesc_offsety(
+        //       tdesc.xetla_format<uint32_t>(), remained_ld_blk_size_y);
+        // }
       }
       constexpr uint32_t final_ld_blk_size_y =
           remained_size_y % remained_ld_blk_size_y;
@@ -344,22 +359,40 @@ tile_load(tile_t& tile, payload_t& payload) {
         constexpr uint32_t final_load_elems =
             final_ld_blk_size_y * block_size_x * arr_len;
         constexpr uint8_t block_width =
-            mem_transpose ? (final_ld_blk_size_y / scale_factor) : block_size_x;
+            (mem_transpose ? final_ld_blk_size_y : block_size_x) / scale_factor;
         constexpr uint8_t block_height =
-            trans ? block_size_x : final_ld_blk_size_y;
-        constexpr uint32_t block_widthx_widthy_arrlen =
-            (block_width - 1) | ((block_height - 1) << 8);
-        gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
-            tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
+            mem_transpose ? block_size_x : final_ld_blk_size_y;
+        // constexpr uint32_t block_widthx_widthy_arrlen =
+        //     (block_width - 1) | ((block_height - 1) << 8);
+        // gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
+        //     tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
         reg_blk.xetla_select<final_load_elems, 1>(final_start)
-            .xetla_format<native_type_t<load_dtype>>() = xetla_tload_global<
-            load_dtype,
-            final_load_elems / scale_factor,
-            L1,
-            L2,
+            .xetla_format<native_type_t<load_dtype>>() = xetla_load_global<
+            native_type_t<load_dtype>,
+            block_width,
+            block_height,
+            arr_len,
             trans,
             mem_transform,
-            arch_tag>(tdesc);
+            L1,
+            L2>(
+            reinterpret_cast<const native_type_t<load_dtype*>>(
+                payload.base_ptr),
+            payload.surface_width,
+            payload.surface_height,
+            payload.surface_pitch,
+            payload.offset_x + offset_x / scale_factor,
+            payload.offset_y + num_block_y * block_size_y +
+                remained_size_y / remained_ld_blk_size_y *
+                    remained_ld_blk_size_y);
+        // xetla_tload_global<
+        // load_dtype,
+        // final_load_elems / scale_factor,
+        // L1,
+        // L2,
+        // trans,
+        // mem_transform,
+        // arch_tag>(tdesc);
       }
     }
   }
@@ -426,7 +459,8 @@ tile_load(tile_t& tile, payload_t& payload) {
 
 /// @brief This function loads data from unaligned-2D memory surface.
 /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
-/// registers. Each block will be loaded serially by its corresponding payload.
+/// registers. Each block will be loaded serially by its corresponding
+/// payload.
 /// @tparam tile_t Is the tile_t struct contains registers.
 /// These registers will be the destination of load operation.
 /// @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -550,7 +584,8 @@ tile_load(tile_t& tile, payload_t& payload) {
 
 /// @brief This function loads data from unaligned-2D memory surface.
 /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
-/// registers. Each block will be loaded serially by its corresponding payload.
+/// registers. Each block will be loaded serially by its corresponding
+/// payload.
 /// @tparam tile_t Is the tile_t struct contains registers.
 /// These registers will be the destination of load operation.
 /// @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -615,7 +650,8 @@ tile_load(tile_t& tile, payload_t& payload) {
 
 /// @brief This function loads data from unaligned-2D memory surface.
 /// Loads an array of rectangular regions (X,Y)..(X+W,Y+H) from memory into
-/// registers. Each block will be loaded serially by its corresponding payload.
+/// registers. Each block will be loaded serially by its corresponding
+/// payload.
 /// @tparam tile_t Is the tile_t struct contains registers.
 /// These registers will be the destination of load operation.
 /// @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -755,8 +791,8 @@ tile_load(
 }
 
 /// @brief Is the data load func from local shared memory to register file,
-/// which supports the memory surface is 1d or 2d scenario. And we always assume
-/// data in SLM is row major.
+/// which supports the memory surface is 1d or 2d scenario. And we always
+/// assume data in SLM is row major.
 /// @tparam tile_t Is the tile_t struct contains registers
 /// These registers will be the destination of load operation.
 /// @tparam payload_t Is the mem_payload_t struct describing the memory
@@ -838,8 +874,8 @@ tile_load(tile_t& tile, payload_t& payload) {
 }
 
 /// @brief Is the data load func from shared local memory to register file,
-/// which supports the memory surface is 1d scenario. And the src memory layout
-/// is always row major.
+/// which supports the memory surface is 1d scenario. And the src memory
+/// layout is always row major.
 /// @tparam tile_t Is the tile_t struct contains registers.
 /// These registers will be the destination of load operation.
 /// @tparam payload_t Is the mem_payload_t struct describing the memory
diff --git a/include/subgroup/tile/impl/payload_xe.hpp b/include/subgroup/tile/impl/payload_xe.hpp
index 697aba49e..3480a7820 100644
--- a/include/subgroup/tile/impl/payload_xe.hpp
+++ b/include/subgroup/tile/impl/payload_xe.hpp
@@ -75,22 +75,53 @@ struct mem_payload_t<
   static constexpr bool trans = (mem_transpose ^ reg_transpose) &&
       !(std::is_same_v<dtype_, int4x2> || std::is_same_v<dtype_, int4x8>);
 
-  static constexpr bool mem_transform = (sizeof(dtype) < 4) && !mem_transpose &&
+  // Transformed and Transposed cannot be set to true at the same time.
+  // If Transformed is true then:
+  // sizeof(T) must be 1- or 2-byte (bytes or words).
+  static constexpr bool mem_transform = (sizeof(dtype) <= 2) && !trans &&
       (register_layout == reg_layout::vnni_tiled ||
        register_layout == reg_layout::vnni_tiled_col_major);
-  static constexpr bool mem_dword_transpose = (sizeof(dtype) < 4) && trans;
 
-  using mem_dtype =
-      typename std::conditional<mem_dword_transpose, uint32_t, dtype>::type;
+  // If Transposed is true then:
+  // sizeof(T) must be 4- or 8-byte (dwords or qwords).
+  static constexpr bool mem_transpose_dtype_less4bytes =
+      (sizeof(dtype) < 4) && trans;
+
+  using mem_dtype = typename std::
+      conditional_t<mem_transpose_dtype_less4bytes, uint32_t, dtype>;
   static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
 
+  dtype* base_ptr;
+  uint32_t surface_width;
+  uint32_t surface_height;
+  uint32_t surface_pitch;
+  int32_t offset_x;
+  int32_t offset_y;
+
   xetla_vector<uint32_t, 16 * num_block> payloads;
 
   inline mem_payload_t(const this_payload_t& rhs) {
+    this->base_ptr = rhs.base_ptr;
+    this->surface_width = rhs.surface_width;
+    this->surface_height = rhs.surface_height;
+    this->surface_pitch = rhs.surface_pitch;
+    this->offset_x = rhs.offset_x;
+    this->offset_y = rhs.offset_y;
+
     this->payloads = rhs.payloads;
   }
 
   inline mem_payload_t(mem_desc_t& mem_desc) {
+    this->base_ptr = (dtype*)mem_desc.base.base;
+    this->surface_width =
+        (mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
+    this->surface_height =
+        (mem_transpose ? mem_desc.shape.x : mem_desc.shape.y);
+    this->surface_pitch = mem_desc.shape.stride * sizeof(dtype);
+    this->offset_x = (mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) /
+        int32_t(scale_factor);
+    this->offset_y = mem_transpose ? mem_desc.coord.x : mem_desc.coord.y;
+
     xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
     int32_t offset = gpu::xetla::detail::xetla_get_tensor_offset_x(base_tdesc) /
         int32_t(scale_factor);
@@ -106,7 +137,15 @@ struct mem_payload_t<
       uint32_t surface_pitch,
       int32_t surface_offset_x = 0,
       int32_t surface_offset_y = 0) {
+    this->base_ptr = p;
+    this->surface_width = surface_width * sizeof(dtype);
+    this->surface_height = surface_height;
+    this->surface_pitch = surface_pitch * sizeof(dtype);
+    this->offset_x = surface_offset_x / int32_t(scale_factor);
+    this->offset_y = surface_offset_y;
+
     xetla_tdescriptor base_tdesc;
+
     xetla_fill_tdesc(
         base_tdesc.xetla_format<uint32_t>(),
         p,
@@ -119,6 +158,16 @@ struct mem_payload_t<
   }
 
   __XETLA_API void init(mem_desc_t& mem_desc) {
+    this->base_ptr = (dtype*)mem_desc.base.base;
+    this->surface_width =
+        (mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
+    this->surface_height =
+        (mem_transpose ? mem_desc.shape.x : mem_desc.shape.y);
+    this->surface_pitch = mem_desc.shape.stride * sizeof(dtype);
+    this->offset_x = (mem_transpose ? mem_desc.coord.y : mem_desc.coord.x) /
+        int32_t(scale_factor);
+    this->offset_y = (mem_transpose ? mem_desc.coord.x : mem_desc.coord.y);
+
     xetla_tdescriptor base_tdesc = mem_desc.get_tdesc();
     int32_t offset = gpu::xetla::detail::xetla_get_tensor_offset_x(base_tdesc) /
         int32_t(scale_factor);
@@ -142,6 +191,13 @@ struct mem_payload_t<
       uint32_t surface_pitch,
       int32_t surface_offset_x = 0,
       int32_t surface_offset_y = 0) {
+    this->base_ptr = p;
+    this->surface_width = surface_width * sizeof(dtype);
+    this->surface_height = surface_height;
+    this->surface_pitch = surface_pitch * sizeof(dtype);
+    this->offset_x = surface_offset_x / int32_t(scale_factor);
+    this->offset_y = surface_offset_y;
+
     xetla_tdescriptor base_tdesc;
     xetla_fill_tdesc(
         base_tdesc.xetla_format<uint32_t>(),
@@ -160,6 +216,13 @@ struct mem_payload_t<
   // ~mem_payload_t(){}
 
   inline this_payload_t& operator=(const this_payload_t& rhs) {
+    this->base_ptr = rhs.base_ptr;
+    this->surface_width = rhs.surface_width;
+    this->surface_height = rhs.surface_height;
+    this->surface_pitch = rhs.surface_pitch;
+    this->offset_x = rhs.offset_x;
+    this->offset_y = rhs.offset_y;
+
     this->payloads = rhs.payloads;
     return *this;
   }
@@ -168,12 +231,14 @@ struct mem_payload_t<
   __XETLA_API void update_tdesc(int offset) {
     auto payloads_2d = payloads.xetla_format<uint32_t, num_block, 16>();
     if constexpr (update_dir == tdesc_update_dir::x_dir) {
+      offset_x += offset / scale_factor;
 #pragma unroll
       for (uint32_t i = 0; i < num_block; i++) {
         xetla_update_tdesc_offsetx(
             payloads_2d.row(i), offset / int32_t(scale_factor));
       }
     } else {
+      offset_y += offset;
 #pragma unroll
       for (uint32_t i = 0; i < num_block; i++) {
         xetla_update_tdesc_offsety(payloads_2d.row(i), offset);
@@ -1150,10 +1215,9 @@ struct mem_payload_t<
   static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
   static constexpr uint32_t block_size_x = tile_desc::block_size_x;
   static constexpr uint32_t block_size_y = tile_desc::block_size_y;
-  static constexpr uint32_t tile_bytes =
-      tile_size_x * tile_size_y * sizeof(dtype);
+  static constexpr uint32_t tile_bytes = tile_desc::tile_elems * sizeof(dtype);
   static constexpr uint32_t block_bytes =
-      block_size_x * block_size_y * sizeof(dtype);
+      tile_desc::block_elems * sizeof(dtype);
   using this_payload_t =
       mem_payload_t<mem_desc_t, tile_desc, msg_type::block_2d, arch_tag_>;
 
@@ -1230,7 +1294,7 @@ struct mem_payload_t<
     base_offset = mem_transpose
         ? base_x * pitch_in_bytes + base_y * sizeof(dtype)
         : base_y * pitch_in_bytes + base_x * sizeof(dtype);
-    base_ptr = (mem_dtype*)mem_tdesc.base.base;
+    base_ptr = reinterpret_cast<mem_dtype*>(mem_tdesc.base.base);
 
     xetla_vector<uint32_t, num_channel> channel_index =
         xetla_vector_gen<uint32_t, num_channel>(0, 1);
@@ -1710,11 +1774,12 @@ struct prefetch_payload_t<
         reg_layout_>,
     num_coop_sg_,
     arch_tag_,
-    std::enable_if_t<(!arch_has_2d_load_store<arch_tag_>)&&(
-        ((block_size_y_ != 1 || tile_size_y_ != 1) &&
-         mem_layout_ == mem_layout::row_major) ||
-        ((block_size_x_ != 1 || tile_size_x_ != 1) &&
-         mem_layout_ == mem_layout::col_major))>> {
+    std::enable_if_t<
+        (!arch_has_2d_load_store<arch_tag_>) &&
+        (((block_size_y_ != 1 || tile_size_y_ != 1) &&
+          mem_layout_ == mem_layout::row_major) ||
+         ((block_size_x_ != 1 || tile_size_x_ != 1) &&
+          mem_layout_ == mem_layout::col_major))>> {
   using dtype = native_type_t<dtype_>;
   using mem_desc_t =
       mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_, use_mask_>;
@@ -1735,10 +1800,8 @@ struct prefetch_payload_t<
   static constexpr uint32_t tile_size_y = tile_desc::tile_size_y;
   static constexpr uint32_t block_size_x = tile_desc::block_size_x;
   static constexpr uint32_t block_size_y = tile_desc::block_size_y;
-  static constexpr uint32_t tile_bytes =
-      tile_size_x * tile_size_y * sizeof(dtype);
-  static constexpr uint32_t block_bytes =
-      block_size_x * block_size_y * sizeof(dtype);
+  static constexpr uint32_t tile_bytes = tile_desc::block_elems * sizeof(dtype);
+  static constexpr uint32_t block_bytes = tile_desc::tile_elems * sizeof(dtype);
 
  private:
   using this_payload_t =
@@ -1781,41 +1844,43 @@ struct prefetch_payload_t<
   static constexpr uint32_t num_channel = select_channel(
       std::min(mem_transpose ? block_size_x : block_size_y, max_channel));
 
-  static constexpr uint32_t mem_tile_size_w =
-      mem_transpose ? tile_size_y : tile_size_x;
-  static constexpr uint32_t mem_tile_size_h =
-      mem_transpose ? tile_size_x : tile_size_y;
-  using load_store_attr =
-      typename arch_attr_t<arch_tag>::template load_store_attr<message_type>;
-  static constexpr uint32_t special_prefetch_width =
-      load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
-  static constexpr uint32_t normal_prefetch_width =
-      load_store_attr::max_load_width_in_bytes / sizeof(dtype);
-  static constexpr bool is_special_prefetch =
-      (mem_tile_size_w % special_prefetch_width) == 0;
-
-  static constexpr uint32_t block_size_w = is_special_prefetch
-      ? special_prefetch_width
-      : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
-                                                 : normal_prefetch_width);
-  static constexpr uint32_t block_size_h =
-      load_store_attr::max_load_height_in_elem;
-  // could have over-prefetch, but that's should be fine
-  static constexpr uint32_t max_num_block_w =
-      (mem_tile_size_w + block_size_w - 1) / block_size_w;
-  static constexpr uint32_t num_coop_sg = num_coop_sg_;
-  static constexpr uint32_t num_coop_sg_w =
-      detail::gcd<num_coop_sg, max_num_block_w>::value;
-  static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
-
-  static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
-  static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
-  static constexpr uint32_t tile_size_h =
-      (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
-  static constexpr uint32_t num_block_h =
-      (tile_size_h + block_size_h - 1) / block_size_h;
+  // static constexpr uint32_t mem_tile_size_w =
+  //     mem_transpose ? tile_size_y : tile_size_x;
+  // static constexpr uint32_t mem_tile_size_h =
+  //     mem_transpose ? tile_size_x : tile_size_y;
+
+  // static constexpr uint32_t special_prefetch_width =
+  //     load_store_attr::special_prefetch_width_in_bytes / sizeof(dtype);
+  // static constexpr uint32_t normal_prefetch_width =
+  //     load_store_attr::max_load_width_in_bytes / sizeof(dtype);
+  // static constexpr bool is_special_prefetch =
+  //     (mem_tile_size_w % special_prefetch_width) == 0;
+
+  // static constexpr uint32_t block_size_w = is_special_prefetch
+  //     ? special_prefetch_width
+  //     : (normal_prefetch_width > mem_tile_size_w ? mem_tile_size_w
+  //                                                : normal_prefetch_width);
+  // static constexpr uint32_t block_size_h =
+  //     load_store_attr::max_load_height_in_elem;
+  // // could have over-prefetch, but that's should be fine
+  // static constexpr uint32_t max_num_block_w =
+  //     (mem_tile_size_w + block_size_w - 1) / block_size_w;
+  // static constexpr uint32_t num_coop_sg = num_coop_sg_;
+  // static constexpr uint32_t num_coop_sg_w =
+  //     detail::gcd<num_coop_sg, max_num_block_w>::value;
+  // static constexpr uint32_t num_coop_sg_h = num_coop_sg / num_coop_sg_w;
+
+  // static constexpr uint32_t num_block_w = max_num_block_w / num_coop_sg_w;
+  // static constexpr uint32_t tile_size_w = block_size_w * num_block_w;
+  // static constexpr uint32_t tile_size_h =
+  //     (mem_tile_size_h + num_coop_sg_h - 1) / num_coop_sg_h;
+  // static constexpr uint32_t num_block_h =
+  //     (tile_size_h + block_size_h - 1) / block_size_h;
 
   xetla_vector<uint32_t, num_channel> channel_offset;
+  xetla_vector<uint32_t, num_channel> step_x;
+  xetla_vector<uint32_t, num_channel> step_y;
+
   uint64_t base_offset;
   uint32_t base_x;
   uint32_t base_y;
@@ -1852,13 +1917,15 @@ struct prefetch_payload_t<
     return *this;
   }
 
-  inline prefetch_payload_t(mem_desc_t& mem_desc, uint32_t coop_id = 0) {
-    uint32_t coop_id_x = coop_id % num_coop_sg_w;
-    uint32_t coop_id_y = coop_id / num_coop_sg_w;
+  inline prefetch_payload_t(
+      mem_desc_t& mem_desc,
+      [[maybe_unused]] uint32_t coop_id = 0) {
+    // uint32_t coop_id_x = coop_id % num_coop_sg_w;
+    // uint32_t coop_id_y = coop_id / num_coop_sg_w;
 
     pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype);
-    base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
-    base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
+    base_x = mem_desc.coord.x;
+    base_y = mem_desc.coord.y;
     width_in_elems = mem_desc.shape.x;
     height_in_elems = mem_desc.shape.y;
     base_offset = mem_transpose
@@ -1878,13 +1945,15 @@ struct prefetch_payload_t<
       int surface_pitch,
       int surface_offset_x,
       int surface_offset_y,
-      uint32_t coop_id = 0) {
-    uint32_t coop_id_x = coop_id % num_coop_sg_w;
-    uint32_t coop_id_y = coop_id / num_coop_sg_w;
+      [[maybe_unused]] uint32_t coop_id = 0) {
+    // uint32_t coop_id_x = coop_id % num_coop_sg_w;
+    // uint32_t coop_id_y = coop_id / num_coop_sg_w;
+    // base_x = surface_offset_x + coop_id_x * tile_size_w;
+    // base_y = surface_offset_y + coop_id_y * tile_size_h;
 
     pitch_in_bytes = surface_pitch * sizeof(dtype);
-    base_x = surface_offset_x + coop_id_x * tile_size_w;
-    base_y = surface_offset_y + coop_id_y * tile_size_h;
+    base_x = surface_offset_x;
+    base_y = surface_offset_y;
     width_in_elems = surface_width;
     height_in_elems = surface_height;
     base_offset = mem_transpose
@@ -1897,13 +1966,17 @@ struct prefetch_payload_t<
     channel_offset = channel_index * pitch_in_bytes;
   }
 
-  inline void init(mem_desc_t& mem_desc, uint32_t coop_id = 0) {
-    uint32_t coop_id_x = coop_id % num_coop_sg_w;
-    uint32_t coop_id_y = coop_id / num_coop_sg_w;
+  inline void init(
+      mem_desc_t& mem_desc,
+      [[maybe_unused]] uint32_t coop_id = 0) {
+    // uint32_t coop_id_x = coop_id % num_coop_sg_w;
+    // uint32_t coop_id_y = coop_id / num_coop_sg_w;
+    // base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
+    // base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
 
     pitch_in_bytes = mem_desc.shape.stride * sizeof(dtype);
-    base_x = mem_desc.coord.x + coop_id_x * tile_size_w;
-    base_y = mem_desc.coord.y + coop_id_y * tile_size_h;
+    base_x = mem_desc.coord.x;
+    base_y = mem_desc.coord.y;
     width_in_elems = mem_desc.shape.x;
     height_in_elems = mem_desc.shape.y;
     base_offset = mem_transpose
@@ -1960,9 +2033,10 @@ struct prefetch_payload_t<
         reg_layout_>,
     num_coop_sg_,
     arch_tag_,
-    std::enable_if_t<(arch_has_2d_load_store<arch_tag_>)&&(
-        ((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
-        ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
+    std::enable_if_t<
+        (arch_has_2d_load_store<arch_tag_>) &&
+        (((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
+         ((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
   using dtype = dtype_;
   using mem_desc_t =
       mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_, use_mask_>;
diff --git a/include/subgroup/tile/impl/prefetch_xe.hpp b/include/subgroup/tile/impl/prefetch_xe.hpp
index 243379b9d..f14941518 100644
--- a/include/subgroup/tile/impl/prefetch_xe.hpp
+++ b/include/subgroup/tile/impl/prefetch_xe.hpp
@@ -104,8 +104,7 @@ tile_prefetch(payload_t& payload) {
   using prefetch_dtype = typename payload_t::prefetch_dtype;
   constexpr uint32_t num_channel = payload_t::num_channel;
 #pragma unroll
-  for (uint32_t i = 0; i < tile_desc::tile_size_y / tile_desc::block_size_y;
-       i++) {
+  for (uint32_t i = 0; i < tile_desc::num_block_y; i++) {
     uint32_t offset_y = i * tile_desc::block_size_y;
 #pragma unroll
     for (uint32_t j = 0; j < tile_desc::num_block_x; j++) {
@@ -126,7 +125,6 @@ tile_prefetch(payload_t& payload) {
           L2>(
           payload.base_ptr,
           payload.channel_offset + payload.base_offset + address_offset);
-      //   }
     }
   }
 }
diff --git a/include/subgroup/tile/impl/store_xe.hpp b/include/subgroup/tile/impl/store_xe.hpp
index bec5b007a..7f0645da9 100644
--- a/include/subgroup/tile/impl/store_xe.hpp
+++ b/include/subgroup/tile/impl/store_xe.hpp
@@ -98,51 +98,44 @@ tile_store(tile_t& tile, payload_t& payload) {
 
   static constexpr uint32_t num_block_x = tile_desc::num_block_x;
   static constexpr uint32_t num_block_y = tile_desc::num_block_y;
-  static constexpr uint32_t num_block = tile_desc::num_block;
 
-  using load_store_attr = typename arch_attr_t<
-      payload_t::arch_tag>::template load_store_attr<msg_type::block_2d>;
+  static constexpr gpu_arch arch_tag = payload_t::arch_tag;
 
-  static constexpr int32_t max_block_width =
-      load_store_attr::max_load_width_in_bytes / sizeof(dtype);
-  static constexpr int32_t max_store_block_height =
+  using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
+  static constexpr uint32_t max_store_width_in_elem =
+      load_store_attr::max_store_width_in_bytes / sizeof(dtype);
+  static constexpr uint32_t max_store_height_in_elem =
       load_store_attr::max_store_height_in_elem;
-  static_assert(
-      (max_block_width % block_size_x) == 0,
-      "max_block_width should be a multiply of block size x.");
+
   static constexpr uint32_t elems_per_CL =
       load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
+
+  static_assert(
+      (max_store_width_in_elem % block_size_x) == 0,
+      "max_store_width_in_elem should be a multiply of block_size_x.");
+
   static constexpr uint32_t st_blk_size_y =
-      block_size_y > max_store_block_height ? max_store_block_height
-                                            : block_size_y;
+      std::min(block_size_y, max_store_height_in_elem);
+
   // to make sure full CL store
-  static constexpr uint32_t st_block_x = ((tile_size_x % elems_per_CL) == 0)
+  static constexpr uint32_t st_blk_size_x = ((tile_size_x % elems_per_CL) == 0)
       ? elems_per_CL
       : (((elems_per_CL % tile_size_x) == 0) ? tile_size_x : block_size_x);
 
-  static constexpr uint8_t arr_len_candidate = st_block_x / block_size_x;
+  static constexpr uint8_t arr_len_candidate = st_blk_size_x / block_size_x;
   static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) ||
       (arr_len_candidate == 2) || (arr_len_candidate == 4);
 
   static constexpr uint8_t arr_len =
       is_valid_arr_len_candidate ? arr_len_candidate : 1;
 
-  auto payload_2d = payload.payloads.xetla_format<uint32_t, num_block, 16>();
 #pragma unroll
   for (uint32_t i = 0; i < num_block_y; ++i) {
-    constexpr uint32_t store_block_elems = block_elems * arr_len;
-    auto payload_row =
-        payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
-    detail::reset_tile_desc_core<
-        num_block_x,
-        block_size_x * arr_len,
-        st_blk_size_y,
-        1,
-        1,
-        false>(payload_row);
+    int32_t offset_y = i * block_size_y;
 #pragma unroll
     for (uint32_t j = 0; j < num_block_x; j += arr_len) {
-      xetla_tdescriptor tdesc = payload_row.row(j);
+      int32_t offset_x = j * block_size_x;
+      constexpr uint32_t store_block_elems = block_elems * arr_len;
       auto reg_blk = tile.reg.xetla_select<store_block_elems, 1>(
           (i * num_block_x + j) * block_elems);
       xetla_vector<dtype, store_block_elems> combine_blk;
@@ -150,41 +143,50 @@ tile_store(tile_t& tile, payload_t& payload) {
           native_type_t<dtype>,
           block_size_y,
           block_size_x * arr_len>();
+      /*     combine_blk_2d
+        ____________  ____________
+       |            ||            |
+       |   block    ||   block    |
+       |            ||            |
+       |____________||____________|
+      */
 #pragma unroll
-      for (uint32_t combine_i = 0; combine_i < arr_len; ++combine_i) {
+      for (uint32_t block_id = 0; block_id < arr_len; ++block_id) {
         combine_blk_2d.xetla_select<block_size_y, 1, block_size_x, 1>(
-            0, combine_i * block_size_x) =
-            reg_blk.xetla_select<block_elems, 1>(combine_i * block_elems);
+            0, block_id * block_size_x) =
+            reg_blk.xetla_select<block_elems, 1>(block_id * block_elems);
       }
 #pragma unroll
-      for (uint32_t ii = 0; ii < block_size_y / st_blk_size_y; ++ii) {
-        constexpr uint32_t store_elems = st_blk_size_y * block_size_x * arr_len;
+      for (uint32_t ii = 0; ii < block_size_y; ii += st_blk_size_y) {
+        constexpr uint32_t store_elems = st_blk_size_y * st_blk_size_x;
         auto st_blk =
-            combine_blk.xetla_select<store_elems, 1>(ii * store_elems);
-        xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
-            tdesc, st_blk);
-        xetla_update_tdesc_offsety(
-            tdesc.xetla_format<uint32_t>(), st_blk_size_y);
+            combine_blk.xetla_select<store_elems, 1>(ii * st_blk_size_x);
+        xetla_store_global<dtype, st_blk_size_x, st_blk_size_y, L1, L2>(
+            payload.base_ptr,
+            payload.surface_width,
+            payload.surface_height,
+            payload.surface_pitch,
+            payload.offset_x + offset_x,
+            payload.offset_y + offset_y + ii,
+            st_blk);
       }
       // exceed hardware limitation
       if constexpr ((block_size_y % st_blk_size_y) != 0) {
-        constexpr uint32_t blk_remained_start = block_size_y / st_blk_size_y *
-            st_blk_size_y * block_size_x * arr_len;
+        constexpr uint32_t blk_remained_start =
+            block_size_y / st_blk_size_y * st_blk_size_y * st_blk_size_x;
         constexpr uint8_t blk_remained_y = block_size_y % st_blk_size_y;
-        constexpr uint8_t blk_remained_elems =
-            blk_remained_y * block_size_x * arr_len;
+        constexpr uint8_t blk_remained_elems = blk_remained_y * st_blk_size_x;
         auto st_blk =
             combine_blk.xetla_select<blk_remained_elems, 1>(blk_remained_start);
-        constexpr uint32_t block_widthx_widthy_arrlen =
-            (block_size_x * arr_len - 1) | ((blk_remained_y - 1) << 8);
-        gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
-            tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
-        xetla_tstore_global<
-            dtype,
-            blk_remained_elems,
-            L1,
-            L2,
-            payload_t::arch_tag>(tdesc, st_blk);
+        xetla_store_global<dtype, st_blk_size_x, blk_remained_y, L1, L2>(
+            payload.base_ptr,
+            payload.surface_width,
+            payload.surface_height,
+            payload.surface_pitch,
+            payload.offset_x + offset_x,
+            payload.offset_y + offset_y +
+                block_size_y / st_blk_size_y * st_blk_size_y,
+            st_blk);
       }
     }
   }
@@ -194,19 +196,10 @@ tile_store(tile_t& tile, payload_t& payload) {
     constexpr uint32_t processed_elems =
         num_block_y * num_block_x * block_elems;
     constexpr uint32_t remained_st_blk_size_y =
-        st_blk_size_y > remained_size_y ? remained_size_y : st_blk_size_y;
-    auto payload_row = payload_2d.xetla_select<num_block_x, 1, 16, 1>(
-        num_block_y * num_block_x, 0);
-    detail::reset_tile_desc_core<
-        num_block_x,
-        block_size_x * arr_len,
-        remained_st_blk_size_y,
-        1,
-        1,
-        false>(payload_row);
+        std::min(st_blk_size_y, remained_size_y);
 #pragma unroll
     for (uint32_t j = 0; j < num_block_x; j += arr_len) {
-      xetla_tdescriptor tdesc = payload_row.row(j);
+      int offset_x = j * block_size_x;
       auto reg_blk = tile.reg.xetla_select<remained_block_elems * arr_len, 1>(
           processed_elems + j * remained_block_elems);
       // Do combination
@@ -214,46 +207,53 @@ tile_store(tile_t& tile, payload_t& payload) {
       auto combine_blk_2d = combine_blk.xetla_format<
           native_type_t<dtype>,
           remained_size_y,
-          block_size_x * arr_len>();
+          st_blk_size_x>();
 #pragma unroll
-      for (uint32_t combine_i = 0; combine_i < arr_len; ++combine_i) {
+      for (uint32_t block_id = 0; block_id < arr_len; ++block_id) {
         combine_blk_2d.xetla_select<remained_size_y, 1, block_size_x, 1>(
-            0, combine_i * block_size_x) =
+            0, block_id * block_size_x) =
             reg_blk.xetla_select<remained_block_elems, 1>(
-                combine_i * remained_block_elems);
+                block_id * remained_block_elems);
       }
 #pragma unroll
-      for (uint32_t ii = 0; ii < remained_size_y / remained_st_blk_size_y;
-           ++ii) {
-        constexpr uint32_t store_elems =
-            remained_st_blk_size_y * block_size_x * arr_len;
+      for (uint32_t ii = 0; ii < remained_size_y;
+           ii += remained_st_blk_size_y) {
+        constexpr uint32_t store_elems = remained_st_blk_size_y * st_blk_size_x;
         auto st_blk =
-            combine_blk.xetla_select<store_elems, 1>(ii * store_elems);
-        xetla_tstore_global<dtype, store_elems, L1, L2, payload_t::arch_tag>(
-            tdesc, st_blk);
-        xetla_update_tdesc_offsety(
-            tdesc.xetla_format<uint32_t>(), remained_st_blk_size_y);
+            combine_blk.xetla_select<store_elems, 1>(ii * st_blk_size_x);
+        xetla_store_global<
+            dtype,
+            st_blk_size_x,
+            remained_st_blk_size_y,
+            L1,
+            L2>(
+            payload.base_ptr,
+            payload.surface_width,
+            payload.surface_height,
+            payload.surface_pitch,
+            payload.offset_x + offset_x,
+            payload.offset_y + num_block_y * block_size_y + ii,
+            st_blk);
       }
       constexpr uint32_t final_st_blk_size_y =
           remained_size_y % remained_st_blk_size_y;
       if constexpr (final_st_blk_size_y != 0) {
         constexpr uint32_t final_start = remained_size_y /
-            remained_st_blk_size_y * remained_st_blk_size_y * block_size_x *
-            arr_len;
+            remained_st_blk_size_y * remained_st_blk_size_y * st_blk_size_x;
         constexpr uint32_t final_store_elems =
-            final_st_blk_size_y * block_size_x * arr_len;
+            final_st_blk_size_y * st_blk_size_x;
         auto st_blk =
             combine_blk.xetla_select<final_store_elems, 1>(final_start);
-        constexpr uint32_t block_widthx_widthy_arrlen =
-            (block_size_x * arr_len - 1) | ((final_st_blk_size_y - 1) << 8);
-        gpu::xetla::detail::xetla_set_block_widthx_widthy_arrlen(
-            tdesc.xetla_format<uint32_t>(), block_widthx_widthy_arrlen);
-        xetla_tstore_global<
-            dtype,
-            final_store_elems,
-            L1,
-            L2,
-            payload_t::arch_tag>(tdesc, st_blk);
+        xetla_store_global<dtype, st_blk_size_x, final_st_blk_size_y, L1, L2>(
+            payload.base_ptr,
+            payload.surface_width,
+            payload.surface_height,
+            payload.surface_pitch,
+            payload.offset_x + offset_x,
+            payload.offset_y + num_block_y * block_size_y +
+                remained_size_y / remained_st_blk_size_y *
+                    remained_st_blk_size_y,
+            st_blk);
       }
     }
   }
diff --git a/tests/integration/default_config/group_gemm/kernel_func.hpp b/tests/integration/default_config/group_gemm/kernel_func.hpp
index 26b2fb181..7b7b31404 100644
--- a/tests/integration/default_config/group_gemm/kernel_func.hpp
+++ b/tests/integration/default_config/group_gemm/kernel_func.hpp
@@ -108,6 +108,9 @@ struct default_config_group_gemm_test_func {
 
   using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
 
+  static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
+  static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
+
   static const char* func_name() {
     return "default_config_group_gemm_test_func";
   }
diff --git a/tests/integration/default_config/kernel_gemm/kernel_func.hpp b/tests/integration/default_config/kernel_gemm/kernel_func.hpp
index 3745343d0..24b82bb94 100644
--- a/tests/integration/default_config/kernel_gemm/kernel_func.hpp
+++ b/tests/integration/default_config/kernel_gemm/kernel_func.hpp
@@ -65,6 +65,9 @@ struct default_config_kernel_gemm_test_func {
       gpu_arch::XeHpc, // GPU arch
       tune_option>;
 
+  static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
+  static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
+
   static const char* func_name() {
     return "default_config_kernel_gemm_test_func";
   }
diff --git a/tests/integration/gemm/bf16/kernel_func.hpp b/tests/integration/gemm/bf16/kernel_func.hpp
index 6345047df..286e53dbb 100644
--- a/tests/integration/gemm/bf16/kernel_func.hpp
+++ b/tests/integration/gemm/bf16/kernel_func.hpp
@@ -76,6 +76,9 @@ struct bf16_gemm_test_func {
 
   using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
 
+  static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
+  static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
+
   static const char* func_name() {
     return "bf16_gemm_test_func";
   }
diff --git a/tests/integration/gemm/fp16/common.hpp b/tests/integration/gemm/fp16/common.hpp
index 0432fa349..cd7d52a7a 100644
--- a/tests/integration/gemm/fp16/common.hpp
+++ b/tests/integration/gemm/fp16/common.hpp
@@ -59,7 +59,7 @@ class TestBaseFP16f : public TestBase {
   using data_type_b = fp16;
   using data_type_c = fp16;
   using data_type_acc = float;
-  static constexpr mma_engine engine = mma_engine::fpu;
+  static constexpr mma_engine engine = mma_engine::xmx;
 };
 
 class TestBaseFP16x : public TestBase {
diff --git a/tests/integration/gemm/fp16/main.cpp b/tests/integration/gemm/fp16/main.cpp
index 32ef12461..78dca8b4a 100644
--- a/tests/integration/gemm/fp16/main.cpp
+++ b/tests/integration/gemm/fp16/main.cpp
@@ -32,7 +32,6 @@ TYPED_TEST_P(fp16_gemm_test, esimd) {
       esimd_compile_string);
 }
 REGISTER_TYPED_TEST_SUITE_P(fp16_gemm_test, esimd);
-using tests =
-    ::testing::Types<Test0, Test0x, Test0f, Test2f, Test2fx1, Test4x, Test4x1, Test4f>;
+using tests = ::testing::Types<Test0>;
 
 INSTANTIATE_TYPED_TEST_SUITE_P(fp16_gemm_test_suite, fp16_gemm_test, tests);
diff --git a/tests/integration/gemm/fp32/common.hpp b/tests/integration/gemm/fp32/common.hpp
index 7ce5098fc..67303c398 100644
--- a/tests/integration/gemm/fp32/common.hpp
+++ b/tests/integration/gemm/fp32/common.hpp
@@ -97,7 +97,7 @@ class Test3 : public TestBase {
   static constexpr size_t mat_m = 16;
   static constexpr size_t mat_n = 64;
   static constexpr size_t mat_k = 32;
-  static constexpr size_t wg_m = 8;
+  static constexpr size_t wg_m = 16;
   static constexpr size_t wg_n = 64;
   static constexpr size_t sg_m = 1;
   static constexpr size_t sg_n = 64;
@@ -205,7 +205,7 @@ class Test8 : public TestBase {
   static constexpr uint32_t global_kslicing = 2;
   static constexpr uint32_t local_kslicing = 1;
   static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::col_major;
+  static constexpr mem_layout layout_b = mem_layout::row_major;
   using data_type_a = float;
   using data_type_b = float;
   using data_type_c = float;
@@ -227,7 +227,6 @@ class Test9 : public TestBase {
   static constexpr uint32_t local_kslicing = 1;
   static constexpr mem_layout layout_a = mem_layout::row_major;
   static constexpr mem_layout layout_b = mem_layout::row_major;
-  static constexpr mma_engine engine = mma_engine::xmx;
   using data_type_a = float;
   using data_type_b = float;
   using data_type_c = float;
@@ -245,10 +244,10 @@ class Test10 : public TestBase {
   static constexpr size_t sg_m = 32;
   static constexpr size_t sg_n = 64;
   static constexpr size_t sg_k = 8;
-  static constexpr uint32_t global_kslicing = 2;
+  static constexpr uint32_t global_kslicing = 1;
   static constexpr uint32_t local_kslicing = 1;
   static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::col_major;
+  static constexpr mem_layout layout_b = mem_layout::row_major;
   using data_type_a = float;
   using data_type_b = float;
   using data_type_c = float;
@@ -258,9 +257,9 @@ class Test10 : public TestBase {
 class Test11 : public TestBase {
  public:
   static constexpr size_t batch_size = 35;
-  static constexpr size_t mat_m = 4192;
-  static constexpr size_t mat_k = 1136;
-  static constexpr size_t mat_n = 688;
+  static constexpr size_t mat_m = 4193;
+  static constexpr size_t mat_k = 1134;
+  static constexpr size_t mat_n = 686;
   static constexpr size_t wg_m = 256;
   static constexpr size_t wg_n = 256;
   static constexpr size_t sg_m = 32;
@@ -314,4 +313,4 @@ class result_validate {
         Test::layout_a,
         Test::layout_b);
   }
-};
+};
\ No newline at end of file
diff --git a/tests/integration/gemm/fp32/kernel_func.hpp b/tests/integration/gemm/fp32/kernel_func.hpp
index 962acfe96..a83518076 100644
--- a/tests/integration/gemm/fp32/kernel_func.hpp
+++ b/tests/integration/gemm/fp32/kernel_func.hpp
@@ -77,6 +77,9 @@ struct fp32_gemm_test_func {
 
   using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
 
+  static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
+  static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
+
   static const char* func_name() {
     return "fp32_gemm_test_func";
   }
diff --git a/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt b/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt
index 4bdf0a42b..6a3a155f3 100644
--- a/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt
+++ b/tests/integration/gemm/int4_dequantization_bias/CMakeLists.txt
@@ -1,11 +1,6 @@
 get_filename_component(ProjectId ${CMAKE_CURRENT_SOURCE_DIR} NAME)
 string(REPLACE " " "_" ProjectId ${ProjectId})
-set(ProjectIdClient ${ProjectId})
-set(ProjectIdXe ${ProjectId})
-string(PREPEND ProjectIdClient "gemm_client_")
-string(PREPEND ProjectIdXe "gemm_xe_")
+string(PREPEND ProjectIdClient "gemm_")
 
-FILE(GLOB src_client main_client.cpp)
-add_integration_test(${ProjectIdClient} ${src_client})
-FILE(GLOB src_xe main_xe.cpp)
-add_integration_test(${ProjectIdXe} ${src_xe})
+FILE(GLOB src main.cpp)
+add_integration_test(${ProjectId} ${src})
diff --git a/tests/integration/gemm/int4_dequantization_bias/main_client.cpp b/tests/integration/gemm/int4_dequantization_bias/main.cpp
similarity index 100%
rename from tests/integration/gemm/int4_dequantization_bias/main_client.cpp
rename to tests/integration/gemm/int4_dequantization_bias/main.cpp
diff --git a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp b/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp
deleted file mode 100644
index 0cc9d8d6f..000000000
--- a/tests/integration/gemm/int4_dequantization_bias/main_xe.cpp
+++ /dev/null
@@ -1,641 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2022-2023 Intel Corporation
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- *******************************************************************************/
-
-#include <utils/utils.hpp>
-#include "xetla.hpp"
-// #define UT_DEBUG 1
-using namespace gpu::xetla;
-// The number of times the kernel is executed
-constexpr int ITER = 100;
-
-class test1 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 1;
-  static constexpr size_t mat_n = 16384;
-  static constexpr size_t mat_k = 4096;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 8;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class test2 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 1;
-  static constexpr size_t mat_n = 4096;
-  static constexpr size_t mat_k = 22016;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 128;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 128;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 4;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv1 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 8;
-  static constexpr size_t mat_n = 12288;
-  static constexpr size_t mat_k = 4096;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 8;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv2 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 8;
-  static constexpr size_t mat_n = 4096;
-  static constexpr size_t mat_k = 4096;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 8;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv3 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 8;
-  static constexpr size_t mat_n = 11008;
-  static constexpr size_t mat_k = 4096;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 8;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv4 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 1;
-  static constexpr size_t mat_n = 4096;
-  static constexpr size_t mat_k = 11008;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 32;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 4;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv5 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 8;
-  static constexpr size_t mat_n = 151936;
-  static constexpr size_t mat_k = 4096;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 8;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv6 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 1;
-  static constexpr size_t mat_n = 12288;
-  static constexpr size_t mat_k = 4096;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 8;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv7 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 1;
-  static constexpr size_t mat_n = 4096;
-  static constexpr size_t mat_k = 4096;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 8;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv8 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 1;
-  static constexpr size_t mat_n = 11008;
-  static constexpr size_t mat_k = 4096;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 8;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv9 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 1;
-  static constexpr size_t mat_n = 4096;
-  static constexpr size_t mat_k = 11008;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 4;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-class qkv10 {
- public:
-  // Extract the parameters required by different test cases
-  static constexpr size_t mat_m = 1;
-  static constexpr size_t mat_n = 151936;
-  static constexpr size_t mat_k = 4096;
-  static constexpr size_t wg_m = 8;
-  static constexpr size_t wg_n = 64;
-  static constexpr size_t sg_m = 8;
-  static constexpr size_t sg_n = 16;
-  static constexpr size_t sg_k = 16;
-  static constexpr size_t dequant_s = 64;
-  static constexpr size_t num_buffer = 64;
-  static constexpr size_t local_kslicing = 8;
-  static constexpr size_t global_kslicing = 1;
-  static constexpr mem_layout layout_a = mem_layout::row_major;
-  static constexpr mem_layout layout_b = mem_layout::row_major;
-  using data_type_a = fp16;
-  using data_type_b = int4x2;
-  using data_type_c = fp16;
-};
-
-template <
-    typename data_type_a,
-    typename data_type_b,
-    typename data_type_c,
-    typename data_type_acc = float,
-    typename data_type_bias = data_type_a>
-int gemm_result_validate(
-    data_type_a* A,
-    data_type_b* B,
-    data_type_c* C,
-    data_type_bias* bias,
-    uint32_t m,
-    uint32_t k,
-    uint32_t n,
-    mem_layout mem_layout_a_ = mem_layout::row_major,
-    mem_layout mem_layout_b_ = mem_layout::row_major) {
-  buff_cmp::buff_vals<data_type_c> data(C, m, n, n);
-  std::vector<data_type_acc> gold_C(m * n, 0);
-  get_gemm_gold<data_type_a, data_type_b, data_type_acc>(
-      m, n, k, mem_layout_a_, mem_layout_b_, A, B, gold_C.data());
-
-  // BiasAdd
-  for (uint32_t i = 0; i < gold_C.size(); ++i) {
-    uint32_t col = i % n;
-    gold_C[i] += bias[col];
-  }
-
-  buff_cmp::buff_vals<data_type_c, data_type_acc> other(gold_C.data(), m, n, n);
-
-  bool result = buff_cmp::xetla_buff_cmp(data, other, "gemm validation");
-
-  std::cout << (!result ? "FAILED\n" : "PASSED\n");
-  return result ? 0 : 1;
-}
-
-template <class Test>
-void dequantize_gemm_run(int iter) {
-  using namespace gpu;
-  // Accept incoming parameters
-  constexpr size_t matrix_m = Test::mat_m;
-  constexpr size_t matrix_n = Test::mat_n;
-  constexpr size_t matrix_k = Test::mat_k;
-  constexpr uint32_t global_kslicing = Test::global_kslicing;
-  constexpr uint32_t local_kslicing = Test::local_kslicing;
-  static constexpr mem_layout layout_b = Test::layout_b;
-  constexpr size_t wg_tile_m = Test::wg_m;
-  constexpr size_t wg_tile_n = Test::wg_n;
-  constexpr size_t sg_tile_m = Test::sg_m;
-  constexpr size_t sg_tile_n = Test::sg_n;
-  constexpr size_t sg_tile_k = Test::sg_k;
-  constexpr size_t dequant_s = Test::dequant_s;
-  using data_type_a = typename Test::data_type_a;
-  using data_type_b = typename Test::data_type_b;
-  using data_type_c = typename Test::data_type_c;
-  using data_type_zero_pt = int4x2;
-  using data_type_scale = fp16;
-  using data_type_acc_in = fp16;
-  using data_type_acc = float;
-  using data_type_bias = fp16;
-
-  constexpr size_t size_a = matrix_m * matrix_k;
-  constexpr size_t size_b = matrix_k * matrix_n / 2;
-
-  constexpr size_t size_scale_m = matrix_k / dequant_s;
-  constexpr size_t size_scale_n = matrix_n;
-  constexpr size_t size_scale = size_scale_m * size_scale_n;
-
-  constexpr size_t size_zero_pt_m = matrix_k / dequant_s;
-  constexpr size_t size_zero_pt_n = matrix_n / 2;
-  constexpr size_t size_zero_pt = size_zero_pt_m * size_zero_pt_n;
-
-  constexpr size_t size_c = matrix_m * matrix_n;
-  constexpr size_t size_bias = matrix_n;
-
-  // Turn on the enable_profiling property to facilitate subsequent profiling
-  sycl::property_list properties{sycl::property::queue::enable_profiling()};
-  auto queue = sycl::queue(properties);
-  auto context = queue.get_info<info::queue::context>();
-  auto device = queue.get_info<info::queue::device>();
-
-  std::cout << "Running on " << device.get_info<info::device::name>() << "\n";
-
-  using tile_shape =
-      xetla::group::tile_shape_t<wg_tile_n, wg_tile_m, sg_tile_n, sg_tile_m>;
-  static constexpr uint32_t periodic_sync_interval = 0;
-  static constexpr uint32_t prefetch_distance = 0;
-
-  using mem_desc_a_t = xetla::mem_desc_t<
-      data_type_a,
-      mem_layout::row_major,
-      mem_space::global,
-      DEVICE_MEM_ALIGNMENT / sizeof(data_type_a)>;
-  using mem_desc_b_t = xetla::mem_desc_t<
-      data_type_b,
-      layout_b,
-      mem_space::global,
-      DEVICE_MEM_ALIGNMENT / sizeof(data_type_b)>;
-  using mem_desc_c_t = xetla::mem_desc_t<
-      data_type_c,
-      mem_layout::row_major,
-      mem_space::global,
-      DEVICE_MEM_ALIGNMENT / sizeof(data_type_c)>;
-
-  using mem_desc_bias_t = xetla::mem_desc_t<
-      data_type_bias,
-      mem_layout::row_major,
-      mem_space::global,
-      DEVICE_MEM_ALIGNMENT / sizeof(data_type_bias)>;
-
-  using compute_attr = xetla::group::
-      compute_attr_t<data_type_acc_in, data_type_acc_in, data_type_acc>;
-  using perf_tuning_knob = xetla::group::
-      perf_tuning_knob_t<sg_tile_k, prefetch_distance, periodic_sync_interval>;
-  static constexpr quant_info quant_info{
-      quant_mode::I4_SYM, Test::dequant_s, layout_b};
-
-  using compute_policy = xetla::group::compute_policy_int4_dequantize<
-      compute_attr,
-      perf_tuning_knob,
-      data_type_scale,
-      data_type_zero_pt,
-      quant_info,
-      mma_engine::xmx,
-      gpu_arch::XeHpc>;
-
-  using gemm_t = xetla::group::
-      gemm_t<compute_policy, tile_shape, mem_desc_a_t, mem_desc_b_t>;
-
-  using bias_op_t =
-      gpu::xetla::subgroup::bias_add_op_t<mem_desc_bias_t, gpu_arch::XeHpc>;
-  using tile_op_t = gpu::xetla::subgroup::chained_tile_op_t<bias_op_t>;
-
-  using epilogue_t = xetla::group::epilogue_t<
-      xetla::group::epilogue_policy_tile_op<tile_op_t, gpu_arch::XeHpc>,
-      tile_shape,
-      mem_desc_c_t>;
-
-  using group_swizzle = xetla::kernel::group_swizzle_default<gpu_arch::XeHpc>;
-  using gemm_op_t = xetla::kernel::gemm_universal_t<
-      gpu::xetla::kernel::dispatch_policy_int4_dequantize_kslicing<
-          group_swizzle,
-          global_kslicing,
-          local_kslicing>,
-      gemm_t,
-      epilogue_t>;
-
-  size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n);
-  size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n);
-
-  // Define and initialize the data required for the calculation
-  auto* A_h = static_cast<data_type_a*>(
-      malloc_host(size_a * sizeof(data_type_a), context));
-  auto* B_h = static_cast<data_type_b*>(
-      malloc_host(size_b * sizeof(data_type_b), context));
-  auto* C_h = static_cast<data_type_c*>(
-      malloc_host(size_c * sizeof(data_type_c), context));
-  auto* Acc_h = static_cast<data_type_acc*>(
-      malloc_host(size_acc * sizeof(data_type_acc), context));
-  auto* Cnt_h =
-      static_cast<uint32_t*>(malloc_host(size_cnt * sizeof(uint32_t), context));
-  auto* scale_h = static_cast<data_type_scale*>(
-      malloc_host(size_scale * sizeof(data_type_scale), context));
-  auto* zero_pt_h = static_cast<data_type_zero_pt*>(
-      malloc_host(size_zero_pt * sizeof(data_type_zero_pt), context));
-  auto* bias_h = static_cast<data_type_bias*>(
-      malloc_host(size_bias * sizeof(data_type_bias), context));
-
-  auto* A_d = static_cast<data_type_a*>(aligned_alloc_device(
-      DEVICE_MEM_ALIGNMENT, size_a * sizeof(data_type_a), device, context));
-  auto* B_d = static_cast<data_type_b*>(aligned_alloc_device(
-      DEVICE_MEM_ALIGNMENT, size_b * sizeof(data_type_b), device, context));
-  auto* C_d = static_cast<data_type_c*>(aligned_alloc_device(
-      DEVICE_MEM_ALIGNMENT, size_c * sizeof(data_type_c), device, context));
-  auto* Acc_d = static_cast<data_type_acc*>(aligned_alloc_device(
-      DEVICE_MEM_ALIGNMENT, size_acc * sizeof(data_type_acc), device, context));
-  auto* Cnt_d = static_cast<uint32_t*>(aligned_alloc_device(
-      DEVICE_MEM_ALIGNMENT, size_cnt * sizeof(uint32_t), device, context));
-  auto* scale_d = static_cast<data_type_scale*>(aligned_alloc_device(
-      DEVICE_MEM_ALIGNMENT,
-      size_scale * sizeof(data_type_scale),
-      device,
-      context));
-  auto* zero_pt_d = static_cast<data_type_zero_pt*>(aligned_alloc_device(
-      DEVICE_MEM_ALIGNMENT,
-      size_zero_pt * sizeof(data_type_zero_pt),
-      device,
-      context));
-  auto* bias_d = static_cast<data_type_bias*>(aligned_alloc_device(
-      DEVICE_MEM_ALIGNMENT,
-      size_bias * sizeof(data_type_bias),
-      device,
-      context));
-
-  for (unsigned i = 0; i < size_a; ++i) {
-    A_h[i] = random_float();
-#ifdef UT_DEBUG
-    A_h[i] = 1.f;
-#endif
-  }
-  for (unsigned i = 0; i < size_b; ++i) {
-    B_h[i] = uint8_t(random_uint8());
-#ifdef UT_DEBUG
-    B_h[i] = 153;
-#endif
-  }
-  for (unsigned i = 0; i < size_scale; ++i) {
-    scale_h[i] = random_float();
-#ifdef UT_DEBUG
-    scale_h[i] = 1.f;
-#endif
-  }
-  for (unsigned i = 0; i < size_zero_pt; ++i) {
-    zero_pt_h[i] = 0.f;
-  }
-  for (unsigned i = 0; i < size_c; ++i) {
-    C_h[i] = 0;
-  }
-  for (unsigned i = 0; i < size_acc; ++i) {
-    Acc_h[i] = 0;
-  }
-  for (unsigned i = 0; i < size_cnt; ++i) {
-    Cnt_h[i] = 0;
-  }
-  for (unsigned i = 0; i < size_bias; ++i) {
-    bias_h[i] = random_float();
-#ifdef UT_DEBUG
-    bias_h[i] = 0.f;
-#endif
-  }
-
-  queue.memcpy((void*)A_d, (void*)A_h, size_a * sizeof(data_type_a)).wait();
-  queue.memcpy((void*)B_d, (void*)B_h, size_b * sizeof(data_type_b)).wait();
-  queue.memcpy((void*)C_d, (void*)C_h, size_c * sizeof(data_type_c)).wait();
-  queue.memcpy((void*)Acc_d, (void*)Acc_h, size_acc * sizeof(data_type_acc))
-      .wait();
-  queue.memcpy((void*)Cnt_d, (void*)Cnt_h, size_cnt * sizeof(uint32_t)).wait();
-  queue
-      .memcpy(
-          (void*)scale_d, (void*)scale_h, size_scale * sizeof(data_type_scale))
-      .wait();
-  queue
-      .memcpy(
-          (void*)zero_pt_d,
-          (void*)zero_pt_h,
-          size_zero_pt * sizeof(data_type_zero_pt))
-      .wait();
-  queue.memcpy((void*)bias_d, (void*)bias_h, size_bias * sizeof(data_type_bias))
-      .wait();
-
-  // set up gemm arguments
-  typename bias_op_t::shape_t bias_add_shape(matrix_n, 1, matrix_n);
-  using epilogue_args_t = epilogue_t::arguments_t;
-
-  epilogue_args_t epilogue_args(
-      {// epilogue_args init list
-       // It accepts the base pointer to matrix D, and its dimensions
-       {bias_d, bias_add_shape}});
-
-  typename gemm_op_t::template arguments_t<compute_policy::quant_mode> gemm_arg(
-      matrix_m,
-      matrix_k,
-      matrix_n,
-      A_d,
-      matrix_k,
-      B_d,
-      matrix_n,
-      C_d,
-      matrix_n,
-      scale_d,
-      matrix_n,
-      Acc_d,
-      Cnt_d,
-      epilogue_args);
-
-  cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(gemm_arg);
-  if (!gemm_op_t::can_implement(gemm_arg)) {
-    std::cout << "The arguments cannot be supported, aborting ... "
-              << std::endl;
-    FAIL();
-  }
-
-  size_t ops = 2 * matrix_m * matrix_n * matrix_k + matrix_m * matrix_n;
-  profiling_helper prof("dequantize_gemm", ops, "gflops");
-  try {
-    for (int i = 0; i < iter; i++) {
-      prof.cpu_start();
-      auto e_esimd = queue.submit([&](handler& cgh) {
-        cgh.parallel_for<Test>(
-            nd_range, [=](nd_item<3> item) SYCL_ESIMD_KERNEL {
-              // allocate slm and nbarrier resource
-              slm_barrier_init<gemm_op_t>();
-              gemm_op_t gemm_op;
-              gemm_op(item, gemm_arg);
-            });
-      });
-      e_esimd.wait();
-      prof.cpu_end();
-      prof.add_gpu_event(e_esimd);
-    }
-  } catch (cl::sycl::exception const& e) {
-    std::cout << "SYCL exception caught: " << e.what() << '\n';
-    FAIL();
-  }
-
-  // performance
-  prof.print_profiling_result(profiling_selector::GPU);
-
-  std::vector<fp16> dequantize_b(matrix_k * matrix_n, 0);
-  for (uint32_t i = 0; i < matrix_k / dequant_s; i++) {
-    for (uint32_t j = 0; j < matrix_n / 2; j++) {
-      int start_in = i * dequant_s * matrix_n / 2 + j;
-      int start_out = i * dequant_s * matrix_n + j * 2;
-      int start_scale = i * size_scale_n + j * 2;
-      for (uint32_t ii = 0; ii < dequant_s; ii++) {
-        uint8_t data_in = B_h[start_in + ii * matrix_n / 2];
-        int8_t data_0 = int8_t(data_in & 0x0f) - 8;
-        int8_t data_1 = int8_t(data_in >> 4) - 8;
-        dequantize_b[start_out + ii * matrix_n] =
-            fp16(data_0) * scale_h[start_scale];
-        dequantize_b[start_out + ii * matrix_n + 1] =
-            fp16(data_1) * scale_h[start_scale + 1];
-      }
-    }
-  }
-
-  queue.memcpy((void*)C_h, (void*)C_d, size_c * sizeof(data_type_c)).wait();
-  ASSERT_EQ(
-      0,
-      gemm_result_validate(
-          A_h, dequantize_b.data(), C_h, bias_h, matrix_m, matrix_k, matrix_n));
-
-  free(A_h, context);
-  free(B_h, context);
-  free(C_h, context);
-  free(scale_h, context);
-  free(zero_pt_h, context);
-  free(A_d, context);
-  free(B_d, context);
-  free(C_d, context);
-  free(scale_d, context);
-  free(zero_pt_d, context);
-  free(Acc_h, context);
-  free(Cnt_h, context);
-  free(Acc_d, context);
-  free(Cnt_d, context);
-}
-
-template <typename T>
-class dequantize_gemm_test : public ::testing::Test {};
-TYPED_TEST_SUITE_P(dequantize_gemm_test);
-
-TYPED_TEST_P(dequantize_gemm_test, esimd) {
-  dequantize_gemm_run<TypeParam>(ITER);
-}
-
-REGISTER_TYPED_TEST_SUITE_P(dequantize_gemm_test, esimd);
-using tests = ::testing::Types<qkv6>;
-// using tests = ::testing::Types<qkv1, qkv2, qkv3, qkv4, qkv5, qkv6, qkv7,
-// qkv8,
-//         qkv9, qkv10>;
-
-INSTANTIATE_TYPED_TEST_SUITE_P(
-    dequantize_gemm_test_suite,
-    dequantize_gemm_test,
-    tests);
diff --git a/tests/integration/gemm/int8/kernel_func.hpp b/tests/integration/gemm/int8/kernel_func.hpp
index 3a9e595ca..da2eab012 100644
--- a/tests/integration/gemm/int8/kernel_func.hpp
+++ b/tests/integration/gemm/int8/kernel_func.hpp
@@ -72,6 +72,9 @@ struct int8gemm_test_func {
 
   using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
 
+  static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
+  static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
+
   static const char* func_name() {
     return "int8gemm_test_func";
   }
diff --git a/tests/integration/gemm/tf32/kernel_func.hpp b/tests/integration/gemm/tf32/kernel_func.hpp
index 8c2850b9b..42d69eb51 100644
--- a/tests/integration/gemm/tf32/kernel_func.hpp
+++ b/tests/integration/gemm/tf32/kernel_func.hpp
@@ -71,6 +71,9 @@ struct tf32_gemm_test_func {
 
   using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
 
+  static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
+  static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
+
   static const char* func_name() {
     return "tf32_gemm_test_func";
   }
diff --git a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp
index d45ddc0b7..d5534ebdf 100644
--- a/tests/integration/gemm/unaligned_bf16/kernel_func.hpp
+++ b/tests/integration/gemm/unaligned_bf16/kernel_func.hpp
@@ -68,13 +68,20 @@ struct unaligned_gemm_test_func {
   using epilogue_t = epilogue_t<
       epilogue_policy_unaligned<arch_tag>,
       tile_shape,
-      mem_desc_t<dtype_c, mem_layout::row_major, mem_space::global, ldc_alignment>>;
+      mem_desc_t<
+          dtype_c,
+          mem_layout::row_major,
+          mem_space::global,
+          ldc_alignment>>;
 
   using group_swizzle = gpu::xetla::kernel::group_swizzle_default<arch_tag>;
   using dispatch_policy =
       dispatch_policy_kslicing<group_swizzle, global_kslicing, local_kslicing>;
   using gemm_op_t = gemm_universal_t<dispatch_policy, gemm_t, epilogue_t>;
 
+  static constexpr uint32_t barrier_count = gemm_op_t::get_barrier_count();
+  static constexpr uint32_t slm_size = gemm_op_t::get_slm_size();
+
   static const char* func_name() {
     return "unaligned_gemm_test_func";
   }
diff --git a/tests/integration/gemm/unaligned_bf16/main.cpp b/tests/integration/gemm/unaligned_bf16/main.cpp
index 9ceaf695e..3d23c78ce 100644
--- a/tests/integration/gemm/unaligned_bf16/main.cpp
+++ b/tests/integration/gemm/unaligned_bf16/main.cpp
@@ -31,10 +31,7 @@ TYPED_TEST_P(unaligned_gemm_test, esimd) {
   gemm_exec<
       TypeParam,
       result_validate<TypeParam>,
-      unaligned_gemm_func<TypeParam>,
-      unaligned_gemm_func<TypeParam>::gemm_op_t::get_slm_size(),
-      unaligned_gemm_func<TypeParam>::gemm_op_t::get_barrier_count()>(
-      esimd_compile_string);
+      unaligned_gemm_func<TypeParam>>(esimd_compile_string);
 }
 REGISTER_TYPED_TEST_SUITE_P(unaligned_gemm_test, esimd);
 using tests = ::testing::Types<
diff --git a/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp b/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp
index bc7f63af9..8c47795fc 100644
--- a/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp
+++ b/tests/integration/mlp/int4/int4_mlp_gate_mul_up_fwd.hpp
@@ -188,7 +188,7 @@ class global_sum_reduce_two_mat_t {
           mat_zero, matAcc1_payload);
       subgroup::tile_store<cache_hint::uncached, cache_hint::write_back>(
           mat_zero, matAcc2_payload);
-      SW_BARRIER();
+      sw_barrier();
     }
   }
 };
diff --git a/tests/integration/vector_add/int32_1d/kernel_func.hpp b/tests/integration/vector_add/int32_1d/kernel_func.hpp
index bca6071d1..3dbbbac47 100644
--- a/tests/integration/vector_add/int32_1d/kernel_func.hpp
+++ b/tests/integration/vector_add/int32_1d/kernel_func.hpp
@@ -36,7 +36,7 @@ KERNEL_FUNC inline void vector_add_func(
   /// use block prefetch for b
   xetla_prefetch_global<dtype, SIMD, cache_hint::cached, cache_hint::cached>(
       b, offset);
-  SW_BARRIER();
+  sw_barrier();
   /// use scattered load for a
   xetla_vector<dtype, SIMD> ivector1 = xetla_load_global<
       dtype,
diff --git a/tests/integration/vector_add/tf32_1d/kernel_func.hpp b/tests/integration/vector_add/tf32_1d/kernel_func.hpp
index f1a6f177d..79872dec3 100644
--- a/tests/integration/vector_add/tf32_1d/kernel_func.hpp
+++ b/tests/integration/vector_add/tf32_1d/kernel_func.hpp
@@ -36,7 +36,7 @@ KERNEL_FUNC inline void vector_add_func(
   /// use block prefetch for b
   xetla_prefetch_global<dtype, SIMD, cache_hint::cached, cache_hint::cached>(
       b, offset);
-  SW_BARRIER();
+  sw_barrier();
   /// use scattered load for a
   xetla_vector<dtype, SIMD> ivector1 = xetla_load_global<
       dtype,
diff --git a/tests/unit/block_load_store/kernel_func.hpp b/tests/unit/block_load_store/kernel_func.hpp
index 17465f08c..3c80ed99d 100644
--- a/tests/unit/block_load_store/kernel_func.hpp
+++ b/tests/unit/block_load_store/kernel_func.hpp
@@ -44,7 +44,7 @@ struct block_load_store_func {
         cache_hint::cached,
         cache_hint::cached,
         arch_tag>(src_tdesc);
-    SW_BARRIER();
+    sw_barrier();
     xetla_vector<dtype, bwidth* bheight> A_load_vec = xetla_tload_global<
         dtype,
         bwidth * bheight,
diff --git a/tests/unit/tile_load_store/kernel_func.hpp b/tests/unit/tile_load_store/kernel_func.hpp
index 89832a46f..cdd4a59f2 100644
--- a/tests/unit/tile_load_store/kernel_func.hpp
+++ b/tests/unit/tile_load_store/kernel_func.hpp
@@ -242,7 +242,7 @@ struct tile_load_store_atomic_func {
     matBias.reg = matA.reg;
     matA.reg = 0;
     tile_store(matA, payload_store);
-    SW_BARRIER();
+    sw_barrier();
     tile_store(matBias, payload_store_add, check_tag);
   }
 };
diff --git a/tests/unit/tile_load_store/main.cpp b/tests/unit/tile_load_store/main.cpp
index a08f71d57..bc24278b7 100644
--- a/tests/unit/tile_load_store/main.cpp
+++ b/tests/unit/tile_load_store/main.cpp
@@ -271,34 +271,34 @@ TEST(tile_load_store_atomic_disable_oob_check, esimd) {
           false>>(nd_range, result_validate);
 }
 
-TEST(tile_load_store_atomic_boundary, esimd) {
-  cl::sycl::nd_range<1> nd_range({1}, {1});
-  auto result_validate = std::bind(
-      tile_load_store_result_validate<float>,
-      _1,
-      _2,
-      _3,
-      128,
-      33554440,
-      32,
-      32,
-      33554432);
-  kernel_run<
-      float,
-      tile_load_store_atomic_func<
-          float,
-          128,
-          33554440,
-          128,
-          32,
-          32,
-          16,
-          16,
-          true>,
-      128 * 1024,
-      32,
-      4294968320U>(nd_range, result_validate);
-}
+// TEST(tile_load_store_atomic_boundary, esimd) {
+//   cl::sycl::nd_range<1> nd_range({1}, {1});
+//   auto result_validate = std::bind(
+//       tile_load_store_result_validate<float>,
+//       _1,
+//       _2,
+//       _3,
+//       128,
+//       33554440,
+//       32,
+//       32,
+//       33554432);
+//   kernel_run<
+//       float,
+//       tile_load_store_atomic_func<
+//           float,
+//           128,
+//           33554440,
+//           128,
+//           32,
+//           32,
+//           16,
+//           16,
+//           true>,
+//       128 * 1024,
+//       32,
+//       4294968320U>(nd_range, result_validate);
+// }
 
 TEST(tile_load_broadcast_store, esimd) {
   cl::sycl::nd_range<1> nd_range({1}, {1});
@@ -318,25 +318,25 @@ TEST(tile_load_store_1d, esimd) {
       nd_range, result_validate);
 }
 
-TEST(tile_load_store_1d_boundary, esimd) {
-  cl::sycl::nd_range<1> nd_range({1}, {1});
-  auto result_validate = std::bind(
-      tile_load_store_result_validate<int>,
-      _1,
-      _2,
-      _3,
-      128,
-      33554440,
-      128,
-      1,
-      33554432);
-  kernel_run<
-      int,
-      tile_load_store_1d_func<int, 128, 33554440, 128, 128, 1, 128, 1, true>,
-      128 * 1024,
-      32,
-      4294968320U>(nd_range, result_validate);
-}
+// TEST(tile_load_store_1d_boundary, esimd) {
+//   cl::sycl::nd_range<1> nd_range({1}, {1});
+//   auto result_validate = std::bind(
+//       tile_load_store_result_validate<int>,
+//       _1,
+//       _2,
+//       _3,
+//       128,
+//       33554440,
+//       128,
+//       1,
+//       33554432);
+//   kernel_run<
+//       int,
+//       tile_load_store_1d_func<int, 128, 33554440, 128, 128, 1, 128, 1, true>,
+//       128 * 1024,
+//       32,
+//       4294968320U>(nd_range, result_validate);
+// }
 
 TEST(tile_load_store_unaligned_2d, esimd) {
   cl::sycl::nd_range<1> nd_range({1}, {1});
diff --git a/tests/unit/tile_mma/kernel_func.hpp b/tests/unit/tile_mma/kernel_func.hpp
index d8130c4a5..926b20c17 100644
--- a/tests/unit/tile_mma/kernel_func.hpp
+++ b/tests/unit/tile_mma/kernel_func.hpp
@@ -103,9 +103,9 @@ struct tile_mma_func {
         matA, matA_payload);
     subgroup::tile_load<cache_hint::cached, cache_hint::cached>(
         matB, matB_payload);
-    SW_BARRIER();
+    sw_barrier();
     tile_mma::mma(matAcc, matAcc, matB, matA);
-    SW_BARRIER();
+    sw_barrier();
     matC.reg = xetla_cvt<dtypeC, dtypeAcc, matAcc_t::tile_desc::tile_elems>(
         matAcc.reg);
     matC_payload.init(c, n, m, n, 0, 0);
diff --git a/tests/utils/common.hpp b/tests/utils/common.hpp
index 4a79de7c4..991c9cc9a 100644
--- a/tests/utils/common.hpp
+++ b/tests/utils/common.hpp
@@ -90,6 +90,7 @@ inline auto getTypeName<gpu::xetla::tf32>() {
 }
 
 enum class test_result : uint8_t { complete = 0, skip = 1, fail = 2 };
+enum Direction { FWD = 0, BWDD = 1, BWDW = 2 };
 
 template <typename result_type>
 inline result_type generate_real_random(
@@ -129,6 +130,28 @@ inline data_type* alloc_host(size_t size) {
   return host_ptr;
 }
 
+template <typename data_type>
+using init_func_t = std::function<void(data_type* data, size_t elements)>;
+
+template <typename data_type>
+void index_init_func(data_type* data, size_t idx) {
+  data[idx] = static_cast<data_type>(idx);
+}
+template <typename data_type>
+void no_init_func(
+    [[maybe_unused]] data_type* data,
+    [[maybe_unused]] size_t idx) {}
+
+template <typename data_type>
+void rand_init_func(data_type* data, size_t idx) {
+  data[idx] = static_cast<data_type>(random_float() - 0.5f);
+}
+
+template <typename data_type>
+void zero_init_func(data_type* data, size_t idx) {
+  data[idx] = 0;
+}
+
 template <typename data_type>
 inline data_type* alloc_host_and_init(
     size_t size,
diff --git a/tests/utils/execution.hpp b/tests/utils/execution.hpp
index 85d722ce9..843420373 100644
--- a/tests/utils/execution.hpp
+++ b/tests/utils/execution.hpp
@@ -1,5 +1,5 @@
 /*******************************************************************************
- * Copyright (c) 2022-2023 Intel Corporation
+ * Copyright (c) 2022-2024 Intel Corporation
  *
  * Licensed under the Apache License, Version 2.0 (the "License");
  * you may not use this file except in compliance with the License.
@@ -20,22 +20,26 @@
 #include <stdexcept>
 #include "common.hpp"
 #include "profiling.hpp"
+#ifdef _WIN32
+#include "windows_functions.hpp"
+#endif
 #include "xetla.hpp"
 
 using namespace cl::sycl;
 using namespace gpu;
 using namespace gpu::xetla;
 
-template <
-    class Test,
-    typename validate_func,
-    typename KERNEL,
-    int SLMSIZE = arch_attr_t<Test::gpu_arch>::local_mem_size,
-    int BARNUM = 32>
-void gemm_exec(const std::string& compile_str, size_t batch = 1) {
+template <typename Test, typename validate_func, typename kernel_t>
+void gemm_exec(
+    const std::string& compile_str,
+    size_t batch = 1,
+    size_t scaling = 1) {
   test_result result = test_result::complete;
 
-  using gemm_op_t = typename KERNEL::gemm_op_t;
+  using gemm_op_t = typename kernel_t::gemm_op_t;
+
+  constexpr uint32_t slm_size = kernel_t::slm_size;
+  constexpr uint32_t barrier_num = kernel_t::barrier_count;
 
   using data_type_a = typename Test::data_type_a;
   using data_type_b = typename Test::data_type_b;
@@ -46,6 +50,12 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) {
   constexpr size_t matrix_n = Test::mat_n;
   constexpr size_t matrix_k = Test::mat_k;
 
+  [[maybe_unused]] constexpr size_t wg_tile_m = Test::wg_m;
+  [[maybe_unused]] constexpr size_t wg_tile_n = Test::wg_n;
+  [[maybe_unused]] constexpr size_t sg_tile_m = Test::sg_m;
+  [[maybe_unused]] constexpr size_t sg_tile_n = Test::sg_n;
+  [[maybe_unused]] constexpr size_t sg_tile_k = Test::sg_k;
+
   size_t size_a = matrix_m * matrix_k;
   size_t size_b = matrix_k * matrix_n;
   size_t size_c = matrix_m * matrix_n;
@@ -59,16 +69,16 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) {
 
   auto A = alloc_device_and_init<data_type_a>(
       batch * size_a,
-      [](data_type_a* data, size_t idx) {
-        data[idx] = static_cast<data_type_a>(random_float());
+      [&scaling](data_type_a* data, size_t idx) {
+        data[idx] = static_cast<data_type_a>(scaling * (random_float() - 0.5f));
       },
       queue,
       device,
       context);
   auto B = alloc_device_and_init<data_type_b>(
       batch * size_b,
-      [](data_type_b* data, size_t idx) {
-        data[idx] = static_cast<data_type_b>(random_float());
+      [&scaling](data_type_b* data, size_t idx) {
+        data[idx] = static_cast<data_type_b>(scaling * (random_float() - 0.5f));
       },
       queue,
       device,
@@ -81,6 +91,7 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) {
       queue,
       device,
       context);
+
   size_t size_acc = gemm_op_t::get_acc_buf_size(matrix_m, matrix_n);
   size_t size_cnt = gemm_op_t::get_cnt_buf_size(matrix_m, matrix_n);
   auto Acc = alloc_device_and_init<data_type_acc>(
@@ -97,20 +108,20 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) {
       queue,
       device,
       context);
-
-  size_t ops = 2 * matrix_m * matrix_n * matrix_k;
+  long ops = 2 * static_cast<long>(matrix_m) * matrix_n * matrix_k;
   profiling_helper prof("gemm", ops, "gflops");
-
   try {
     std::vector<kernel_id> kernelId = {get_kernel_id<Test>()};
     auto inputBundle =
         get_kernel_bundle<bundle_state::input>(context, kernelId);
-    static const std::string env_set_str =
-        "SYCL_PROGRAM_COMPILE_OPTIONS=" + compile_str;
-    putenv(const_cast<char*>(env_set_str.c_str()));
+    char* value = getenv("GOGRITS");
+    if (value == NULL || strcmp(value, "on") != 0) {
+      setenv("SYCL_PROGRAM_COMPILE_OPTIONS", compile_str.c_str(), 1);
+    }
     kernel_bundle<bundle_state::executable> exeBundle = build(inputBundle);
-    static const std::string env_unset_str = "SYCL_PROGRAM_COMPILE_OPTIONS=";
-    putenv(const_cast<char*>(env_unset_str.c_str()));
+    if (value == NULL || strcmp(value, "on") != 0) {
+      unsetenv("SYCL_PROGRAM_COMPILE_OPTIONS");
+    }
 
     using namespace gpu::xetla::group;
     using namespace gpu::xetla::kernel;
@@ -130,10 +141,10 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) {
         nullptr);
 
     cl::sycl::nd_range<3> nd_range = gemm_op_t::get_nd_range(arg);
-
     int constexpr warm_up = 10;
     int constexpr iters = 100;
     for (size_t i = 0; i < batch; i++) {
+      prof.cpu_start();
       auto A_ptr = A + i * size_a;
       auto B_ptr = B + i * size_b;
       auto C_ptr = C + i * size_c;
@@ -157,11 +168,13 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) {
           prof.cpu_start();
         }
         auto e_esimd = queue.submit([&](handler& cgh) {
-          cgh.use_kernel_bundle(exeBundle);
+          if (value == NULL || strcmp(value, "on") != 0) {
+            cgh.use_kernel_bundle(exeBundle);
+          }
           cgh.parallel_for<Test>(nd_range, [=](nd_item<3> item) KERNEL_MAIN {
-            gpu::xetla::xetla_local_init<SLMSIZE>();
-            gpu::xetla::xetla_nbarrier_init<BARNUM>();
-            KERNEL::run(
+            gpu::xetla::xetla_local_init<slm_size>();
+            gpu::xetla::xetla_nbarrier_init<barrier_num>();
+            kernel_t::run(
                 item,
                 A_ptr,
                 B_ptr,
@@ -184,9 +197,7 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) {
     std::cout << "SYCL exception caught: " << e.what() << '\n';
     result = test_result::fail;
   }
-
-  // performance
-  prof.print_profiling_result(profiling_selector::GPU);
+  unsetenv("SYCL_PROGRAM_COMPILE_OPTIONS");
   // validation
   if (result == test_result::complete) {
     validate_func vfunc;
@@ -204,6 +215,7 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) {
   } else if (result != test_result::complete) {
     FAIL();
   }
+  prof.print_profiling_result(profiling_selector::GPU);
 }
 
 /// @brief The template function to execute kernel in esimd way for unit test
@@ -211,54 +223,41 @@ void gemm_exec(const std::string& compile_str, size_t batch = 1) {
 ///
 /// @tparam data_type data_type The data type of buffer used in kernel and
 /// buffer allocation
-/// @tparam KERNEL the kernel function struct
+/// @tparam kernel_t the kernel function struct
 /// @param nd_range the range of workitems
-/// @param validate_result validation function, taking 3 parameters buffer A, B
-/// as input C as output
+/// @param validate_result validation function, taking 3 parameters buffer A,
+/// B as input C as output
 ///
 template <
     typename data_type,
-    class KERNEL,
-    size_t SLMSIZE = 8 * 1024,
-    size_t BARNUM = 32,
-    size_t Size = 4096>
-void kernel_run(auto nd_range, auto validate_result) {
+    typename kernel_t,
+    size_t slm_size = 8 * 1024,
+    size_t barrier_num = 32,
+    size_t size = 4096>
+void kernel_run(
+    auto nd_range,
+    auto validate_result,
+    init_func_t<data_type> init_func_a = index_init_func<data_type>,
+    init_func_t<data_type> init_func_b = index_init_func<data_type>,
+    init_func_t<data_type> init_func_c = no_init_func<data_type>) {
   queue queue{};
   auto context = queue.get_info<info::queue::context>();
   auto device = queue.get_info<info::queue::device>();
   std::cout << "Running on " << device.get_info<info::device::name>() << "\n";
 
   auto A = alloc_device_and_init<data_type>(
-      Size,
-      [](data_type* data, size_t idx) {
-        data[idx] = static_cast<data_type>(idx);
-      },
-      queue,
-      device,
-      context);
+      size, init_func_a, queue, device, context);
   auto B = alloc_device_and_init<data_type>(
-      Size,
-      [](data_type* data, size_t idx) {
-        data[idx] = static_cast<data_type>(idx);
-      },
-      queue,
-      device,
-      context);
+      size, init_func_b, queue, device, context);
   auto C = alloc_device_and_init<data_type>(
-      Size,
-      [](data_type* data, size_t idx) {
-        data[idx] = static_cast<data_type>(idx);
-      },
-      queue,
-      device,
-      context);
+      size, init_func_c, queue, device, context);
 
   try {
     auto e_esimd = queue.submit([&](handler& cgh) {
       cgh.parallel_for<>(nd_range, [=](nd_item<1> ndi) KERNEL_MAIN {
-        gpu::xetla::xetla_local_init<SLMSIZE>();
-        gpu::xetla::xetla_nbarrier_init<BARNUM>();
-        KERNEL::run(&ndi, A, B, C);
+        gpu::xetla::xetla_local_init<slm_size>();
+        gpu::xetla::xetla_nbarrier_init<barrier_num>();
+        kernel_t::run(&ndi, A, B, C);
       });
     });
     e_esimd.wait();
@@ -267,9 +266,9 @@ void kernel_run(auto nd_range, auto validate_result) {
     FAIL();
   }
 
-  auto A_host = alloc_host_and_copy<data_type>(A, Size, queue);
-  auto B_host = alloc_host_and_copy<data_type>(B, Size, queue);
-  auto C_host = alloc_host_and_copy<data_type>(C, Size, queue);
+  auto A_host = alloc_host_and_copy<data_type>(A, size, queue);
+  auto B_host = alloc_host_and_copy<data_type>(B, size, queue);
+  auto C_host = alloc_host_and_copy<data_type>(C, size, queue);
 
   ASSERT_EQ(0, validate_result(A_host, B_host, C_host));