Skip to content

[mlir][memref] Fix computeCollapsedLayoutMap for contiguous dynamic dim #136485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallBitVector.h"
#include <algorithm>

using namespace mlir;
using namespace mlir::memref;
Expand Down Expand Up @@ -2401,11 +2402,19 @@ computeCollapsedLayoutMap(MemRefType srcType,
if (!ShapedType::isDynamic(srcShape[ref.back()]) || ref.size() == 1) {
resultStrides.push_back(srcStrides[ref.back()]);
} else {
// Dynamically-sized dims may turn out to be dims of size 1 at runtime, so
// the corresponding stride may have to be skipped. (See above comment.)
// Therefore, the result stride cannot be statically determined and must
// be dynamic.
resultStrides.push_back(ShapedType::kDynamic);
bool contiguousSrcDim = srcStrides[ref.back()] == 1;
bool dynamicSizeIsPreserved =
std::all_of(ref.begin(), ref.end() - 1,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it ref.begin() -> ref.end() - 1? All dimensions except for one must be 1, right? In that case, it does not matter where non-unit dimension is?

[srcShape](int64_t dim) { return srcShape[dim] == 1; });
if (contiguousSrcDim && dynamicSizeIsPreserved)
resultStrides.push_back(1);
else {
// Dynamically-sized dims may turn out to be dims of size 1 at runtime,
// so the corresponding stride may have to be skipped. (See above
// comment.) Therefore, the result stride cannot be statically
// determined and must be dynamic.
resultStrides.push_back(ShapedType::kDynamic);
}
}
}

Expand Down
8 changes: 8 additions & 0 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,14 @@ func.func @collapse_shape_wrong_collapsed_type(%arg0: memref<?x?x?xf32>) {

// -----

func.func @collapse_shape_infer_stride_of_dynamic_dim(%arg0: memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1>, %dim : index) -> (memref<?xsi32, strided<[?]>, 1>) {
// expected-error @+1 {{expected collapsed type to be 'memref<?xsi32, strided<[1]>, 1>' but found 'memref<?xsi32, strided<[?]>, 1>'}}
%collapse_shape = memref.collapse_shape %arg0 [[0, 1, 2, 3]] : memref<1x1x?x1xsi32, strided<[36960, 4620, 1, 330]>, 1> into memref<?xsi32, strided<[?]>, 1>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add another example where not all of the static source dimensions are 1?

return %collapse_shape : memref<?xsi32, strided<[?]>, 1>
}

// -----

func.func @expand_shape_illegal_static_memref
(%arg0: memref<2x3x20xf32>) -> memref<2x3x2x4x5xf32> {
// expected-error @+1 {{collapsed dim size (20) must equal reassociation group size (40)}}
Expand Down
Loading