-
Notifications
You must be signed in to change notification settings - Fork 14.1k
[mlir][spirv] Add support for SPV_ARM_tensors #144667
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Davide Grohmann <[email protected]> Signed-off-by: Mohammadreza Ameri Mahabadian <[email protected]> Change-Id: If78909a47417ef3dda710847cfe90c34b984ff09
@llvm/pr-subscribers-mlir-spirv Author: Davide Grohmann (davidegrohmann) ChangesPatch is 32.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144667.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d2ba76cdad904..d874817e6888d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
+def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+
def SPIRV_ExtensionAttr :
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
@@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
+ SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams : I32EnumAttrCase<"Geome
def SPIRV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57> {
list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
}
+def SPIRV_C_TensorsARM : I32EnumAttrCase<"TensorsARM", 4174> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
+def SPIRV_C_StorageTensorArrayDynamicIndexingEXT : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
+def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
list<Availability> availability = [
@@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
+ SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
+ SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
-
+def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
@@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
"any SPIR-V struct type">;
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
+def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
+ "any SPIR-V tensorArm type">;
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
@@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
- SPIRV_AnyImage
+ SPIRV_AnyImage, SPIRV_AnyTensorArm
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
- SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+ SPIRV_OC_OpGroupNonUniformLogicalXor,
+ SPIRV_OC_OpTypeTensorARM,
+ SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 787535d0a6bd2..7ffea6e7dba81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@ namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
+struct TensorArmTypeStorage;
struct ImageTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
@@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
std::optional<int64_t> getSizeInBytes();
};
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
+// StructType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
@@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
std::optional<StorageClass> storage = std::nullopt);
};
+// SPIR-V TensorARM Type
+class TensorArmType
+ : public Type::TypeBase<TensorArmType, CompositeType,
+ detail::TensorArmTypeStorage, ShapedType::Trait> {
+public:
+ using Base::Base;
+
+ static constexpr StringLiteral name = "spirv.arm.tensor";
+
+ // TensorArm supports minimum rank of 1, hence an empty shape here means
+ // unranked.
+ static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
+ TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const;
+
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType);
+
+ Type getElementType() const;
+ ArrayRef<int64_t> getShape() const;
+ unsigned getNumElements() const;
+ bool hasRank() const { return !getShape().empty(); }
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage = std::nullopt);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage = std::nullopt);
+};
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a21acef1c4b43..15002f1d5d16e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
<< t.getNumElements();
return Type();
}
+ } else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
+ if (!llvm::isa<ScalarType>(t.getElementType())) {
+ parser.emitError(
+ typeLoc, "only scalar element type allowed in tensor type but found ")
+ << t.getElementType();
+ return Type();
+ }
} else {
parser.emitError(typeLoc, "cannot use ")
<< type << " to compose SPIR-V types";
@@ -363,6 +370,54 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
}
+// tensor-arm-type ::=
+// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
+static Type parseTensorArmType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return {};
+
+ bool unranked = false;
+ SmallVector<int64_t, 4> dims;
+ SMLoc countLoc = parser.getCurrentLocation();
+
+ if (parser.parseOptionalStar().succeeded()) {
+ unranked = true;
+ if (parser.parseXInDimensionList())
+ return {};
+ } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
+ return {};
+
+ if (!unranked && dims.empty()) {
+ parser.emitError(countLoc, "arm.tensors do not support rank zero");
+ return {};
+ }
+
+ if (std::any_of(dims.begin(), dims.end(),
+ [](int64_t dim) { return dim == 0; })) {
+ parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
+ return {};
+ }
+
+ if (std::any_of(dims.begin(), dims.end(),
+ [](int64_t dim) { return dim < 0; }) &&
+ std::any_of(dims.begin(), dims.end(),
+ [](int64_t dim) { return dim > 0; })) {
+ parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
+ "fully dynamic or completed shaped");
+ return {};
+ }
+
+ auto elementTy = parseAndVerifyType(dialect, parser);
+ if (!elementTy)
+ return {};
+
+ if (parser.parseGreater())
+ return {};
+
+ return TensorArmType::get(dims, elementTy);
+}
+
// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
@@ -759,6 +814,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseStructType(*this, parser);
if (keyword == "matrix")
return parseMatrixType(*this, parser);
+ if (keyword == "arm.tensor")
+ return parseTensorArmType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
@@ -855,10 +912,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
os << ">";
}
+static void print(TensorArmType type, DialectAsmPrinter &os) {
+ os << "arm.tensor<";
+
+ llvm::interleave(
+ type.getShape(), os,
+ [&](int64_t dim) {
+ if (ShapedType::isDynamic(dim))
+ os << '?';
+ else
+ os << dim;
+ },
+ "x");
+ if (!type.hasRank()) {
+ os << "*";
+ }
+ os << "x" << type.getElementType() << ">";
+}
+
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
- ImageType, SampledImageType, StructType, MatrixType>(
+ ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
[&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7148027dae78d..eb2974d62fdd1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
return failure();
}
+ if (llvm::isa<TensorArmType>(type)) {
+ if (parser.parseOptionalColon().succeeded())
+ if (parser.parseType(type))
+ return failure();
+ }
+
return parser.addTypeToList(type, result.types);
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 93e0c9b33c546..e4eeb0a7f37d5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,8 +18,10 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
#include <cstdint>
#include <iterator>
+#include <numeric>
using namespace mlir;
using namespace mlir::spirv;
@@ -96,7 +98,7 @@ bool CompositeType::classof(Type type) {
return isValid(vectorType);
return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
spirv::MatrixType, spirv::RuntimeArrayType,
- spirv::StructType>(type);
+ spirv::StructType, spirv::TensorArmType>(type);
}
bool CompositeType::isValid(VectorType type) {
@@ -107,8 +109,8 @@ bool CompositeType::isValid(VectorType type) {
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
- [](auto type) { return type.getElementType(); })
+ .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
+ TensorArmType>([](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
[index](StructType type) { return type.getElementType(index); })
@@ -125,6 +127,8 @@ unsigned CompositeType::getNumElements() const {
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
+ return tensorArmType.getNumElements();
if (llvm::isa<CooperativeMatrixType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv Cooperative Matrix type");
@@ -151,6 +155,14 @@ void CompositeType::getExtensions(
return llvm::cast<ScalarType>(type.getElementType())
.getExtensions(extensions, storage);
})
+ .Case<TensorArmType>([&](TensorArmType type) {
+ static const Extension exts[] = {Extension::SPV_ARM_tensors};
+ ArrayRef<Extension> ref(exts, std::size(exts));
+ extensions.push_back(ref);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getExtensions(extensions, storage);
+ })
+
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -171,6 +183,13 @@ void CompositeType::getCapabilities(
return llvm::cast<ScalarType>(type.getElementType())
.getCapabilities(capabilities, storage);
})
+ .Case<TensorArmType>([&](TensorArmType type) {
+ static const Capability caps[] = {Capability::TensorsARM};
+ ArrayRef<Capability> ref(caps, std::size(caps));
+ capabilities.push_back(ref);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getCapabilities(capabilities, storage);
+ })
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -186,6 +205,13 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
return std::nullopt;
return *elementSize * vectorType.getNumElements();
}
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ std::optional<int64_t> elementSize =
+ llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
+ if (!elementSize)
+ return std::nullopt;
+ return *elementSize * tensorArmType.getNumElements();
+ }
return std::nullopt;
}
@@ -691,6 +717,9 @@ bool SPIRVType::classof(Type type) {
return true;
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return CompositeType::isValid(vectorType);
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+ return llvm::isa<ScalarType>(tensorArmType.getElementType());
+ }
return false;
}
@@ -712,6 +741,8 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
matrixType.getExtensions(extensions, storage);
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getExtensions(extensions, storage);
+ } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ tensorArmType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
}
@@ -732,6 +763,8 @@ void SPIRVType::getCapabilities(
matrixType.getCapabilities(capabilities, storage);
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getCapabilities(capabilities, storage);
+ } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ tensorArmType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
}
@@ -1203,11 +1236,94 @@ void MatrixType::getCapabilities(
llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
}
+//===----------------------------------------------------------------------===//
+// TensorArmType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
+ using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
+
+ static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ auto shape = std::get<0>(key);
+ auto elementType = std::get<1>(key);
+ shape = allocator.copyInto(shape);
+ return new (allocator.allocate<TensorArmTypeStorage>())
+ TensorArmTypeStorage(std::move(shape), std::move(elementType));
+ }
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(shape, elementType);
+ }
+
+ TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
+ : shape(std::move(shape)), elementType(std::move(elementType)) {}
+
+ ArrayRef<int64_t> shape;
+ Type elementType;
+};
+
+TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
+ return Base::get(elementType.getContext(), shape, elementType);
+}
+
+TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ return TensorArmType::get(shape.value_or(getShape()), elementType);
+}
+
+Type TensorArmType::getElementType() const { return getImpl()->elementType; }
+ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+
+unsigned TensorArmType::getNumElements() const {
+ auto shape = getShape();
+ return std::accumulate(shape.begin(), shape.end(), unsigned(1),
+ std::multiplies<unsigned>());
+}
+
+void TensorArmType::getExtensions(
+ SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage) {
+
+ llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+ static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
+ extensions.push_back(exts);
+}
+
+void TensorArmType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage) {
+ llvm::cast<SPIRVType>(getElementType())
+ .getCapabilities(capabilities, storage);
+ static constexpr Capab...
[truncated]
|
@llvm/pr-subscribers-mlir Author: Davide Grohmann (davidegrohmann) ChangesPatch is 32.41 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/144667.diff 11 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
index d2ba76cdad904..d874817e6888d 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td
@@ -422,6 +422,8 @@ def SPV_NV_ray_tracing_motion_blur : I32EnumAttrCase<"SPV_NV_ray_tracing_m
def SPV_NVX_multiview_per_view_attributes : I32EnumAttrCase<"SPV_NVX_multiview_per_view_attributes", 5015>;
+def SPV_ARM_tensors : I32EnumAttrCase<"SPV_ARM_tensors", 6000>;
+
def SPIRV_ExtensionAttr :
SPIRV_I32EnumAttr<"Extension", "supported SPIR-V extensions", "ext", [
SPV_KHR_16bit_storage, SPV_KHR_8bit_storage, SPV_KHR_device_group,
@@ -445,6 +447,7 @@ def SPIRV_ExtensionAttr :
SPV_EXT_shader_atomic_float_add, SPV_EXT_shader_atomic_float_min_max,
SPV_EXT_shader_image_int64, SPV_EXT_shader_atomic_float16_add,
SPV_EXT_mesh_shader,
+ SPV_ARM_tensors,
SPV_AMD_gpu_shader_half_float_fetch, SPV_AMD_shader_ballot,
SPV_AMD_shader_explicit_vertex_parameter, SPV_AMD_shader_fragment_mask,
SPV_AMD_shader_image_load_store_lod, SPV_AMD_texture_gather_bias_lod,
@@ -1311,6 +1314,24 @@ def SPIRV_C_GeometryStreams : I32EnumAttrCase<"Geome
def SPIRV_C_MultiViewport : I32EnumAttrCase<"MultiViewport", 57> {
list<I32EnumAttrCase> implies = [SPIRV_C_Geometry];
}
+def SPIRV_C_TensorsARM : I32EnumAttrCase<"TensorsARM", 4174> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_Int8];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
+def SPIRV_C_StorageTensorArrayDynamicIndexingEXT : I32EnumAttrCase<"StorageTensorArrayDynamicIndexingEXT", 4175> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_Shader];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
+def SPIRV_C_StorageTensorArrayNonUniformIndexingEXT : I32EnumAttrCase<"StorageTensorArrayNonUniformIndexingEXT", 4176> {
+ list<I32EnumAttrCase> implies = [SPIRV_C_TensorsARM, SPIRV_C_ShaderNonUniform];
+ list<Availability> availability = [
+ Extension<[SPV_ARM_tensors]>
+ ];
+}
def SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR : I32EnumAttrCase<"WorkgroupMemoryExplicitLayout8BitAccessKHR", 4429> {
list<I32EnumAttrCase> implies = [SPIRV_C_WorkgroupMemoryExplicitLayoutKHR];
list<Availability> availability = [
@@ -1523,6 +1544,8 @@ def SPIRV_CapabilityAttr :
SPIRV_C_IntegerFunctions2INTEL, SPIRV_C_TessellationPointSize,
SPIRV_C_GeometryPointSize, SPIRV_C_ImageCubeArray, SPIRV_C_ImageRect,
SPIRV_C_GeometryStreams, SPIRV_C_MultiViewport,
+ SPIRV_C_TensorsARM, SPIRV_C_StorageTensorArrayDynamicIndexingEXT,
+ SPIRV_C_StorageTensorArrayNonUniformIndexingEXT,
SPIRV_C_WorkgroupMemoryExplicitLayout8BitAccessKHR, SPIRV_C_VariablePointers,
SPIRV_C_RayTraversalPrimitiveCullingKHR, SPIRV_C_SampleMaskOverrideCoverageNV,
SPIRV_C_GeometryShaderPassthroughNV, SPIRV_C_PerViewAttributesNV,
@@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">;
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">;
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">;
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">;
-
+def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">;
// See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_types
// for the definition of the following types and type categories.
@@ -4217,6 +4240,8 @@ def SPIRV_AnyStruct : DialectType<SPIRV_Dialect, SPIRV_IsStructType,
"any SPIR-V struct type">;
def SPIRV_AnySampledImage : DialectType<SPIRV_Dialect, SPIRV_IsSampledImageType,
"any SPIR-V sampled image type">;
+def SPIRV_AnyTensorArm : DialectType<SPIRV_Dialect, SPIRV_IsTensorArmType,
+ "any SPIR-V tensorArm type">;
def SPIRV_Numerical : AnyTypeOf<[SPIRV_Integer, SPIRV_AnyFloat]>;
def SPIRV_Scalar : AnyTypeOf<[SPIRV_Numerical, SPIRV_Bool]>;
@@ -4228,7 +4253,7 @@ def SPIRV_Type : AnyTypeOf<[
SPIRV_Void, SPIRV_Bool, SPIRV_Integer, SPIRV_AnyFloat, SPIRV_Vector,
SPIRV_AnyPtr, SPIRV_AnyArray, SPIRV_AnyRTArray, SPIRV_AnyStruct,
SPIRV_AnyCooperativeMatrix, SPIRV_AnyMatrix, SPIRV_AnySampledImage,
- SPIRV_AnyImage
+ SPIRV_AnyImage, SPIRV_AnyTensorArm
]>;
def SPIRV_SignedInt : SignedIntOfWidths<[8, 16, 32, 64]>;
@@ -4525,6 +4550,7 @@ def SPIRV_OC_OpGroupNonUniformBitwiseXor : I32EnumAttrCase<"OpGroupNonUnifo
def SPIRV_OC_OpGroupNonUniformLogicalAnd : I32EnumAttrCase<"OpGroupNonUniformLogicalAnd", 362>;
def SPIRV_OC_OpGroupNonUniformLogicalOr : I32EnumAttrCase<"OpGroupNonUniformLogicalOr", 363>;
def SPIRV_OC_OpGroupNonUniformLogicalXor : I32EnumAttrCase<"OpGroupNonUniformLogicalXor", 364>;
+def SPIRV_OC_OpTypeTensorARM : I32EnumAttrCase<"OpTypeTensorARM", 4163>;
def SPIRV_OC_OpSubgroupBallotKHR : I32EnumAttrCase<"OpSubgroupBallotKHR", 4421>;
def SPIRV_OC_OpGroupNonUniformRotateKHR : I32EnumAttrCase<"OpGroupNonUniformRotateKHR", 4431>;
def SPIRV_OC_OpSDot : I32EnumAttrCase<"OpSDot", 4450>;
@@ -4638,7 +4664,9 @@ def SPIRV_OpcodeAttr :
SPIRV_OC_OpGroupNonUniformFMax, SPIRV_OC_OpGroupNonUniformBitwiseAnd,
SPIRV_OC_OpGroupNonUniformBitwiseOr, SPIRV_OC_OpGroupNonUniformBitwiseXor,
SPIRV_OC_OpGroupNonUniformLogicalAnd, SPIRV_OC_OpGroupNonUniformLogicalOr,
- SPIRV_OC_OpGroupNonUniformLogicalXor, SPIRV_OC_OpSubgroupBallotKHR,
+ SPIRV_OC_OpGroupNonUniformLogicalXor,
+ SPIRV_OC_OpTypeTensorARM,
+ SPIRV_OC_OpSubgroupBallotKHR,
SPIRV_OC_OpGroupNonUniformRotateKHR, SPIRV_OC_OpSDot, SPIRV_OC_OpUDot,
SPIRV_OC_OpSUDot, SPIRV_OC_OpSDotAccSat, SPIRV_OC_OpUDotAccSat,
SPIRV_OC_OpSUDotAccSat, SPIRV_OC_OpTypeCooperativeMatrixKHR,
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
index 787535d0a6bd2..7ffea6e7dba81 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVTypes.h
@@ -29,6 +29,7 @@ namespace spirv {
namespace detail {
struct ArrayTypeStorage;
struct CooperativeMatrixTypeStorage;
+struct TensorArmTypeStorage;
struct ImageTypeStorage;
struct MatrixTypeStorage;
struct PointerTypeStorage;
@@ -96,7 +97,8 @@ class ScalarType : public SPIRVType {
std::optional<int64_t> getSizeInBytes();
};
-// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType.
+// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V
+// StructType.
class CompositeType : public SPIRVType {
public:
using SPIRVType::SPIRVType;
@@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType,
std::optional<StorageClass> storage = std::nullopt);
};
+// SPIR-V TensorARM Type
+class TensorArmType
+ : public Type::TypeBase<TensorArmType, CompositeType,
+ detail::TensorArmTypeStorage, ShapedType::Trait> {
+public:
+ using Base::Base;
+
+ static constexpr StringLiteral name = "spirv.arm.tensor";
+
+ // TensorArm supports minimum rank of 1, hence an empty shape here means
+ // unranked.
+ static TensorArmType get(ArrayRef<int64_t> shape, Type elementType);
+ TensorArmType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const;
+
+ static LogicalResult
+ verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
+ ArrayRef<int64_t> shape, Type elementType);
+
+ Type getElementType() const;
+ ArrayRef<int64_t> getShape() const;
+ unsigned getNumElements() const;
+ bool hasRank() const { return !getShape().empty(); }
+ operator ShapedType() const { return llvm::cast<ShapedType>(*this); }
+
+ void getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage = std::nullopt);
+ void getCapabilities(SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage = std::nullopt);
+};
+
} // namespace spirv
} // namespace mlir
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
index a21acef1c4b43..15002f1d5d16e 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVDialect.cpp
@@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect,
<< t.getNumElements();
return Type();
}
+ } else if (auto t = llvm::dyn_cast<TensorArmType>(type)) {
+ if (!llvm::isa<ScalarType>(t.getElementType())) {
+ parser.emitError(
+ typeLoc, "only scalar element type allowed in tensor type but found ")
+ << t.getElementType();
+ return Type();
+ }
} else {
parser.emitError(typeLoc, "cannot use ")
<< type << " to compose SPIR-V types";
@@ -363,6 +370,54 @@ static Type parseCooperativeMatrixType(SPIRVDialect const &dialect,
return CooperativeMatrixType::get(elementTy, dims[0], dims[1], scope, use);
}
+// tensor-arm-type ::=
+// `!spirv.arm.tensor` `<` dim0 `x` dim1 `x` ... `x` dimN `x` element-type`>`
+static Type parseTensorArmType(SPIRVDialect const &dialect,
+ DialectAsmParser &parser) {
+ if (parser.parseLess())
+ return {};
+
+ bool unranked = false;
+ SmallVector<int64_t, 4> dims;
+ SMLoc countLoc = parser.getCurrentLocation();
+
+ if (parser.parseOptionalStar().succeeded()) {
+ unranked = true;
+ if (parser.parseXInDimensionList())
+ return {};
+ } else if (parser.parseDimensionList(dims, /*allowDynamic=*/true))
+ return {};
+
+ if (!unranked && dims.empty()) {
+ parser.emitError(countLoc, "arm.tensors do not support rank zero");
+ return {};
+ }
+
+ if (std::any_of(dims.begin(), dims.end(),
+ [](int64_t dim) { return dim == 0; })) {
+ parser.emitError(countLoc, "arm.tensors do not support zero dimensions");
+ return {};
+ }
+
+ if (std::any_of(dims.begin(), dims.end(),
+ [](int64_t dim) { return dim < 0; }) &&
+ std::any_of(dims.begin(), dims.end(),
+ [](int64_t dim) { return dim > 0; })) {
+ parser.emitError(countLoc, "arm.tensor shape dimensions must be either "
+ "fully dynamic or completed shaped");
+ return {};
+ }
+
+ auto elementTy = parseAndVerifyType(dialect, parser);
+ if (!elementTy)
+ return {};
+
+ if (parser.parseGreater())
+ return {};
+
+ return TensorArmType::get(dims, elementTy);
+}
+
// TODO: Reorder methods to be utilities first and parse*Type
// methods in alphabetical order
//
@@ -759,6 +814,8 @@ Type SPIRVDialect::parseType(DialectAsmParser &parser) const {
return parseStructType(*this, parser);
if (keyword == "matrix")
return parseMatrixType(*this, parser);
+ if (keyword == "arm.tensor")
+ return parseTensorArmType(*this, parser);
parser.emitError(parser.getNameLoc(), "unknown SPIR-V type: ") << keyword;
return Type();
}
@@ -855,10 +912,28 @@ static void print(MatrixType type, DialectAsmPrinter &os) {
os << ">";
}
+static void print(TensorArmType type, DialectAsmPrinter &os) {
+ os << "arm.tensor<";
+
+ llvm::interleave(
+ type.getShape(), os,
+ [&](int64_t dim) {
+ if (ShapedType::isDynamic(dim))
+ os << '?';
+ else
+ os << dim;
+ },
+ "x");
+ if (!type.hasRank()) {
+ os << "*";
+ }
+ os << "x" << type.getElementType() << ">";
+}
+
void SPIRVDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<ArrayType, CooperativeMatrixType, PointerType, RuntimeArrayType,
- ImageType, SampledImageType, StructType, MatrixType>(
+ ImageType, SampledImageType, StructType, MatrixType, TensorArmType>(
[&](auto type) { print(type, os); })
.Default([](Type) { llvm_unreachable("unhandled SPIR-V type"); });
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 7148027dae78d..eb2974d62fdd1 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -547,6 +547,12 @@ ParseResult spirv::ConstantOp::parse(OpAsmParser &parser,
return failure();
}
+ if (llvm::isa<TensorArmType>(type)) {
+ if (parser.parseOptionalColon().succeeded())
+ if (parser.parseType(type))
+ return failure();
+ }
+
return parser.addTypeToList(type, result.types);
}
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
index 93e0c9b33c546..e4eeb0a7f37d5 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVTypes.cpp
@@ -18,8 +18,10 @@
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <algorithm>
#include <cstdint>
#include <iterator>
+#include <numeric>
using namespace mlir;
using namespace mlir::spirv;
@@ -96,7 +98,7 @@ bool CompositeType::classof(Type type) {
return isValid(vectorType);
return llvm::isa<spirv::ArrayType, spirv::CooperativeMatrixType,
spirv::MatrixType, spirv::RuntimeArrayType,
- spirv::StructType>(type);
+ spirv::StructType, spirv::TensorArmType>(type);
}
bool CompositeType::isValid(VectorType type) {
@@ -107,8 +109,8 @@ bool CompositeType::isValid(VectorType type) {
Type CompositeType::getElementType(unsigned index) const {
return TypeSwitch<Type, Type>(*this)
- .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType>(
- [](auto type) { return type.getElementType(); })
+ .Case<ArrayType, CooperativeMatrixType, RuntimeArrayType, VectorType,
+ TensorArmType>([](auto type) { return type.getElementType(); })
.Case<MatrixType>([](MatrixType type) { return type.getColumnType(); })
.Case<StructType>(
[index](StructType type) { return type.getElementType(index); })
@@ -125,6 +127,8 @@ unsigned CompositeType::getNumElements() const {
return structType.getNumElements();
if (auto vectorType = llvm::dyn_cast<VectorType>(*this))
return vectorType.getNumElements();
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this))
+ return tensorArmType.getNumElements();
if (llvm::isa<CooperativeMatrixType>(*this)) {
llvm_unreachable(
"invalid to query number of elements of spirv Cooperative Matrix type");
@@ -151,6 +155,14 @@ void CompositeType::getExtensions(
return llvm::cast<ScalarType>(type.getElementType())
.getExtensions(extensions, storage);
})
+ .Case<TensorArmType>([&](TensorArmType type) {
+ static const Extension exts[] = {Extension::SPV_ARM_tensors};
+ ArrayRef<Extension> ref(exts, std::size(exts));
+ extensions.push_back(ref);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getExtensions(extensions, storage);
+ })
+
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -171,6 +183,13 @@ void CompositeType::getCapabilities(
return llvm::cast<ScalarType>(type.getElementType())
.getCapabilities(capabilities, storage);
})
+ .Case<TensorArmType>([&](TensorArmType type) {
+ static const Capability caps[] = {Capability::TensorsARM};
+ ArrayRef<Capability> ref(caps, std::size(caps));
+ capabilities.push_back(ref);
+ return llvm::cast<ScalarType>(type.getElementType())
+ .getCapabilities(capabilities, storage);
+ })
.Default([](Type) { llvm_unreachable("invalid composite type"); });
}
@@ -186,6 +205,13 @@ std::optional<int64_t> CompositeType::getSizeInBytes() {
return std::nullopt;
return *elementSize * vectorType.getNumElements();
}
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ std::optional<int64_t> elementSize =
+ llvm::cast<ScalarType>(tensorArmType.getElementType()).getSizeInBytes();
+ if (!elementSize)
+ return std::nullopt;
+ return *elementSize * tensorArmType.getNumElements();
+ }
return std::nullopt;
}
@@ -691,6 +717,9 @@ bool SPIRVType::classof(Type type) {
return true;
if (auto vectorType = llvm::dyn_cast<VectorType>(type))
return CompositeType::isValid(vectorType);
+ if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(type)) {
+ return llvm::isa<ScalarType>(tensorArmType.getElementType());
+ }
return false;
}
@@ -712,6 +741,8 @@ void SPIRVType::getExtensions(SPIRVType::ExtensionArrayRefVector &extensions,
matrixType.getExtensions(extensions, storage);
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getExtensions(extensions, storage);
+ } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ tensorArmType.getExtensions(extensions, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getExtensions");
}
@@ -732,6 +763,8 @@ void SPIRVType::getCapabilities(
matrixType.getCapabilities(capabilities, storage);
} else if (auto ptrType = llvm::dyn_cast<PointerType>(*this)) {
ptrType.getCapabilities(capabilities, storage);
+ } else if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) {
+ tensorArmType.getCapabilities(capabilities, storage);
} else {
llvm_unreachable("invalid SPIR-V Type to getCapabilities");
}
@@ -1203,11 +1236,94 @@ void MatrixType::getCapabilities(
llvm::cast<SPIRVType>(getColumnType()).getCapabilities(capabilities, storage);
}
+//===----------------------------------------------------------------------===//
+// TensorArmType
+//===----------------------------------------------------------------------===//
+
+struct spirv::detail::TensorArmTypeStorage final : TypeStorage {
+ using KeyTy = std::tuple<ArrayRef<int64_t>, Type>;
+
+ static TensorArmTypeStorage *construct(TypeStorageAllocator &allocator,
+ const KeyTy &key) {
+ auto shape = std::get<0>(key);
+ auto elementType = std::get<1>(key);
+ shape = allocator.copyInto(shape);
+ return new (allocator.allocate<TensorArmTypeStorage>())
+ TensorArmTypeStorage(std::move(shape), std::move(elementType));
+ }
+
+ static llvm::hash_code hashKey(const KeyTy &key) {
+ return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
+ }
+
+ bool operator==(const KeyTy &key) const {
+ return key == KeyTy(shape, elementType);
+ }
+
+ TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType)
+ : shape(std::move(shape)), elementType(std::move(elementType)) {}
+
+ ArrayRef<int64_t> shape;
+ Type elementType;
+};
+
+TensorArmType TensorArmType::get(ArrayRef<int64_t> shape, Type elementType) {
+ return Base::get(elementType.getContext(), shape, elementType);
+}
+
+TensorArmType TensorArmType::cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ return TensorArmType::get(shape.value_or(getShape()), elementType);
+}
+
+Type TensorArmType::getElementType() const { return getImpl()->elementType; }
+ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; }
+
+unsigned TensorArmType::getNumElements() const {
+ auto shape = getShape();
+ return std::accumulate(shape.begin(), shape.end(), unsigned(1),
+ std::multiplies<unsigned>());
+}
+
+void TensorArmType::getExtensions(
+ SPIRVType::ExtensionArrayRefVector &extensions,
+ std::optional<StorageClass> storage) {
+
+ llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage);
+ static constexpr Extension exts[] = {Extension::SPV_ARM_tensors};
+ extensions.push_back(exts);
+}
+
+void TensorArmType::getCapabilities(
+ SPIRVType::CapabilityArrayRefVector &capabilities,
+ std::optional<StorageClass> storage) {
+ llvm::cast<SPIRVType>(getElementType())
+ .getCapabilities(capabilities, storage);
+ static constexpr Capab...
[truncated]
|
Implementation as agreed on this RFC discussion https://discourse.llvm.org/t/rfc-enable-use-of-mlir-tensor-type-with-spv-arm-tensors-extension-in-spir-v-dialect/86789 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for contributing this. I added few comments.
I pointed it out few times, but I believe you don't need llvm::
with casts, etc. It's an artifact of refactoring that was done in the past: #141458 (comment)
Also, we probably should add a test in here: https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/SPIRV/IR/availability.mlir
@@ -4179,7 +4202,7 @@ def SPIRV_IsPtrType : CPred<"::llvm::isa<::mlir::spirv::PointerType>($_self)">; | |||
def SPIRV_IsRTArrayType : CPred<"::llvm::isa<::mlir::spirv::RuntimeArrayType>($_self)">; | |||
def SPIRV_IsSampledImageType : CPred<"::llvm::isa<::mlir::spirv::SampledImageType>($_self)">; | |||
def SPIRV_IsStructType : CPred<"::llvm::isa<::mlir::spirv::StructType>($_self)">; | |||
|
|||
def SPIRV_IsTensorArmType : CPred<"::llvm::isa<::mlir::spirv::TensorArmType>($_self)">; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
meta: I know it comes from the spec so probably should stay this way, but ArmTensor sounds more natural than TensorArm.
@@ -477,6 +479,37 @@ class MatrixType : public Type::TypeBase<MatrixType, CompositeType, | |||
std::optional<StorageClass> storage = std::nullopt); | |||
}; | |||
|
|||
// SPIR-V TensorARM Type |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here and below it probably should be ///
instead of //
.
@@ -194,6 +194,13 @@ static Type parseAndVerifyType(SPIRVDialect const &dialect, | |||
<< t.getNumElements(); | |||
return Type(); | |||
} | |||
} else if (auto t = llvm::dyn_cast<TensorArmType>(type)) { | |||
if (!llvm::isa<ScalarType>(t.getElementType())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think there is a need for llvm::
. Also line above.
@@ -125,6 +127,8 @@ unsigned CompositeType::getNumElements() const { | |||
return structType.getNumElements(); | |||
if (auto vectorType = llvm::dyn_cast<VectorType>(*this)) | |||
return vectorType.getNumElements(); | |||
if (auto tensorArmType = llvm::dyn_cast<TensorArmType>(*this)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for llvm::
... I think :) Also below.
TensorArmTypeStorage(ArrayRef<int64_t> shape, Type elementType) | ||
: shape(std::move(shape)), elementType(std::move(elementType)) {} | ||
|
||
ArrayRef<int64_t> shape; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks suspicious to me. I don't think an ArrayRef
should be a storage type, since the ArrayRef
itself doesn't keep any data - it's only "pointer" to data. I think this should be a SmallVector
or something similar. But please wait for someone to confirm that I’m not wrong :)
@@ -138,6 +138,7 @@ LogicalResult spirv::Deserializer::processHeader() { | |||
MIN_VERSION_CASE(3); | |||
MIN_VERSION_CASE(4); | |||
MIN_VERSION_CASE(5); | |||
MIN_VERSION_CASE(6); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably should land as a separte PR.
@@ -1238,6 +1241,55 @@ spirv::Deserializer::processMatrixType(ArrayRef<uint32_t> operands) { | |||
return success(); | |||
} | |||
|
|||
LogicalResult | |||
spirv::Deserializer::processTensorARMType(ArrayRef<uint32_t> operands) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some of the if
statements probably don't need braces: https://llvm.org/docs/CodingStandards.html#don-t-use-braces-on-simple-single-statement-bodies-of-if-else-loop-statements
return success(); | ||
} | ||
|
||
auto rankAttr = getConstantInt(operands[2]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please spell out the type: https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable
return success(); | ||
} | ||
|
||
auto shapeInfo = getConstant(operands[3]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above: auto
.
@@ -96,7 +97,8 @@ class ScalarType : public SPIRVType { | |||
std::optional<int64_t> getSizeInBytes(); | |||
}; | |||
|
|||
// SPIR-V composite type: VectorType, SPIR-V ArrayType, or SPIR-V StructType. | |||
// SPIR-V composite type: TensorArmType, VectorType, SPIR-V ArrayType, or SPIR-V |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you put the tensor type at the end, since it's going to be the least commonly used one?
unranked = true; | ||
if (parser.parseXInDimensionList()) | ||
return {}; | ||
} else if (parser.parseDimensionList(dims, /*allowDynamic=*/true)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the llvm coding style requires that when either of if/else uses braces, the other statement should too
if (std::any_of(dims.begin(), dims.end(), | ||
[](int64_t dim) { return dim == 0; })) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use llvm::is_contained
if (std::any_of(dims.begin(), dims.end(), | ||
[](int64_t dim) { return dim < 0; }) && | ||
std::any_of(dims.begin(), dims.end(), | ||
[](int64_t dim) { return dim > 0; })) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
usellvm::any_of
|
||
Type getElementType() const; | ||
ArrayRef<int64_t> getShape() const; | ||
unsigned getNumElements() const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What happens when this is called on an unranked tensor?
ArrayRef<int64_t> TensorArmType::getShape() const { return getImpl()->shape; } | ||
|
||
unsigned TensorArmType::getNumElements() const { | ||
auto shape = getShape(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't use auto here since the type is not obvious based on the RHS alone: https://llvm.org/docs/CodingStandards.html#use-auto-type-deduction-to-make-code-more-readable
std::optional<StorageClass> storage) { | ||
|
||
llvm::cast<SPIRVType>(getElementType()).getExtensions(extensions, storage); | ||
static constexpr Extension exts[] = {Extension::SPV_ARM_tensors}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this doesn't have to be an array
std::optional<StorageClass> storage) { | ||
llvm::cast<SPIRVType>(getElementType()) | ||
.getCapabilities(capabilities, storage); | ||
static constexpr Capability caps[] = {Capability::TensorsARM}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also here
Type elementTy = getType(operands[1]); | ||
if (!elementTy) { | ||
return emitError(unknownLoc, | ||
"OpTypeTensorARM references undefined element type.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we want to print the .
before the operand?
if (size < 2 || size > 4) { | ||
return emitError(unknownLoc, "OpTypeTensorARM must have 2-4 operands " | ||
"(result_id, element_type, (rank), (shape))") | ||
<< size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add a space before printing the size?
This patch introduces a new custom type
!spirv.arm.tensor<>
to the MLIR SPIR-V dialect to representOpTypeTensorARM
as defined in theSPV_ARM_tensors
extension.The type models a shaped tensor with element type and optional shape, and implements the
ShapedType
interface to enable reuse of MLIR's existing shape-aware infrastructure.The type supports serialization to and from SPIR-V binary as
OpTypeTensorARM
, and emits therequired capability (
TensorsARM
) and extension (SPV_ARM_tensors
) declarations automatically.This addition lays the foundation for supporting structured tensor values natively in SPIR-V and
will enable future support for operations defined in the
SPV_ARM_tensors
extension, such asOpTensorReadARM
,OpTensorWriteARM
, andOpTensorQuerySizeARM
.Reference: KhronosGroup/SPIRV-Registry#342