From 01fa7a292177dbbd81de9ca2fdc01fa0cf0839d9 Mon Sep 17 00:00:00 2001 From: Yaniv Kaniel Date: Thu, 15 May 2025 12:33:48 +0300 Subject: [PATCH 1/2] Validate type consistency in reintepret cast sizes Ensure that when peforming a reinterpret cast, the expected size and the result size are of the same type. Emit an error if one of the dimensions has a static size and the corresponding dimension has a dynamic size in the other. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 13 +++++++++++-- .../MemRefToSPIRV/memref-to-spirv.mlir | 9 ++++++--- mlir/test/Dialect/MemRef/canonicalize.mlir | 14 +++++++------- mlir/test/Dialect/MemRef/invalid.mlir | 18 ++++++++++++++++++ mlir/test/Dialect/MemRef/ops.mlir | 12 ++++++------ 5 files changed, 48 insertions(+), 18 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index a0237c18cf2fe..5a348b823d02b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1052,7 +1052,7 @@ struct DimOfMemRefReshape : public OpRewritePattern { } } // else dim.getIndex is a block argument to reshape->getBlock and // dominates reshape - } // Check condition 2 + } // Check condition 2 else if (dim->getBlock() != reshape->getBlock() && !dim.getIndex().getParentRegion()->isProperAncestor( reshape->getParentRegion())) { @@ -1835,6 +1835,15 @@ LogicalResult ReinterpretCastOp::verify() { // Match sizes in result memref type and in static_sizes attribute. for (auto [idx, resultSize, expectedSize] : llvm::enumerate(resultType.getShape(), getStaticSizes())) { + // Check that dynamic sizes are not mixed with static sizes + if (ShapedType::isDynamic(resultSize) && + !ShapedType::isDynamic(expectedSize)) + return emitError( + "expectedSize is static but received a dynamic resultSize "); + if (!ShapedType::isDynamic(resultSize) && + ShapedType::isDynamic(expectedSize)) + return emitError( + "expectedSize is dynamic but received a static resultSize "); if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize) return emitError("expected result type with size = ") << (ShapedType::isDynamic(expectedSize) @@ -2008,7 +2017,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder // Second, check the sizes. if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(), op.getConstifiedMixedSizes())) - return false; + return false; // Finally, check the offset. assert(op.getMixedOffsets().size() == 1 && diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 8906de9db3724..18b151c469da6 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -339,7 +339,8 @@ func.func @reinterpret_cast(%arg: memref, i32 // CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> // CHECK: return %[[RET1]] - %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + %c10 = arith.constant 10 : index + %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> return %ret : memref, #spirv.storage_class> } @@ -349,7 +350,8 @@ func.func @reinterpret_cast_0(%arg: memref> to !spirv.ptr // CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr to memref, #spirv.storage_class> // CHECK: return %[[RET]] - %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + %c10 = arith.constant 10 : index + %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> return %ret : memref, #spirv.storage_class> } @@ -361,7 +363,8 @@ func.func @reinterpret_cast_5(%arg: memref, i32 // CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> // CHECK: return %[[RET1]] - %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + %c10 = arith.constant 10 : index + %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> return %ret : memref, #spirv.storage_class> } diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index e7cee7cd85426..a53a5d10eceb5 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -923,13 +923,13 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref // same constant value, the match is valid. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref to memref) -> memref> { +func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index %c8 = arith.constant 8: index - %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref to memref> - return %m2 : memref> + %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref to memref> + return %m2 : memref> } // ----- @@ -954,10 +954,10 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref) -> memref> { +func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index - %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref to memref> - return %m2 : memref> + %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref to memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> + return %m2 : memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> } // ----- diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index 34fc4775924e7..c98d4913dc5d2 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -245,6 +245,24 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref) { // ----- +func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x2x1xf32>) { + // expected-error@+1 {{expectedSize is static but received a dynamic resultSize}} + %out = memref.reinterpret_cast %in to + offset: [0], sizes: [1, 4672, 1, 1], strides: [4672, 8, 8, 1] + : memref<1x?x2x1xf32> to memref<1x4672x?x1xf32> +} + +// ----- + +func.func @memref_reinterpret_cast_dynamic_static_size_mismatch(%in: memref<1x?x2x1xf32>, %size: index) { + // expected-error@+1 {{expectedSize is dynamic but received a static resultSize}} + %out = memref.reinterpret_cast %in to + offset: [0], sizes: [1, %size, 1, 1], strides: [4672, 8, 8, 1] + : memref<1x?x2x1xf32> to memref<1x4672x2x1xf32> + return +} + +// ----- func.func @memref_reshape_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{element types of source and destination memref types should be the same}} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 7038a6ff744e4..03e344e0e9cf2 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -131,20 +131,20 @@ func.func @memref_reinterpret_cast(%in: memref) // CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref) - -> memref<10x?xf32, strided<[?, 1], offset: ?>> { + -> memref<10x10xf32, strided<[?, 1], offset: ?>> { %out = memref.reinterpret_cast %in to offset: [1], sizes: [10, 10], strides: [1, 1] - : memref to memref<10x?xf32, strided<[?, 1], offset: ?>> - return %out : memref<10x?xf32, strided<[?, 1], offset: ?>> + : memref to memref<10x10xf32, strided<[?, 1], offset: ?>> + return %out : memref<10x10xf32, strided<[?, 1], offset: ?>> } // CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset func.func @memref_reinterpret_cast_dynamic_offset(%in: memref, %offset: index) - -> memref<10x?xf32, strided<[?, 1], offset: ?>> { + -> memref<10x10xf32, strided<[?, 1], offset: ?>> { %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10, 10], strides: [1, 1] - : memref to memref<10x?xf32, strided<[?, 1], offset: ?>> - return %out : memref<10x?xf32, strided<[?, 1], offset: ?>> + : memref to memref<10x10xf32, strided<[?, 1], offset: ?>> + return %out : memref<10x10xf32, strided<[?, 1], offset: ?>> } // CHECK-LABEL: func @memref_reshape( From c9814d98ee08699ce244f9bf473519dce9a3d7d3 Mon Sep 17 00:00:00 2001 From: Yaniv Kaniel Date: Wed, 21 May 2025 14:57:08 +0300 Subject: [PATCH 2/2] Validate type consistency in reintepret cast offsets Ensure that when peforming a reinterpret cast, the expected offset and the result offset are of the same type. Emit an error if one of the dimensions has a static offset and the corresponding dimension has a dynamic offset in the other. Delete previous test that is a specific instance of this case. --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 13 ++++++-- .../MemRefToSPIRV/memref-to-spirv.mlir | 16 ++++----- mlir/test/Dialect/MemRef/canonicalize.mlir | 12 +++---- mlir/test/Dialect/MemRef/invalid.mlir | 33 ++++++++++++------- mlir/test/Dialect/MemRef/ops.mlir | 6 ++-- 5 files changed, 49 insertions(+), 31 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 5a348b823d02b..82fc4eac5b40b 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1839,11 +1839,11 @@ LogicalResult ReinterpretCastOp::verify() { if (ShapedType::isDynamic(resultSize) && !ShapedType::isDynamic(expectedSize)) return emitError( - "expectedSize is static but received a dynamic resultSize "); + "expected size is static, but result type dimension is dynamic "); if (!ShapedType::isDynamic(resultSize) && ShapedType::isDynamic(expectedSize)) return emitError( - "expectedSize is dynamic but received a static resultSize "); + "expected size is dynamic, but result type dimension is static "); if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize) return emitError("expected result type with size = ") << (ShapedType::isDynamic(expectedSize) @@ -1863,6 +1863,15 @@ LogicalResult ReinterpretCastOp::verify() { // Match offset in result memref type and in static_offsets attribute. int64_t expectedOffset = getStaticOffsets().front(); + // Check that dynamic offset is not mixed with static offset + if (ShapedType::isDynamic(resultOffset) && + !ShapedType::isDynamic(expectedOffset)) + return emitError( + "expected offset is static, but result type offset is dynamic"); + if (!ShapedType::isDynamic(resultOffset) && + ShapedType::isDynamic(expectedOffset)) + return emitError( + "expected offset is dynamic, but result type offset is static"); if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset) return emitError("expected result type with offset = ") << (ShapedType::isDynamic(expectedOffset) diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 18b151c469da6..fbc1b8ca42377 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -346,26 +346,26 @@ func.func @reinterpret_cast(%arg: memref>) -func.func @reinterpret_cast_0(%arg: memref>) -> memref, #spirv.storage_class> { +func.func @reinterpret_cast_0(%arg: memref>) -> memref, #spirv.storage_class> { // CHECK-DAG: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr -// CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr to memref, #spirv.storage_class> // CHECK: return %[[RET]] %c10 = arith.constant 10 : index - %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> - return %ret : memref, #spirv.storage_class> + %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> } // CHECK-LABEL: func.func @reinterpret_cast_5 // CHECK-SAME: (%[[MEM:.*]]: memref>) -func.func @reinterpret_cast_5(%arg: memref>) -> memref, #spirv.storage_class> { +func.func @reinterpret_cast_5(%arg: memref>) -> memref, #spirv.storage_class> { // CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr // CHECK: %[[OFF:.*]] = spirv.Constant 5 : i32 // CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr, i32 -// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> // CHECK: return %[[RET1]] %c10 = arith.constant 10 : index - %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> - return %ret : memref, #spirv.storage_class> + %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> } } // end module diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index a53a5d10eceb5..7d293fcec0083 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -925,11 +925,11 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) // CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref) -> memref> { +func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index %c8 = arith.constant 8: index - %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref to memref> - return %m2 : memref> + %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref to memref> + return %m2 : memref> } // ----- @@ -970,10 +970,10 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] // CHECK: return %[[RES]] -func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref> { +func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index - %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref> - return %m2 : memref> + %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref> + return %m2 : memref> } // ----- diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index c98d4913dc5d2..68d88e9214705 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -214,16 +214,6 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref) { : memref to memref<10xf32> return } - -// ----- - -func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref, %offset : index) { - // expected-error @+1 {{expected result type with offset = dynamic instead of 0}} - %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1] - : memref to memref<10xf32> - return -} - // ----- func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref) { @@ -246,7 +236,7 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref) { // ----- func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x2x1xf32>) { - // expected-error@+1 {{expectedSize is static but received a dynamic resultSize}} + // expected-error@+1 {{expected size is static, but result type dimension is dynamic }} %out = memref.reinterpret_cast %in to offset: [0], sizes: [1, 4672, 1, 1], strides: [4672, 8, 8, 1] : memref<1x?x2x1xf32> to memref<1x4672x?x1xf32> @@ -255,13 +245,32 @@ func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x // ----- func.func @memref_reinterpret_cast_dynamic_static_size_mismatch(%in: memref<1x?x2x1xf32>, %size: index) { - // expected-error@+1 {{expectedSize is dynamic but received a static resultSize}} + // expected-error@+1 {{expected size is dynamic, but result type dimension is static }} %out = memref.reinterpret_cast %in to offset: [0], sizes: [1, %size, 1, 1], strides: [4672, 8, 8, 1] : memref<1x?x2x1xf32> to memref<1x4672x2x1xf32> return } +// ----- + +func.func @memref_reinterpret_cast_static_dynamic_offset_mismatch(%in: memref) { + // expected-error@+1 {{expected offset is static, but result type offset is dynamic}} + %out = memref.reinterpret_cast %in to + offset: [0], sizes: [10], strides: [1] + : memref to memref<10xf32, strided<[1], offset: ?>> +} + +// ----- + +func.func @memref_reinterpret_cast_dynamic_static_offset_mismatch(%in: memref, %offset: index) { + // expected-error@+1 {{expected offset is dynamic, but result type offset is static}} + %out = memref.reinterpret_cast %in to + offset: [%offset], sizes: [10], strides: [1] + : memref to memref<10xf32, strided<[1], offset: 0>> + return +} + // ----- func.func @memref_reshape_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 03e344e0e9cf2..0685334cd20ea 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -131,11 +131,11 @@ func.func @memref_reinterpret_cast(%in: memref) // CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref) - -> memref<10x10xf32, strided<[?, 1], offset: ?>> { + -> memref<10x10xf32, strided<[?, 1], offset: 1>> { %out = memref.reinterpret_cast %in to offset: [1], sizes: [10, 10], strides: [1, 1] - : memref to memref<10x10xf32, strided<[?, 1], offset: ?>> - return %out : memref<10x10xf32, strided<[?, 1], offset: ?>> + : memref to memref<10x10xf32, strided<[?, 1], offset: 1>> + return %out : memref<10x10xf32, strided<[?, 1], offset: 1>> } // CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset