Skip to content

Commit

Permalink
[Stablehlo] Enhance broadcast pattern in matmul Ops (llvm#3161)
Browse files Browse the repository at this point in the history
To pass test "MatmulStaticBroadcast_basic" in stablehlo:
```python
class MatmulStaticBroadcast(torch.nn.Module):
    def __init__(self):
        super().__init__()

    @export
    @annotate_args([
        None,
        ([4, 1, 6, 7], torch.float32, True),
        ([8, 1, 5, 7, 6], torch.float32, True),
    ])
    def forward(self, lhs, rhs):
        return torch.matmul(lhs, rhs)


@register_test_case(module_factory=lambda: MatmulStaticBroadcast())
def MatmulStaticBroadcast_basic(module, tu: TestUtils):
    module.forward(tu.rand(4, 1, 6, 7), tu.rand(8, 1, 5, 7, 6))
```
  • Loading branch information
Xinyu Yang authored Apr 16, 2024
1 parent 5e564b5 commit ae47247
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 0 deletions.
69 changes: 69 additions & 0 deletions lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,

auto lhsRank = lhsRankTy.getRank();
auto rhsRank = rhsRankTy.getRank();
int64_t nBatchDims = std::max(lhsRank - 2, rhsRank - 2);

// The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be
// broadcastable).
Expand All @@ -143,6 +144,7 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
llvm::seq<int64_t>(leadingRank, minRank + leadingRank));
auto lhsShape = lhsRankTy.getShape();
auto rhsShape = rhsRankTy.getShape();

if (lhsRank < rhsRank) {
std::vector<int64_t> newShape(rhsShape.begin(),
rhsShape.begin() + leadingRank);
Expand All @@ -169,6 +171,73 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
broadcastDims);
}

if (lhsRank <= 2 || rhsRank <= 2) {
inpLhs = lhs;
inpRhs = rhs;
return;
}

lhsShape = lhs.getType().cast<RankedTensorType>().getShape();
rhsShape = rhs.getType().cast<RankedTensorType>().getShape();

// check shape compatibility, check if we should broadcast
// first, we should got a new batch shape. Check from (0, nBatchDims)
SmallVector<int64_t> lhsBroadcastDims;
SmallVector<int64_t> rhsBroadcastDims;
SmallVector<int64_t> newBatchShape;

for (int64_t i = 0; i < nBatchDims; i++) {
if (lhsShape[i] != rhsShape[i]) {
if (lhsShape[i] == 1) {
lhsBroadcastDims.push_back(i);
newBatchShape.push_back(rhsShape[i]);
} else if (rhsShape[i] == 1) {
rhsBroadcastDims.push_back(i);
newBatchShape.push_back(lhsShape[i]);
} else {
assert(false && "shape mismatch in matmul op");
}
} else {
newBatchShape.push_back(lhsShape[i]);
}
}

if (lhsBroadcastDims.empty() && rhsBroadcastDims.empty()) {
inpLhs = lhs;
inpRhs = rhs;
return;
}

auto lhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits);
auto rhsDimSizes =
*hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits);

if (!lhsBroadcastDims.empty()) {
SmallVector<int64_t> lhsNewShape(newBatchShape);
lhsNewShape.insert(lhsNewShape.end(), lhsShape.begin() + nBatchDims,
lhsShape.end());
for (auto i : lhsBroadcastDims) {
lhsDimSizes[i] = rhsDimSizes[i];
}
broadcastDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, lhsNewShape.size()));
lhs = getBroadcastTensor(rewriter, op, lhs, lhsNewShape, lhsDimSizes,
broadcastDims);
}
if (!rhsBroadcastDims.empty()) {
SmallVector<int64_t> rhsNewShape(newBatchShape);
rhsNewShape.insert(rhsNewShape.end(), rhsShape.begin() + nBatchDims,
rhsShape.end());
for (auto i : rhsBroadcastDims) {
rhsDimSizes[i] = lhsDimSizes[i];
}
broadcastDims =
llvm::to_vector<4>(llvm::seq<int64_t>(0, rhsNewShape.size()));
rhs = getBroadcastTensor(rewriter, op, rhs, rhsNewShape, rhsDimSizes,
broadcastDims);
}

inpLhs = lhs;
inpRhs = rhs;
}
Expand Down
1 change: 1 addition & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@
"Matmul_dot",
"Matmul_matvec",
"Matmul_vecmat",
"MatmulStaticBroadcast_basic",
"MaxPool2dStaticModule_basic",
"MeanDimAllReduceModule_basic",
"MeanDimEmptyDimModule_basic",
Expand Down

0 comments on commit ae47247

Please sign in to comment.