From 1ccfbfc7fe4330ecfe5b9d74b9ecde824af651a3 Mon Sep 17 00:00:00 2001 From: Soren Lassen Date: Sun, 8 Oct 2023 21:04:26 -0700 Subject: [PATCH] fix reshapeStrides bug (#2553) Signed-off-by: Soren Lassen --- src/Dialect/ONNX/ElementsAttr/Strides.cpp | 14 +++++++++----- test/unit/Strides/TestStrides.cpp | 16 ++++++++++++++++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/src/Dialect/ONNX/ElementsAttr/Strides.cpp b/src/Dialect/ONNX/ElementsAttr/Strides.cpp index 11bcb0dfaf..440b78b4a0 100644 --- a/src/Dialect/ONNX/ElementsAttr/Strides.cpp +++ b/src/Dialect/ONNX/ElementsAttr/Strides.cpp @@ -112,18 +112,22 @@ std::optional> reshapeStrides(ArrayRef shape, ++a1; } } while (a1 < rank1 && last == shape[a1] * strides[a1]); - assert(total == n * strides[a1 - 1]); + assert(total == n * last); // Add contiguous strides for axes in reshapedShape with dimSizes product n. int64_t n2 = 1; while (a2 < rank2 && n2 * reshapedShape[a2] <= n) { - n2 *= reshapedShape[a2]; - total /= reshapedShape[a2]; - reshapedStrides.push_back(total); + if (reshapedShape[a2] == 1) { + reshapedStrides.push_back(0); + } else { + n2 *= reshapedShape[a2]; + total /= reshapedShape[a2]; + reshapedStrides.push_back(total); + } ++a2; } if (n2 < n) return std::nullopt; - assert(strides[a1 - 1] == reshapedStrides[a2 - 1]); + assert(last == total); } while (a1 < rank1); assert(a2 == rank2); assert(a2 == reshapedStrides.size()); diff --git a/test/unit/Strides/TestStrides.cpp b/test/unit/Strides/TestStrides.cpp index 5c9785485f..2e75f24b5a 100644 --- a/test/unit/Strides/TestStrides.cpp +++ b/test/unit/Strides/TestStrides.cpp @@ -35,6 +35,21 @@ class Test { return 0; } + // This example triggered a bug. + int test_reshapeStrides_unsqueeze_last() { + std::cout << "test_reshapeStrides_unsqueeze_last:" << std::endl; + + SmallVector shape{1, 2, 124, 1}; + SmallVector strides{0, 0, 1, 0}; + + SmallVector reshapedShape{1, 2, 124, 1, 1}; + SmallVector expectedReshapedStrides{0, 0, 1, 0, 0}; + auto reshapedStrides = reshapeStrides(shape, strides, reshapedShape); + assert(reshapedStrides == expectedReshapedStrides); + + return 0; + } + int test_reshapeStrides_failure() { std::cout << "test_reshapeStrides_failure:" << std::endl; @@ -60,6 +75,7 @@ int main(int argc, char *argv[]) { Test test; int failures = 0; failures += test.test_reshapeStrides_success(); + failures += test.test_reshapeStrides_unsqueeze_last(); failures += test.test_reshapeStrides_failure(); if (failures != 0) { std::cerr << failures << " test failures\n";