diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index b1749ee1c074..f95184833841 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -134,6 +134,7 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, auto lhsRank = lhsRankTy.getRank(); auto rhsRank = rhsRankTy.getRank(); + int64_t nBatchDims = std::max(lhsRank - 2, rhsRank - 2); // The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be // broadcastable). @@ -143,6 +144,7 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, llvm::seq(leadingRank, minRank + leadingRank)); auto lhsShape = lhsRankTy.getShape(); auto rhsShape = rhsRankTy.getShape(); + if (lhsRank < rhsRank) { std::vector newShape(rhsShape.begin(), rhsShape.begin() + leadingRank); @@ -169,6 +171,73 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, broadcastDims); } + if (lhsRank <= 2 || rhsRank <= 2) { + inpLhs = lhs; + inpRhs = rhs; + return; + } + + lhsShape = lhs.getType().cast().getShape(); + rhsShape = rhs.getType().cast().getShape(); + + // check shape compatibility, check if we should broadcast + // first, we should got a new batch shape. Check from (0, nBatchDims) + SmallVector lhsBroadcastDims; + SmallVector rhsBroadcastDims; + SmallVector newBatchShape; + + for (int64_t i = 0; i < nBatchDims; i++) { + if (lhsShape[i] != rhsShape[i]) { + if (lhsShape[i] == 1) { + lhsBroadcastDims.push_back(i); + newBatchShape.push_back(rhsShape[i]); + } else if (rhsShape[i] == 1) { + rhsBroadcastDims.push_back(i); + newBatchShape.push_back(lhsShape[i]); + } else { + assert(false && "shape mismatch in matmul op"); + } + } else { + newBatchShape.push_back(lhsShape[i]); + } + } + + if (lhsBroadcastDims.empty() && rhsBroadcastDims.empty()) { + inpLhs = lhs; + inpRhs = rhs; + return; + } + + auto lhsDimSizes = + *hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); + auto rhsDimSizes = + *hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); + + if (!lhsBroadcastDims.empty()) { + SmallVector lhsNewShape(newBatchShape); + lhsNewShape.insert(lhsNewShape.end(), lhsShape.begin() + nBatchDims, + lhsShape.end()); + for (auto i : lhsBroadcastDims) { + lhsDimSizes[i] = rhsDimSizes[i]; + } + broadcastDims = + llvm::to_vector<4>(llvm::seq(0, lhsNewShape.size())); + lhs = getBroadcastTensor(rewriter, op, lhs, lhsNewShape, lhsDimSizes, + broadcastDims); + } + if (!rhsBroadcastDims.empty()) { + SmallVector rhsNewShape(newBatchShape); + rhsNewShape.insert(rhsNewShape.end(), rhsShape.begin() + nBatchDims, + rhsShape.end()); + for (auto i : rhsBroadcastDims) { + rhsDimSizes[i] = lhsDimSizes[i]; + } + broadcastDims = + llvm::to_vector<4>(llvm::seq(0, rhsNewShape.size())); + rhs = getBroadcastTensor(rewriter, op, rhs, rhsNewShape, rhsDimSizes, + broadcastDims); + } + inpLhs = lhs; inpRhs = rhs; } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0f9bca677431..e6cd80e3a275 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -913,6 +913,7 @@ "Matmul_dot", "Matmul_matvec", "Matmul_vecmat", + "MatmulStaticBroadcast_basic", "MaxPool2dStaticModule_basic", "MeanDimAllReduceModule_basic", "MeanDimEmptyDimModule_basic",