diff --git a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt index 9c5bbae1022f7..cfad07e57021f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/LLVMIR/CMakeLists.txt @@ -87,3 +87,13 @@ mlir_tablegen(VCIXConversions.inc -gen-llvmir-conversions) mlir_tablegen(VCIXOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=vcix) mlir_tablegen(VCIXOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=vcix) add_public_tablegen_target(MLIRVCIXConversionsIncGen) + +add_mlir_dialect(XeVMOps xevm) +add_mlir_doc(XeVMOps XeVMDialect Dialects/ -gen-dialect-doc -dialect=xevm) +set(LLVM_TARGET_DEFINITIONS XeVMOps.td) +mlir_tablegen(XeVMConversions.inc -gen-llvmir-conversions) +mlir_tablegen(XeVMOpsEnums.h.inc -gen-enum-decls) +mlir_tablegen(XeVMOpsEnums.cpp.inc -gen-enum-defs) +mlir_tablegen(XeVMOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=xevm) +mlir_tablegen(XeVMOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=xevm) +add_public_tablegen_target(MLIRXeVMConversionsIncGen) diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h new file mode 100644 index 0000000000000..a83d4248c862c --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMDialect.h @@ -0,0 +1,28 @@ +//===-- XeVMDialect.h - MLIR XeVM target definitions ------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ +#define MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ + +#include "mlir/Bytecode/BytecodeOpInterface.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Target/LLVMIR/ModuleTranslation.h" + +#include + +#define GET_ATTRDEF_CLASSES +#include + +#define GET_OP_CLASSES +#include + +#include + +#endif /* MLIR_DIALECT_LLVMIR_XEVMDIALECT_H_ */ diff --git a/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td new file mode 100644 index 0000000000000..ca055670a9527 --- /dev/null +++ b/mlir/include/mlir/Dialect/LLVMIR/XeVMOps.td @@ -0,0 +1,589 @@ +//===-- XeVMOps.td - XeVM dialect definition ---------------*- tablegen -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef XEVMIR_OPS +#define XEVMIR_OPS + +include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td" +include "mlir/Dialect/LLVMIR/LLVMOpBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" + +include "mlir/IR/OpBase.td" +include "mlir/IR/EnumAttr.td" + +def XeVM_Dialect : Dialect { + let name = "xevm"; + let cppNamespace = "::mlir::xevm"; + let summary = "The XeVM dialect that extends LLVM dialect and models Intel " + "GPU's hardware features."; + let description = [{ + The XeVM dialect is extension to the LLVM dialect that models hardware + features of Intel GPUs. The dialect is designed to work with the Xe + architecture for Intel GPUs, supporting advanced operations like 2D block + loads, stores, prefetch and matrix multiply-add (MMA) operations. + }]; + let dependentDialects = ["LLVM::LLVMDialect"]; + + let extraClassDeclaration = [{ + /// Get the name for the attribute used to specify cache control + /// decorations. + static constexpr ::llvm::StringRef getCacheControlsAttrName() { + return ::llvm::StringLiteral("xevm.DecorationCacheControl"); + } + }]; + + let useDefaultAttributePrinterParser = 1; +} + +class XeVM_Attr traits = []> + : AttrDef { + let mnemonic = attrMnemonic; +} + +class XeVM_Op traits = []> + : Op { + + code extraBaseClassDeclaration = [{ + void printProperties(::mlir::MLIRContext *ctx, + ::mlir::OpAsmPrinter &p, const Properties &prop, + ::mlir::ArrayRef<::llvm::StringRef> elidedProps) { + Attribute propAttr = getPropertiesAsAttr(ctx, prop); + if (propAttr) + p << "<" << propAttr << ">"; + } + + static ::mlir::ParseResult parseProperties(::mlir::OpAsmParser &parser, + ::mlir::OperationState &result) { + if (mlir::succeeded(parser.parseOptionalLess())) { + if (parser.parseAttribute(result.propertiesAttr) || parser.parseGreater()) + return failure(); + } + return success(); + } + + }]; +} + +def XeVM_ElemType : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>; + +//===----------------------------------------------------------------------===// +// XeVM Load Cache Control +// L1, L2, L3 - cache levels +// uc - uncached +// c - cached +// s - streaming +// ir - invalidated after read +// Default - default cache behavior for L1, L2 and L3 cache +//===----------------------------------------------------------------------===// + +def LoadCacheControlDefault : I32EnumAttrCase<"DEFAULT", 0, "Default">; +def LoadCacheControl_L1uc_L2uc_L3uc + : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">; +def LoadCacheControl_L1uc_L2uc_L3c + : I32EnumAttrCase<"L1UC_L2UC_L3C", 2, "L1uc_L2uc_L3c">; +def LoadCacheControl_L1uc_L2c_L3uc + : I32EnumAttrCase<"L1UC_L2C_L3UC", 3, "L1uc_L2c_L3uc">; +def LoadCacheControl_L1uc_L2c_L3c + : I32EnumAttrCase<"L1UC_L2C_L3C", 4, "L1uc_L2c_L3c">; +def LoadCacheControl_L1c_L2uc_L3uc + : I32EnumAttrCase<"L1C_L2UC_L3UC", 5, "L1c_L2uc_L3uc">; +def LoadCacheControl_L1c_L2uc_L3c + : I32EnumAttrCase<"L1C_L2UC_L3C", 6, "L1c_L2uc_L3c">; +def LoadCacheControl_L1c_L2c_L3uc + : I32EnumAttrCase<"L1C_L2C_L3UC", 7, "L1c_L2c_L3uc">; +def LoadCacheControl_L1c_L2c_L3c + : I32EnumAttrCase<"L1C_L2C_L3C", 8, "L1c_L2c_L3c">; +def LoadCacheControl_L1s_L2uc_L3uc + : I32EnumAttrCase<"L1S_L2UC_L3UC", 9, "L1s_L2uc_L3uc">; +def LoadCacheControl_L1s_L2uc_L3c + : I32EnumAttrCase<"L1S_L2UC_L3C", 10, "L1s_L2uc_L3c">; +def LoadCacheControl_L1s_L2c_L3uc + : I32EnumAttrCase<"L1S_L2C_L3UC", 11, "L1s_L2c_L3uc">; +def LoadCacheControl_L1s_L2c_L3c + : I32EnumAttrCase<"L1S_L2C_L3C", 12, "L1s_L2c_L3c">; +def LoadCacheControlInvalidateRead + : I32EnumAttrCase<"INVALIDATE_READ", 13, "ir">; + +def XeVM_LoadCacheControl + : I32EnumAttr< + "LoadCacheControl", "XeVM load ops cache control", + [LoadCacheControlDefault, LoadCacheControl_L1uc_L2uc_L3uc, + LoadCacheControl_L1uc_L2uc_L3c, LoadCacheControl_L1uc_L2c_L3uc, + LoadCacheControl_L1uc_L2c_L3c, LoadCacheControl_L1c_L2uc_L3uc, + LoadCacheControl_L1c_L2uc_L3c, LoadCacheControl_L1c_L2c_L3uc, + LoadCacheControl_L1c_L2c_L3c, LoadCacheControl_L1s_L2uc_L3uc, + LoadCacheControl_L1s_L2uc_L3c, LoadCacheControl_L1s_L2c_L3uc, + LoadCacheControl_L1s_L2c_L3c, LoadCacheControlInvalidateRead]> { + let cppNamespace = "::mlir::xevm"; + let genSpecializedAttr = 0; +} + +def XeVM_LoadCacheControlAttr + : EnumAttr { + let summary = [{Describe the cache settings for load operators}]; + let assemblyFormat = "`<` $value `>`"; +} + +//===----------------------------------------------------------------------===// +// XeVM Store Cache Control +// L1, L2, L3 - cache levels +// uc - uncached +// wb - write-back +// wt - write-through +// s - streaming +// Default - default cache behavior for L1, L2 and L3 cache +//===----------------------------------------------------------------------===// + +def StoreCacheControlDefault : I32EnumAttrCase<"DEFAULT", 0, "Default">; +def StoreCacheControl_L1uc_L2uc_L3uc + : I32EnumAttrCase<"L1UC_L2UC_L3UC", 1, "L1uc_L2uc_L3uc">; +def StoreCacheControl_L1uc_L2uc_L3wb + : I32EnumAttrCase<"L1UC_L2UC_L3WB", 2, "L1uc_L2uc_L3wb">; +def StoreCacheControl_L1uc_L2wb_L3uc + : I32EnumAttrCase<"L1UC_L2WB_L3UC", 3, "L1uc_L2wb_L3uc">; +def StoreCacheControl_L1uc_L2wb_L3wb + : I32EnumAttrCase<"L1UC_L2WB_L3WB", 4, "L1uc_L2wb_L3wb">; +def StoreCacheControl_L1wt_L2uc_L3uc + : I32EnumAttrCase<"L1WT_L2UC_L3UC", 5, "L1wt_L2uc_L3uc">; +def StoreCacheControl_L1wt_L2uc_L3wb + : I32EnumAttrCase<"L1WT_L2UC_L3WB", 6, "L1wt_L2uc_L3wb">; +def StoreCacheControl_L1wt_L2wb_L3uc + : I32EnumAttrCase<"L1WT_L2WB_L3UC", 7, "L1wt_L2wb_L3uc">; +def StoreCacheControl_L1wt_L2wb_L3wb + : I32EnumAttrCase<"L1WT_L2WB_L3WB", 8, "L1wt_L2wb_L3wb">; +def StoreCacheControl_L1s_L2uc_L3uc + : I32EnumAttrCase<"L1S_L2UC_L3UC", 9, "L1s_L2uc_L3uc">; +def StoreCacheControl_L1s_L2uc_L3wb + : I32EnumAttrCase<"L1S_L2UC_L3WB", 10, "L1s_L2uc_L3wb">; +def StoreCacheControl_L1s_L2wb_L3uc + : I32EnumAttrCase<"L1S_L2WB_L3UC", 11, "L1s_L2wb_L3uc">; +def StoreCacheControl_L1s_L2wb_L3wb + : I32EnumAttrCase<"L1S_L2WB_L3WB", 12, "L1s_L2wb_L3wb">; +def StoreCacheControl_L1wb_L2uc_L3uc + : I32EnumAttrCase<"L1WB_L2UC_L3UC", 13, "L1wb_L2uc_L3uc">; +def StoreCacheControl_L1wb_L2wb_L3uc + : I32EnumAttrCase<"L1WB_L2WB_L3UC", 14, "L1wb_L2wb_L3uc">; +def StoreCacheControl_L1wb_L2uc_L3wb + : I32EnumAttrCase<"L1WB_L2UC_L3WB", 15, "L1wb_L2uc_L3wb">; + +def XeVM_StoreCacheControl + : I32EnumAttr< + "StoreCacheControl", "XeVM store ops cache control", + [StoreCacheControlDefault, StoreCacheControl_L1uc_L2uc_L3uc, + StoreCacheControl_L1uc_L2uc_L3wb, StoreCacheControl_L1uc_L2wb_L3uc, + StoreCacheControl_L1uc_L2wb_L3wb, StoreCacheControl_L1wt_L2uc_L3uc, + StoreCacheControl_L1wt_L2uc_L3wb, StoreCacheControl_L1wt_L2wb_L3uc, + StoreCacheControl_L1wt_L2wb_L3wb, StoreCacheControl_L1s_L2uc_L3uc, + StoreCacheControl_L1s_L2uc_L3wb, StoreCacheControl_L1s_L2wb_L3uc, + StoreCacheControl_L1s_L2wb_L3wb, StoreCacheControl_L1wb_L2uc_L3uc, + StoreCacheControl_L1wb_L2wb_L3uc, + StoreCacheControl_L1wb_L2uc_L3wb]> { + let cppNamespace = "::mlir::xevm"; + let genSpecializedAttr = 0; +} + +def XeVM_StoreCacheControlAttr + : EnumAttr { + let summary = [{Describe the cache settings for store operators}]; + let assemblyFormat = "`<` $value `>`"; +} + +def XeVM_BlockLoad2dOp + : XeVM_Op<"blockload2d">, + Results<(outs FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$res)>, + Arguments<(ins Arg:$ptr, I32:$base_width, + I32:$base_height, I32:$base_pitch, I32:$x, I32:$y, + I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height, + I32Attr:$v_blocks, I1Attr:$transpose, I1Attr:$pack_register, + OptionalAttr:$cache_control)> { + + let summary = "2D block load"; + + let description = [{ + The `xevm.blockload2d` operation loads a two dimensional matrix tile + from a base matrix residing in memory. The parameters are: + $ptr - the base address of the base matrix containing the tile to load + $base_width, $base_height, $base_pitch - the shape of the base matrix. + pitch is the physical stride between the first columns of the current row + and the subsequent row. All units are in bytes. + $x, $y, $tile_width, $tile_height - the starting offsets and shape of + the tile to load in number of elements. + $elem_size_in_bits - the size in bits of the matrix element type + - 32 for f32, tf32 + - 16 for f16, int16, bf16 + - 8 for int8 + $v_blocks - number of consecutive tiles in innermost dimension direction to load + $transpose - transpose the tile in registers (useful for 32 bit element type) + $pack_register - pack element types narrower than register bit width. + [M, N] => [M/factor, N, factor] where factor is register_size_in_bits / elem_size_in_bits + $cache_control - an enumerator that sets the cache behaviour + + Notes: + - the $transpose and $pack_register parameters are mutual exclusive + - transposing the tile loaded is used for A matrix in backward path or used for the B matrix operand + (D = C + A * B), where A has row-major layout and B should have column-major layout in memory. + - if the tile loaded contains out of bound elements of the matrix, they are filled with 0. + + Example: + ```mlir + %base_width_a = arith.constant 32 : i32 + %base_height_a = arith.constant 8 : i32 + %base_pitch_a = arith.constant 32 : i32 + %x = arith.constant 0 : i32 + %y = arith.constant 0 : i32 + %loaded_a = xevm.blockload2d %src, %base_width_a, %base_height_a, %base_pitch_a, %x, %y + <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, + v_blocks=1 : i32, transpose=false : i32, pack_register=false, + cache_control=#xevm.load_cache_control}> + : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + ``` + }]; + + let assemblyFormat = [{ + operands prop-dict attr-dict `:` functional-type(operands, results) + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + /// Get cache control or return default if not set. + ::mlir::xevm::LoadCacheControl getCacheControlOrDefault() { + if(getCacheControl()) + return *getCacheControl(); + return ::mlir::xevm::LoadCacheControl::DEFAULT; + } + }]; + + let hasVerifier = 1; +} + +def XeVM_BlockStore2dOp + : XeVM_Op<"blockstore2d">, + Arguments<(ins Arg:$ptr, I32:$base_width, + I32:$base_height, I32:$base_pitch, I32:$x, I32:$y, + I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height, + FixedVectorOfRankAndType<[1], [XeVM_ElemType]>:$stored_val, + OptionalAttr:$cache_control)> { + + let summary = "2D block store"; + + let description = [{ + The `xevm.blockstore2d` operation stores a two dimensional tile into a + larger matrix residing in memory. The parameters are: + $ptr - the base address of the target matrix where to store the tile + $base_width, $base_height, $base_pitch - the shape of the target matrix. pitch is the + physical stride between the first columns of the current row and the subsequent row. + All units are in bytes. + $x, $y, $tile_width, $tile_height - the starting offsets and shape of the tile to store + in number of elements. + $elem_size_in_bits - the size in bits of the matrix element + - 32 for f32, tf32 + - 16 for f16, int16, bf16 + - 8 for int8 + $cache_control - an enumerator that sets the cache behaviour + $stored_val - the tile to store + + Example: + ```mlir + %base_width_c = arith.constant 64 : i32 + %base_height_c = arith.constant 8 : i32 + %base_pitch_c = arith.constant 64 : i32 + %x = arith.constant 0 : i32 + %y = arith.constant 0 : i32 + xevm.blockstore2d %dst, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %src + <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, + cache_control=#xevm.load_cache_control}> + : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + ``` + }]; + + let assemblyFormat = [{ + operands prop-dict attr-dict `:` `(` type(operands) `)` + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + /// Get cache control or return default if not set. + ::mlir::xevm::StoreCacheControl getCacheControlOrDefault() { + if(getCacheControl()) + return *getCacheControl(); + return ::mlir::xevm::StoreCacheControl::DEFAULT; + } + + /// Default value for v_blocks is 1. + constexpr uint32_t getVBlocks() { + return 1; + } + }]; + + let hasVerifier = 1; +} + +def MemScopeLane : I32EnumAttrCase<"LANE", 0, "lane">; +def MemScopeSg : I32EnumAttrCase<"SUBGROUP", 1, "subgroup">; +def MemScopeWg : I32EnumAttrCase<"WORKGROUP", 2, "workgroup">; +def MemScopeCluster : I32EnumAttrCase<"CLUSTER", 3, "cluster">; +def MemScopeDevice : I32EnumAttrCase<"DEVICE", 4, "device">; +def MemScopeSystem : I32EnumAttrCase<"SYSTEM", 5, "system">; + +def XeVM_MemScope + : I32EnumAttr<"MemScope", "XeVM memory scope", + [MemScopeLane, MemScopeSg, MemScopeWg, MemScopeCluster, + MemScopeDevice, MemScopeSystem]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::xevm"; +} +def XeVM_MemScopeAttr : EnumAttr { + let summary = [{Describe memory scopes}]; + let assemblyFormat = "`<` $value `>`"; +} + +def AddrSpacePrivate : I32EnumAttrCase<"PRIVATE", 0, "private">; +def AddrSpaceGlobal : I32EnumAttrCase<"GLOBAL", 1, "global">; +def AddrSpaceConstant : I32EnumAttrCase<"CONSTANT", 2, "constant">; +def AddrSpaceShared : I32EnumAttrCase<"SHARED", 3, "shared">; +def AddrSpaceGeneric : I32EnumAttrCase<"GENERIC", 4, "generic">; + +def XeVM_AddrSpace + : I32EnumAttr<"AddrSpace", "Address spaces", + [AddrSpacePrivate, AddrSpaceGlobal, AddrSpaceConstant, + AddrSpaceShared, AddrSpaceGeneric]> { + let genSpecializedAttr = 0; + let cppNamespace = "mlir::xevm"; +} +def XeVM_AddrSpaceAttr : EnumAttr { + let summary = [{Describe address spaces}]; + let assemblyFormat = "`<` $value `>`"; +} + +def XeVM_MemfenceOp + : XeVM_Op<"memfence">, + Arguments<(ins XeVM_MemScopeAttr:$scope, + DefaultValuedAttr:$addrspace)> { + let summary = "Work-item's memory fence."; + let description = [{ + This operation ensures that all prior memory accesses of this + work-item to `addrspace` are visible to all other work-items in `scope`. + Parameters description: + $scope - specify the memory scope at which all other work-items should observe + memory operations prior to the fence. + $addrspace - specify the address space of work-item's memory accesses + to be affected by the fence. + }]; + let assemblyFormat = [{prop-dict attr-dict}]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + }]; +} + +def XeVM_PrefetchOp + : XeVM_Op<"prefetch">, + Arguments<(ins Arg:$ptr, + OptionalAttr:$cache_control)> { + let summary = "Prefetch data into a cache subsystem."; + let description = [{ + Work-item issues a prefetch from global memory to cache: + $ptr - LLVM pointer with address space. Address space must be 1 (global) + or 4 (generic) + $cache_control - specify caching options + }]; + let assemblyFormat = [{ + operands prop-dict attr-dict `:` `(` type(operands) `)` + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + /// Get cache control or return default if not set. + ::mlir::xevm::LoadCacheControl getCacheControlOrDefault() { + if(getCacheControl()) + return *getCacheControl(); + return ::mlir::xevm::LoadCacheControl::DEFAULT; + } + }]; + + let hasVerifier = 1; +} + +def XeVM_BlockPrefetch2dOp + : XeVM_Op<"blockprefetch2d">, + Arguments<(ins Arg:$ptr, I32:$base_width, + I32:$base_height, I32:$base_pitch, I32:$x, I32:$y, + I32Attr:$elem_size_in_bits, I32Attr:$tile_width, I32Attr:$tile_height, + I32Attr:$v_blocks, + OptionalAttr:$cache_control)> { + + let summary = "2D block prefetch"; + + let description = [{ + The `xevm.blockprefetch2d` operation prefetches a two dimensional tile + from a larger base matrix residing in memory. The parameters are: + $ptr - the base address of the base matrix containing the tile to prefetch + $base_width, $base_height, $base_pitch - the shape of the base matrix. + pitch is the physical stride between the first columns of the current row + and the subsequent row. All units are in bytes. + $x, $y, $tile_width, $tile_height - the starting offsets and shape of tile + to prefetch in number of elements. + $elem_size_in_bits - the size in bits of the matrix element + - 32 for f32, bf32 + - 16 for f16, int16, bf16 + - 8 for int8, int4, int2 + $v_blocks - number of tiles in innermost dimension direction to prefetch + $cache_control - an enumerator that sets the cache behaviour + + Example: + ```mlir + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y + <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, + v_blocks=1 : i32, cache_control=#xevm.load_cache_control}> + : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + ``` + }]; + + let assemblyFormat = [{ + operands prop-dict attr-dict `:` `(` type(operands) `)` + }]; + + let extraClassDeclaration = extraBaseClassDeclaration#[{ + /// Get cache control or return default if not set. + ::mlir::xevm::LoadCacheControl getCacheControlOrDefault() { + if(getCacheControl()) + return *getCacheControl(); + return ::mlir::xevm::LoadCacheControl::DEFAULT; + } + }]; + + let hasVerifier = 1; +} + +def XeVM_MatrixElemType + : AnyTypeOf<[AnyI8, AnyI16, AnyI32, F32, TF32, F16, BF16]>; + +/// Enum attribute of the different element types. +def XeVM_ET_BF16 : I32EnumAttrCase<"BF16", 8, "bf16">; +def XeVM_ET_F16 : I32EnumAttrCase<"F16", 9, "f16">; +def XeVM_ET_S8 : I32EnumAttrCase<"S8", 10, "s8">; +def XeVM_ET_U8 : I32EnumAttrCase<"U8", 11, "u8">; +def XeVM_ET_S4 : I32EnumAttrCase<"S4", 12, "s4">; +def XeVM_ET_U4 : I32EnumAttrCase<"U4", 13, "u4">; +def XeVM_ET_TF32 : I32EnumAttrCase<"TF32", 14, "tf32">; +def XeVM_ET_F32 : I32EnumAttrCase<"F32", 15, "f32">; +def XeVM_ET_S32 : I32EnumAttrCase<"S32", 16, "s32">; + +def XeVM_ElemTypeAttr : I32EnumAttr<"ElemType", "XeVM element type", + [XeVM_ET_BF16, XeVM_ET_F16, XeVM_ET_S8, + XeVM_ET_U8, XeVM_ET_S4, XeVM_ET_U4, + XeVM_ET_TF32, XeVM_ET_F32, XeVM_ET_S32]> { + let cppNamespace = "::mlir::xevm"; +} + +def XeVM_MMAShapeAttr : XeVM_Attr<"MMAShape", "mma_shape"> { + let description = [{ + MMA operation is represented as D=AxB+C, where + A has the shape MxK. + B has the shape KxN. + D and C have the shape MxN. + This attribute encodes the shape of all matrices that participate in MMA. + }]; + let parameters = (ins "int":$m, "int":$n, "int":$k); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def XeVM_MMATypesAttr : XeVM_Attr<"MMATypes", "mma_types"> { + let parameters = (ins "xevm::ElemType":$d, "xevm::ElemType":$a, + "xevm::ElemType":$b, OptionalParameter<"xevm::ElemType">:$c); + let assemblyFormat = "`<` struct(params) `>`"; +} + +def XeVM_MMAOp + : XeVM_Op<"mma">, + Results<(outs FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$d)>, + Arguments<(ins FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$a, + FixedVectorOfRankAndType<[1], [XeVM_MatrixElemType]>:$b, + Optional>:$c, + XeVM_MMAShapeAttr:$shape, XeVM_MMATypesAttr:$types)> { + + let summary = "Subgroup matrix multiply-add"; + + let description = [{ + The `xevm.mma` is a cooperative operation where all threads/lanes in + a subgroup participates and carries out matrix multiplication plus accumulation: + + D = C + A x B + + where the A, B, C input matrices and the result D have shapes: + D : MxN + C : MxN + A : MxK + B : KxN + + Parameters: + `a` - vector of matrix A elements. + `b` - vector of matrix B elements. + `c` - (optional) vector of matrix C elements. + `shape` - the shape of the matrices, specified as `M`, `N`, and `K` values. + `types` - the data types of the matrices, specified as `D`, `A`, `B`, and optionally `C`. + + Example: + ```mlir + %d = xevm.mma %a, %b, %c { shape=, types= } + : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + ``` + }]; + + let assemblyFormat = [{ + $a `,` $b (`,` $c^)? ` ` + `{` + `shape` `=` $shape `,` + `types` `=` $types + `}` attr-dict `:` functional-type(operands, results) + }]; + + let hasVerifier = 1; +} + +//===----------------------------------------------------------------------===// +// XeVM target attribute. +//===----------------------------------------------------------------------===// + +def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> { + let description = [{ + GPU target attribute for controlling compilation of Intel GPU targets. All + parameters decay into default values if not present. + + Examples: + + 1. Target with default values. + ``` + gpu.module @mymodule [#xevm.target] attributes {...} { + ... + } + ``` + }]; + let parameters = + (ins DefaultValuedParameter<"int", "2", + "Optimization level to apply.">:$O, + StringRefParameter<"Target triple.", + "\"spirv64-unknown-unknown\"">:$triple, + StringRefParameter<"Target chip.", "\"bmg\"">:$chip, + OptionalParameter<"::mlir::DictionaryAttr", + "Target specific flags.">:$flags, + OptionalParameter<"::mlir::ArrayAttr", + "Files to link to the LLVM module.">:$linkFiles); + let assemblyFormat = [{ + (`<` struct($O, $triple, $chip, $flags, $linkFiles)^ `>`)? + }]; + let builders = [AttrBuilder< + (ins CArg<"int", "2">:$optLevel, + CArg<"::llvm::StringRef", "\"spirv64-unknown-unknown\"">:$triple, + CArg<"::llvm::StringRef", "\"bmg\"">:$chip, + CArg<"::mlir::DictionaryAttr", "nullptr">:$targetFlags, + CArg<"::mlir::ArrayAttr", "nullptr">:$linkFiles), + [{ + return Base::get($_ctxt, optLevel, triple, chip, targetFlags, linkFiles); + }]>]; + let skipDefaultBuilders = 1; + let genVerifyDecl = 1; +} + +#endif // XEVMIR_OPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h index 261b0e00bdf86..c6fcf1a0d510b 100644 --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -46,6 +46,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.h" +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/AllInterfaces.h" #include "mlir/Dialect/Linalg/Transforms/RuntimeOpVerification.h" @@ -152,7 +153,8 @@ inline void registerAllDialects(DialectRegistry ®istry) { ub::UBDialect, vector::VectorDialect, x86vector::X86VectorDialect, - xegpu::XeGPUDialect>(); + xegpu::XeGPUDialect, + xevm::XeVMDialect>(); // clang-format on // Register all external models. diff --git a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt index d83fd3800eb91..67081ca61e6e5 100644 --- a/mlir/lib/Dialect/LLVMIR/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/CMakeLists.txt @@ -110,3 +110,25 @@ add_mlir_dialect_library(MLIRVCIXDialect MLIRLLVMDialect MLIRSideEffectInterfaces ) + +add_mlir_dialect_library(MLIRXeVMDialect + IR/XeVMDialect.cpp + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/LLVMIR + + DEPENDS + MLIRGPUCompilationAttrInterfacesIncGen + MLIRXeVMOpsIncGen + MLIRXeVMConversionsIncGen + intrinsics_gen + + LINK_COMPONENTS + AsmParser + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRLLVMDialect + MLIRSideEffectInterfaces +) diff --git a/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp new file mode 100644 index 0000000000000..d10fa5cdbc2f5 --- /dev/null +++ b/mlir/lib/Dialect/LLVMIR/IR/XeVMDialect.cpp @@ -0,0 +1,377 @@ +//===-- XeVMDialect.cpp - XeVM dialect registration -------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "mlir/Dialect/LLVMIR/XeVMDialect.h" +#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/DialectImplementation.h" +#include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/FileSystem.h" +#include "llvm/Support/MathExtras.h" + +using namespace mlir; +using namespace mlir::xevm; + +#include +#include + +namespace { +constexpr uint32_t subgroupSize = 16; + +template +LogicalResult verifyMatrixInput(Op op) { + static_assert(llvm::is_one_of::value, + "Unexpected template parameter"); + + std::optional width = getConstantIntValue(op.getBaseWidth()); + std::optional pitch = getConstantIntValue(op.getBasePitch()); + if (pitch && width && *pitch < *width) + return op->emitOpError( + "4th operand (base pitch) should be >= 2nd operand (base width)"); + + uint32_t elemSize = op.getElemSizeInBits(); + if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32) + return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32"); + + uint32_t tileHeight = op.getTileHeight(); + if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight)) + return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32"); + + uint32_t vBlocks = op.getVBlocks(); + if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks)) + return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8"); + + return success(); +} + +LogicalResult verify2DBlockLoadRestriction(BlockLoad2dOp op) { + VectorType resTy = op.getRes().getType(); + if (!resTy.getElementType().isIntOrFloat()) + return op.emitOpError() + << "expecting result element type to be int or float"; + unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth(); + unsigned resSize = resTy.getNumElements() * resElemTySize; + unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() * + op.getTileWidth() * op.getVBlocks() / subgroupSize; + if (resSize != expectedSize) + return op.emitOpError() << "result size of " << resSize + << " bits does not match the expected size of " + << expectedSize << " bits"; + + if (op.getTranspose() && op.getPackRegister()) + return op.emitOpError("transpose and pack_register are mutually exclusive"); + + if (!op.getTranspose() && !op.getPackRegister()) { + uint32_t tileHeight = op.getTileHeight(); + if (tileHeight < 1 || tileHeight > 32) + return op.emitOpError("expecting tile_height to be between 1 and 32"); + + uint32_t tileWidth = op.getTileWidth(); + uint32_t vBlocks = op.getVBlocks(); + switch (op.getElemSizeInBits()) { + case 8: + if (tileWidth < 4 || tileWidth > 64) + return op.emitOpError("expecting tile_width to be between 4 and 64"); + if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4) + return op.emitOpError("expecting v_blocks to be 1, 2, or 4"); + if (tileWidth * vBlocks > 64) + return op.emitOpError( + "tile_width * v_blocks should be less than or equal " + "to 64 for 8 bit elements"); + break; + case 16: + if (tileWidth < 2 || tileWidth > 32) + return op.emitOpError("expecting tile_width to be between 2 and 32"); + if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4) + return op.emitOpError("expecting v_blocks to be 1, 2, or 4"); + if (tileWidth * vBlocks > 32) + return op.emitOpError( + "tile_width * v_blocks should be less than or equal " + "to 32 for 16 bit elements"); + break; + case 32: + if (tileWidth < 1 || tileWidth > 16) + return op.emitOpError("expecting tile_width to be between 1 and 16"); + if (vBlocks != 1 && vBlocks != 2) + return op.emitOpError("expecting v_blocks to be 1 or 2"); + if (tileWidth * vBlocks > 16) + return op.emitOpError( + "tile_width * v_blocks should be less than or equal " + "to 16 for 32 bit elements"); + break; + case 64: + if (tileWidth < 1 || tileWidth > 8) + return op.emitOpError("expecting tile_width to be between 1 and 8"); + if (vBlocks != 1) + return op.emitOpError("expecting v_blocks to be 1"); + break; + default: + return op.emitOpError( + "expecting elem_size_in_bits to be 8, 16, 32, or 64"); + } + + return success(); + } + + if (op.getTranspose()) { + assert(!op.getPackRegister() && "Expecting pack_register should be false"); + + uint32_t vBlocks = op.getVBlocks(); + if (vBlocks != 1) + return op.emitOpError("expecting v_blocks to be 1"); + + uint32_t tileHeight = op.getTileHeight(); + uint32_t tileWidth = op.getTileWidth(); + switch (op.getElemSizeInBits()) { + case 32: + if (tileHeight < 1 || tileHeight > 32) + return op.emitOpError("expecting tile_height to be between 1 and 32"); + if (tileWidth < 1 || tileWidth > 8) + return op.emitOpError("expecting tile_width to be between 1 and 8"); + break; + case 64: + if (tileHeight != 8) + return op.emitOpError( + "expecting tile_height to be 8 for 64 bit elements"); + if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4) + return op.emitOpError("expecting tile_width to be 1, 2, or 4"); + break; + default: + return op.emitOpError("transpose is only supported for 32 and 64 bit " + "elements"); + } + + return success(); + } + + assert(op.getPackRegister() && !op.getTranspose() && + "Expecting pack_register should be true and transpose should be " + "false"); + + uint32_t vBlocks = op.getVBlocks(); + if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4) + return op.emitOpError("expecting v_blocks to be 1, 2, or 4"); + + uint32_t tileHeight = op.getTileHeight(); + uint32_t tileWidth = op.getTileWidth(); + switch (op.getElemSizeInBits()) { + case 8: + if (tileHeight < 4 || tileHeight > 32) + return op.emitOpError("expecting tile_height to be between 4 and 32"); + if (tileWidth < 4 || tileWidth > 16) + return op.emitOpError("expecting tile_width to be between 4 and 16"); + break; + case 16: + if (tileHeight < 2 || tileHeight > 32) + return op.emitOpError("expecting tile_height to be between 2 and 32"); + if (tileWidth < 2 || tileWidth > 16) + return op.emitOpError("expecting tile_width to be between 2 and 16"); + if (tileWidth * vBlocks > 32) + return op.emitOpError( + "tile_width * v_blocks should be less than or equal " + "to 32 for 16 bit elements"); + break; + default: + return op.emitOpError("pack_register is only supported for 8 and 16 bit " + "elements"); + } + + return success(); +} + +static LogicalResult verify2DBlockStoreRestriction(BlockStore2dOp op) { + uint32_t tileHeight = op.getTileHeight(); + if (tileHeight < 1 || tileHeight > 8) + return op.emitOpError("expecting tile_height to be between 1 and 8"); + + uint32_t tileWidth = op.getTileWidth(); + switch (op.getElemSizeInBits()) { + case 8: + if (tileWidth < 4 || tileWidth > 64) + return op.emitOpError("expecting tile_width to be between 4 and 64"); + break; + case 16: + if (tileWidth < 2 || tileWidth > 32) + return op.emitOpError("expecting tile_width to be between 2 and 32"); + break; + case 32: + if (tileWidth < 1 || tileWidth > 16) + return op.emitOpError("expecting tile_width to be between 1 and 16"); + break; + case 64: + if (tileWidth < 1 || tileWidth > 8) + return op.emitOpError("expecting tile_width to be between 1 and 8"); + break; + default: + return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64"); + } + + uint32_t vBlocks = op.getVBlocks(); + if (vBlocks != 1) + return op.emitOpError("expecting v_blocks to be 1"); + return success(); +} + +} // namespace + +LogicalResult BlockLoad2dOp::verify() { + if (verify2DBlockLoadRestriction(*this).failed()) + return failure(); + + if (verifyMatrixInput(*this).failed()) + return failure(); + + VectorType resTy = getRes().getType(); + if (!resTy.getElementType().isIntOrFloat()) + return emitOpError() << "expecting result element type to be int of float"; + unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth(); + if (getElemSizeInBits() == 32 || getPackRegister()) { + if (resElemTySize != 32) + return emitOpError() << "expecting result element type to be 32 bits"; + } + + uint32_t tileWidth = getTileWidth(); + if (getPackRegister()) { + if (tileWidth != 16) + return emitOpError( + "tile_width when pack_register is true should be equal " + "to subgroup size (16 elements)"); + return success(); + } + + return success(); +} + +LogicalResult BlockStore2dOp::verify() { + if (verify2DBlockStoreRestriction(*this).failed()) + return failure(); + + if (verifyMatrixInput(*this).failed()) + return failure(); + + uint32_t tileWidth = getTileWidth(); + switch (getElemSizeInBits()) { + case 8: + if (tileWidth != 16 && tileWidth != 32) + return emitOpError("tile_width for 8 bit elements should be equal to " + "16 or 32"); + break; + case 16: + if (tileWidth != 16) + return emitOpError("tile_width for 16 bit elements should be equal " + "to 16"); + break; + case 32: + if (tileWidth != 16) + return emitOpError("tile_width for 32 bit elements should be equal " + "to 16"); + break; + default: + llvm_unreachable("unexpected element size"); + } + + return success(); +} + +LogicalResult BlockPrefetch2dOp::verify() { + if (verifyMatrixInput(*this).failed()) + return failure(); + + uint32_t tileWidth = getTileWidth(); + switch (getElemSizeInBits()) { + case 8: + if (tileWidth != 16 && tileWidth != 32) + return emitOpError("tile_width for 8 bit elements should be equal to " + "16 or 32"); + break; + case 16: + if (tileWidth != 16) + return emitOpError("tile_width for 16 bit elements should be equal " + "to 16"); + break; + case 32: + if (tileWidth != 8 && tileWidth != 16) + return emitOpError( + "tile_width for 32 bit elements should be equal to 8 or 16"); + break; + default: + llvm_unreachable("unexpected element size"); + } + + return success(); +} + +LogicalResult MMAOp::verify() { + if (getC()) { + if (getResult().getType() != getC().getType()) + return emitOpError("type of C operand must match result type"); + } + return success(); +} + +LogicalResult PrefetchOp::verify() { + auto ptrTy = mlir::dyn_cast(getOperand().getType()); + auto addrSpace = ptrTy.getAddressSpace(); + if (addrSpace != 1 && addrSpace != 4) + return emitOpError( + "LLVM pointer type address space must be 1 (global) or 4 (generic)"); + return success(); +} + +LogicalResult +XeVMTargetAttr::verify(function_ref emitError, int O, + StringRef triple, StringRef chip, DictionaryAttr flags, + ArrayAttr linkFiles) { + if (O < 0 || O > 3) { + return emitError() + << "The optimization level must be a number between 0 and 3."; + } + if (triple.empty()) { + return emitError() << "The target triple cannot be empty."; + } + if (chip.empty()) { + return emitError() << "The target chip cannot be empty."; + } + if (linkFiles) { + for (Attribute fileAttr : linkFiles) { + if (auto fileStrAttr = llvm::dyn_cast(fileAttr)) { + StringRef filePath = fileStrAttr.getValue(); + if (filePath.empty()) { + return emitError() << "File paths in linkFiles cannot be empty."; + } + if (!llvm::sys::fs::exists(filePath)) { + return emitError() << "File '" << filePath << "' does not exist."; + } + } + } + } + return success(); +} + +void XeVMDialect::initialize() { + // NOLINTBEGIN + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc" + >(); + + addAttributes< +#define GET_ATTRDEF_LIST +#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc" + >(); + // NOLINTEND + declarePromisedInterface(); +} + +#define GET_OP_CLASSES +#include "mlir/Dialect/LLVMIR/XeVMOps.cpp.inc" + +#define GET_ATTRDEF_CLASSES +#include "mlir/Dialect/LLVMIR/XeVMOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/LLVMIR/invalid.mlir b/mlir/test/Dialect/LLVMIR/invalid.mlir index 251ca716c7a7a..174f925fea317 100644 --- a/mlir/test/Dialect/LLVMIR/invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/invalid.mlir @@ -1875,7 +1875,7 @@ llvm.mlir.global @bad_struct_array_init_elements() : !llvm.array<1x!llvm.struct< llvm.return %0 : !llvm.array<1x!llvm.struct<(i32, f32)>> } -// ---- +// ----- llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2 x f64> { // expected-error@below {{'llvm.mlir.constant' op for array with an array attribute must have a struct element type}} @@ -1883,10 +1883,51 @@ llvm.mlir.global internal constant @bad_array_attr_simple_type() : !llvm.array<2 llvm.return %0 : !llvm.array<2 x f64> } -// ---- +// ----- llvm.func @inlineAsmMustTail(%arg0: i32, %arg1 : !llvm.ptr) { // expected-error@+1 {{op tail call kind 'musttail' is not supported}} %8 = llvm.inline_asm tail_call_kind = "foo", "=r,=r,r" %arg0 : (i32) -> !llvm.struct<(i8, i8)> llvm.return } + +// ----- + +llvm.func @invalid_xevm_prefetch(%arg0: !llvm.ptr) { + // expected-error@+1 {{LLVM pointer type address space must be 1 (global) or 4 (generic)}} + xevm.prefetch %arg0 <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr) + llvm.return +} + +// ----- + +llvm.func @invalid_xevm_mma(%loaded_c_casted: vector<4xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { + // expected-error@+1 {{op type of C operand must match result type}} + %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted {shape = , types = } : (vector<8xi16>, vector<8xi32>, vector<4xf32>) -> vector<8xf32> + llvm.return %c_result : vector<8xf32> +} + +// ----- + +llvm.func @invalid_xevm_matrix_1(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { + // expected-error@+1 {{op expecting tile_width to be between 1 and 8}} + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=64 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.store_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + llvm.return +} + +// ----- + +llvm.func @invalid_xevm_matrix_2(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { + // expected-error@+1 {{op expecting elem_size_in_bits to be 8, 16, 32, or 64}} + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=18 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.store_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + llvm.return +} + +// ----- + +llvm.func @invalid_xevm_matrix_3(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { + // expected-error@+1 {{op result size of 128 bits does not match the expected size of 208 bits}} + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=26 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false, cache_control=#xevm.load_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + llvm.return %loaded_a : vector<8xi16> +} + diff --git a/mlir/test/Dialect/LLVMIR/xevm.mlir b/mlir/test/Dialect/LLVMIR/xevm.mlir new file mode 100644 index 0000000000000..bf10bd45d58a0 --- /dev/null +++ b/mlir/test/Dialect/LLVMIR/xevm.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: func.func @blockload2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +func.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { + // CHECK: %[[VAR0:.*]] = xevm.blockload2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] + // CHECK-DAG: elem_size_in_bits = 16 : i32 + // CHECK-DAG: tile_width = 16 : i32 + // CHECK-DAG: tile_height = 8 : i32 + // CHECK-DAG: v_blocks = 1 : i32 + // CHECK-DAG: transpose = false + // CHECK-DAG: pack_register = false + // CHECK-DAG: cache_control = #xevm.load_cache_control + // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false, pack_register=false, cache_control=#xevm.load_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16> + return %loaded_a : vector<8xi16> +} + +// ----- +// CHECK: func.func @blockstore2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32, %[[ARG6:.*]]: vector<8xi32>) +func.func @blockstore2d(%c: !llvm.ptr<1>, %base_width_c: i32, %base_height_c: i32, %base_pitch_c: i32, %x: i32, %y: i32, %c_result_casted: vector<8xi32>) { + // CHECK: xevm.blockstore2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]], %[[ARG6]] + // CHECK-DAG: elem_size_in_bits = 32 : i32 + // CHECK-DAG: tile_width = 16 : i32 + // CHECK-DAG: tile_height = 8 : i32 + // CHECK-DAG: cache_control = #xevm.store_cache_control + // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + xevm.blockstore2d %c, %base_width_c, %base_height_c, %base_pitch_c, %x, %y, %c_result_casted <{elem_size_in_bits=32 : i32, tile_width=16 : i32, tile_height=8 : i32, cache_control=#xevm.store_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32, vector<8xi32>) + return +} + +// ----- +// CHECK: func.func @blockprefetch2d(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32, %[[ARG3:.*]]: i32, %[[ARG4:.*]]: i32, %[[ARG5:.*]]: i32) +func.func @blockprefetch2d(%ptr: !llvm.ptr<1>, %base_width: i32, %base_height: i32, %base_pitch: i32, %x: i32, %y: i32) { + // CHECK: xevm.blockprefetch2d %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[ARG4]], %[[ARG5]] + // CHECK-DAG: elem_size_in_bits = 8 : i32 + // CHECK-DAG: tile_width = 32 : i32 + // CHECK-DAG: tile_height = 8 : i32 + // CHECK-DAG: v_blocks = 1 : i32 + // CHECK-DAG: cache_control = #xevm.load_cache_control + // CHECK: (!llvm.ptr<1>, i32, i32, i32, i32, i32) + xevm.blockprefetch2d %ptr, %base_width, %base_height, %base_pitch, %x, %y <{elem_size_in_bits=8 : i32, tile_width=32 : i32, tile_height=8 : i32, v_blocks=1 : i32, cache_control=#xevm.load_cache_control}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) + return +} + +// ----- +// CHECK: func.func @mma(%[[ARG0:.*]]: vector<8xf32>, %[[ARG1:.*]]: vector<8xi16>, %[[ARG2:.*]]: vector<8xi32>) +func.func @mma(%loaded_c_casted: vector<8xf32>, %loaded_a: vector<8xi16>, %loaded_b_casted: vector<8xi32>) -> vector<8xf32> { + // CHECK: %0 = xevm.mma %[[ARG1]], %[[ARG2]], %[[ARG0]] {shape = , types = } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + %c_result = xevm.mma %loaded_a, %loaded_b_casted, %loaded_c_casted { shape=, types= } : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + return %c_result : vector<8xf32> +} + +// ----- +func.func @memfence() { + // CHECK: xevm.memfence + // CHECK-DAG: addrspace = #xevm.addr_space + // CHECK-DAG: scope = #xevm.mem_scope + xevm.memfence <{addrspace=#xevm.addr_space, scope=#xevm.mem_scope}> + return +} + +// ----- +// CHECK: func.func @prefetch(%[[ARG0:.*]]: !llvm.ptr<1>) +func.func @prefetch(%ptr: !llvm.ptr<1>) { + // CHECK: xevm.prefetch %[[ARG0]] + // CHECK-DAG: cache_control = #xevm.load_cache_control + // CHECK: (!llvm.ptr<1>) + xevm.prefetch %ptr <{cache_control = #xevm.load_cache_control}> : (!llvm.ptr<1>) + return +} + +// ----- +// CHECK: @xevm_module [#xevm.target] { +gpu.module @xevm_module [#xevm.target]{ +}