Skip to content

Commit

Permalink
fix reshapeStrides bug (#2553)
Browse files Browse the repository at this point in the history
Signed-off-by: Soren Lassen <[email protected]>
  • Loading branch information
sorenlassen authored Oct 9, 2023
1 parent 7cacebc commit 1ccfbfc
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
14 changes: 9 additions & 5 deletions src/Dialect/ONNX/ElementsAttr/Strides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,22 @@ std::optional<SmallVector<int64_t, 4>> reshapeStrides(ArrayRef<int64_t> shape,
++a1;
}
} while (a1 < rank1 && last == shape[a1] * strides[a1]);
assert(total == n * strides[a1 - 1]);
assert(total == n * last);
// Add contiguous strides for axes in reshapedShape with dimSizes product n.
int64_t n2 = 1;
while (a2 < rank2 && n2 * reshapedShape[a2] <= n) {
n2 *= reshapedShape[a2];
total /= reshapedShape[a2];
reshapedStrides.push_back(total);
if (reshapedShape[a2] == 1) {
reshapedStrides.push_back(0);
} else {
n2 *= reshapedShape[a2];
total /= reshapedShape[a2];
reshapedStrides.push_back(total);
}
++a2;
}
if (n2 < n)
return std::nullopt;
assert(strides[a1 - 1] == reshapedStrides[a2 - 1]);
assert(last == total);
} while (a1 < rank1);
assert(a2 == rank2);
assert(a2 == reshapedStrides.size());
Expand Down
16 changes: 16 additions & 0 deletions test/unit/Strides/TestStrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,21 @@ class Test {
return 0;
}

// This example triggered a bug.
int test_reshapeStrides_unsqueeze_last() {
std::cout << "test_reshapeStrides_unsqueeze_last:" << std::endl;

SmallVector<int64_t, 4> shape{1, 2, 124, 1};
SmallVector<int64_t, 4> strides{0, 0, 1, 0};

SmallVector<int64_t, 4> reshapedShape{1, 2, 124, 1, 1};
SmallVector<int64_t, 4> expectedReshapedStrides{0, 0, 1, 0, 0};
auto reshapedStrides = reshapeStrides(shape, strides, reshapedShape);
assert(reshapedStrides == expectedReshapedStrides);

return 0;
}

int test_reshapeStrides_failure() {
std::cout << "test_reshapeStrides_failure:" << std::endl;

Expand All @@ -60,6 +75,7 @@ int main(int argc, char *argv[]) {
Test test;
int failures = 0;
failures += test.test_reshapeStrides_success();
failures += test.test_reshapeStrides_unsqueeze_last();
failures += test.test_reshapeStrides_failure();
if (failures != 0) {
std::cerr << failures << " test failures\n";
Expand Down

0 comments on commit 1ccfbfc

Please sign in to comment.