diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index cab0ab8d15d5d..6a663071d0d74 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -1057,7 +1057,7 @@ struct DimOfMemRefReshape : public OpRewritePattern { } } // else dim.getIndex is a block argument to reshape->getBlock and // dominates reshape - } // Check condition 2 + } // Check condition 2 else if (dim->getBlock() != reshape->getBlock() && !dim.getIndex().getParentRegion()->isProperAncestor( reshape->getParentRegion())) { @@ -1840,6 +1840,15 @@ LogicalResult ReinterpretCastOp::verify() { // Match sizes in result memref type and in static_sizes attribute. for (auto [idx, resultSize, expectedSize] : llvm::enumerate(resultType.getShape(), getStaticSizes())) { + // Check that dynamic sizes are not mixed with static sizes + if (ShapedType::isDynamic(resultSize) && + !ShapedType::isDynamic(expectedSize)) + return emitError( + "expected size is static, but result type dimension is dynamic "); + if (!ShapedType::isDynamic(resultSize) && + ShapedType::isDynamic(expectedSize)) + return emitError( + "expected size is dynamic, but result type dimension is static "); if (!ShapedType::isDynamic(resultSize) && resultSize != expectedSize) return emitError("expected result type with size = ") << (ShapedType::isDynamic(expectedSize) @@ -1859,6 +1868,15 @@ LogicalResult ReinterpretCastOp::verify() { // Match offset in result memref type and in static_offsets attribute. int64_t expectedOffset = getStaticOffsets().front(); + // Check that dynamic offset is not mixed with static offset + if (ShapedType::isDynamic(resultOffset) && + !ShapedType::isDynamic(expectedOffset)) + return emitError( + "expected offset is static, but result type offset is dynamic"); + if (!ShapedType::isDynamic(resultOffset) && + ShapedType::isDynamic(expectedOffset)) + return emitError( + "expected offset is dynamic, but result type offset is static"); if (!ShapedType::isDynamic(resultOffset) && resultOffset != expectedOffset) return emitError("expected result type with offset = ") << (ShapedType::isDynamic(expectedOffset) @@ -2011,7 +2029,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder // Second, check the sizes. if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(), op.getConstifiedMixedSizes())) - return false; + return false; // Finally, check the offset. assert(op.getMixedOffsets().size() == 1 && diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir index 8906de9db3724..fbc1b8ca42377 100644 --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -339,30 +339,33 @@ func.func @reinterpret_cast(%arg: memref, i32 // CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> // CHECK: return %[[RET1]] - %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> + %c10 = arith.constant 10 : index + %ret = memref.reinterpret_cast %arg to offset: [%arg1], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> return %ret : memref, #spirv.storage_class> } // CHECK-LABEL: func.func @reinterpret_cast_0 // CHECK-SAME: (%[[MEM:.*]]: memref>) -func.func @reinterpret_cast_0(%arg: memref>) -> memref, #spirv.storage_class> { +func.func @reinterpret_cast_0(%arg: memref>) -> memref, #spirv.storage_class> { // CHECK-DAG: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr -// CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK-DAG: %[[RET:.*]] = builtin.unrealized_conversion_cast %[[MEM1]] : !spirv.ptr to memref, #spirv.storage_class> // CHECK: return %[[RET]] - %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> - return %ret : memref, #spirv.storage_class> + %c10 = arith.constant 10 : index + %ret = memref.reinterpret_cast %arg to offset: [0], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> } // CHECK-LABEL: func.func @reinterpret_cast_5 // CHECK-SAME: (%[[MEM:.*]]: memref>) -func.func @reinterpret_cast_5(%arg: memref>) -> memref, #spirv.storage_class> { +func.func @reinterpret_cast_5(%arg: memref>) -> memref, #spirv.storage_class> { // CHECK: %[[MEM1:.*]] = builtin.unrealized_conversion_cast %[[MEM]] : memref> to !spirv.ptr // CHECK: %[[OFF:.*]] = spirv.Constant 5 : i32 // CHECK: %[[RET:.*]] = spirv.InBoundsPtrAccessChain %[[MEM1]][%[[OFF]]] : !spirv.ptr, i32 -// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> +// CHECK: %[[RET1:.*]] = builtin.unrealized_conversion_cast %[[RET]] : !spirv.ptr to memref, #spirv.storage_class> // CHECK: return %[[RET1]] - %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [10], strides: [1] : memref> to memref, #spirv.storage_class> - return %ret : memref, #spirv.storage_class> + %c10 = arith.constant 10 : index + %ret = memref.reinterpret_cast %arg to offset: [5], sizes: [%c10], strides: [1] : memref> to memref, #spirv.storage_class> + return %ret : memref, #spirv.storage_class> } } // end module diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir index e7cee7cd85426..7d293fcec0083 100644 --- a/mlir/test/Dialect/MemRef/canonicalize.mlir +++ b/mlir/test/Dialect/MemRef/canonicalize.mlir @@ -923,13 +923,13 @@ func.func @reinterpret_of_extract_strided_metadata_w_type_mistach(%arg0 : memref // same constant value, the match is valid. // CHECK-LABEL: func @reinterpret_of_extract_strided_metadata_w_constants // CHECK-SAME: (%[[ARG:.*]]: memref<8x2xf32>) -// CHECK: %[[CAST:.*]] = memref.cast %[[ARG]] : memref<8x2xf32> to memref to memref) -> memref> { +func.func @reinterpret_of_extract_strided_metadata_w_constants(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index %c8 = arith.constant 8: index - %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref to memref> - return %m2 : memref> + %m2 = memref.reinterpret_cast %base to offset: [0], sizes: [%c8, 2], strides: [2, %strides#1] : memref to memref> + return %m2 : memref> } // ----- @@ -954,10 +954,10 @@ func.func @reinterpret_of_extract_strided_metadata_same_type(%arg0 : memref) -> memref> { +func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : memref<8x2xf32>) -> memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index - %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref to memref> - return %m2 : memref> + %m2 = memref.reinterpret_cast %base to offset: [%offset], sizes: [4, 2, 2], strides: [1, 1, %strides#1] : memref to memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> + return %m2 : memref<4x2x2xf32, strided<[?, ?, ?], offset: ?>> } // ----- @@ -970,10 +970,10 @@ func.func @reinterpret_of_extract_strided_metadata_w_different_stride(%arg0 : me // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK: %[[RES:.*]] = memref.reinterpret_cast %[[ARG]] to offset: [1], sizes: [%[[C8]], %[[C2]]], strides: [%[[C2]], %[[C1]]] // CHECK: return %[[RES]] -func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref> { +func.func @reinterpret_of_extract_strided_metadata_w_different_offset(%arg0 : memref<8x2xf32>) -> memref> { %base, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %arg0 : memref<8x2xf32> -> memref, index, index, index, index, index - %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref> - return %m2 : memref> + %m2 = memref.reinterpret_cast %base to offset: [1], sizes: [%sizes#0, %sizes#1], strides: [%strides#0, %strides#1] : memref to memref> + return %m2 : memref> } // ----- diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir index f908efb638446..5352b8d8ce60a 100644 --- a/mlir/test/Dialect/MemRef/invalid.mlir +++ b/mlir/test/Dialect/MemRef/invalid.mlir @@ -214,16 +214,6 @@ func.func @memref_reinterpret_cast_no_map_but_offset(%in: memref) { : memref to memref<10xf32> return } - -// ----- - -func.func @memref_reinterpret_cast_offset_mismatch_dynamic(%in: memref, %offset : index) { - // expected-error @+1 {{expected result type with offset = dynamic instead of 0}} - %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10], strides: [1] - : memref to memref<10xf32> - return -} - // ----- func.func @memref_reinterpret_cast_no_map_but_stride(%in: memref) { @@ -245,6 +235,43 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref) { // ----- +func.func @memref_reinterpret_cast_static_dynamic_size_mismatch(%in: memref<1x?x2x1xf32>) { + // expected-error@+1 {{expected size is static, but result type dimension is dynamic }} + %out = memref.reinterpret_cast %in to + offset: [0], sizes: [1, 4672, 1, 1], strides: [4672, 8, 8, 1] + : memref<1x?x2x1xf32> to memref<1x4672x?x1xf32> +} + +// ----- + +func.func @memref_reinterpret_cast_dynamic_static_size_mismatch(%in: memref<1x?x2x1xf32>, %size: index) { + // expected-error@+1 {{expected size is dynamic, but result type dimension is static }} + %out = memref.reinterpret_cast %in to + offset: [0], sizes: [1, %size, 1, 1], strides: [4672, 8, 8, 1] + : memref<1x?x2x1xf32> to memref<1x4672x2x1xf32> + return +} + +// ----- + +func.func @memref_reinterpret_cast_static_dynamic_offset_mismatch(%in: memref) { + // expected-error@+1 {{expected offset is static, but result type offset is dynamic}} + %out = memref.reinterpret_cast %in to + offset: [0], sizes: [10], strides: [1] + : memref to memref<10xf32, strided<[1], offset: ?>> +} + +// ----- + +func.func @memref_reinterpret_cast_dynamic_static_offset_mismatch(%in: memref, %offset: index) { + // expected-error@+1 {{expected offset is dynamic, but result type offset is static}} + %out = memref.reinterpret_cast %in to + offset: [%offset], sizes: [10], strides: [1] + : memref to memref<10xf32, strided<[1], offset: 0>> + return +} + +// ----- func.func @memref_reshape_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{element types of source and destination memref types should be the same}} diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir index 38ee363a7d424..33efd66fb25a4 100644 --- a/mlir/test/Dialect/MemRef/ops.mlir +++ b/mlir/test/Dialect/MemRef/ops.mlir @@ -131,20 +131,20 @@ func.func @memref_reinterpret_cast(%in: memref) // CHECK-LABEL: func @memref_reinterpret_cast_static_to_dynamic_sizes func.func @memref_reinterpret_cast_static_to_dynamic_sizes(%in: memref) - -> memref<10x?xf32, strided<[?, 1], offset: ?>> { + -> memref<10x10xf32, strided<[?, 1], offset: 1>> { %out = memref.reinterpret_cast %in to offset: [1], sizes: [10, 10], strides: [1, 1] - : memref to memref<10x?xf32, strided<[?, 1], offset: ?>> - return %out : memref<10x?xf32, strided<[?, 1], offset: ?>> + : memref to memref<10x10xf32, strided<[?, 1], offset: 1>> + return %out : memref<10x10xf32, strided<[?, 1], offset: 1>> } // CHECK-LABEL: func @memref_reinterpret_cast_dynamic_offset func.func @memref_reinterpret_cast_dynamic_offset(%in: memref, %offset: index) - -> memref<10x?xf32, strided<[?, 1], offset: ?>> { + -> memref<10x10xf32, strided<[?, 1], offset: ?>> { %out = memref.reinterpret_cast %in to offset: [%offset], sizes: [10, 10], strides: [1, 1] - : memref to memref<10x?xf32, strided<[?, 1], offset: ?>> - return %out : memref<10x?xf32, strided<[?, 1], offset: ?>> + : memref to memref<10x10xf32, strided<[?, 1], offset: ?>> + return %out : memref<10x10xf32, strided<[?, 1], offset: ?>> } // CHECK-LABEL: func @memref_reshape(