Skip to content

Commit bceefd1

Browse files
committed
feat: partial lowering for getrf
1 parent 96c23a4 commit bceefd1

File tree

1 file changed

+106
-4
lines changed

1 file changed

+106
-4
lines changed

src/enzyme_ad/jax/Passes/LowerEnzymeXLALapack.cpp

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,6 +1283,48 @@ struct GemqrtOpLowering : public OpRewritePattern<enzymexla::GemqrtOp> {
12831283
}
12841284
};
12851285

1286+
Value anyNonFiniteValue(PatternRewriter &rewriter, Location loc, Type outType,
1287+
Value input, int64_t inputRank) {
1288+
auto areFinite = stablehlo::AndOp::create(
1289+
rewriter, loc,
1290+
stablehlo::IsFiniteOp::create(
1291+
rewriter, loc, stablehlo::RealOp::create(rewriter, loc, input)),
1292+
stablehlo::IsFiniteOp::create(
1293+
rewriter, loc, stablehlo::ImagOp::create(rewriter, loc, input)));
1294+
1295+
SmallVector<int64_t> reductionDims;
1296+
for (int i = inputRank - 2; i < inputRank; i++)
1297+
reductionDims.push_back(i);
1298+
1299+
auto initValType = RankedTensorType::get({}, rewriter.getI1Type());
1300+
auto initVal = stablehlo::ConstantOp::create(
1301+
rewriter, loc, initValType, cast<ElementsAttr>(makeAttr(initValType, 1)));
1302+
1303+
auto allFinite = stablehlo::ReduceOp::create(
1304+
rewriter, loc, ValueRange{areFinite.getResult()}, ValueRange{initVal},
1305+
rewriter.getDenseI64ArrayAttr(reductionDims));
1306+
1307+
{
1308+
OpBuilder::InsertionGuard guard(rewriter);
1309+
auto &region = allFinite.getBody();
1310+
auto *block = rewriter.createBlock(&region, {}, {initValType, initValType},
1311+
{loc, loc});
1312+
1313+
rewriter.setInsertionPointToStart(block);
1314+
stablehlo::ReturnOp::create(
1315+
rewriter, loc,
1316+
ValueRange{stablehlo::AndOp::create(rewriter, loc,
1317+
block->getArgument(0),
1318+
block->getArgument(1))
1319+
.getResult()});
1320+
}
1321+
1322+
// 0 is success, 1 is failure
1323+
return stablehlo::ConvertOp::create(
1324+
rewriter, loc, outType,
1325+
stablehlo::NotOp::create(rewriter, loc, allFinite.getResult(0)));
1326+
}
1327+
12861328
struct GetrfOpLowering : public OpRewritePattern<enzymexla::GetrfOp> {
12871329
std::string backend;
12881330
int64_t blasIntWidth;
@@ -1307,18 +1349,78 @@ struct GetrfOpLowering : public OpRewritePattern<enzymexla::GetrfOp> {
13071349

13081350
private:
13091351
LogicalResult matchAndRewriteCPU(enzymexla::GetrfOp op,
1310-
PatternRewriter &rewriter) const override {
1352+
PatternRewriter &rewriter) const {
13111353
return failure();
13121354
}
13131355

13141356
LogicalResult matchAndRewriteCUDA(enzymexla::GetrfOp op,
1315-
PatternRewriter &rewriter) const override {
1357+
PatternRewriter &rewriter) const {
13161358
return failure();
13171359
}
13181360

13191361
LogicalResult matchAndRewriteTPU(enzymexla::GetrfOp op,
1320-
PatternRewriter &rewriter) const override {
1321-
return failure();
1362+
PatternRewriter &rewriter) const {
1363+
auto input = op.getInput();
1364+
1365+
auto inputType = cast<RankedTensorType>(input.getType());
1366+
auto pivotType = cast<RankedTensorType>(op.getResult(1).getType());
1367+
auto infoType = cast<RankedTensorType>(op.getResult(3).getType());
1368+
1369+
auto inputShape = inputType.getShape();
1370+
auto inputRank = inputType.getRank();
1371+
1372+
SmallVector<int64_t> permutationShape(inputShape.begin(),
1373+
inputShape.end() - 2);
1374+
permutationShape.push_back(inputShape[inputRank - 2]);
1375+
auto permutationType =
1376+
RankedTensorType::get(permutationShape, rewriter.getI32Type());
1377+
1378+
auto pivotTPUType =
1379+
RankedTensorType::get(pivotType.getShape(), rewriter.getI32Type());
1380+
1381+
// TPU returns (LU, pivots, permutation). info isn't returned. based on
1382+
// how JAX operates, I am assuming info != 0 when there is a nan in the
1383+
// output.
1384+
auto customCall = stablehlo::CustomCallOp::create(
1385+
rewriter, op.getLoc(),
1386+
TypeRange{inputType, pivotTPUType, permutationType}, ValueRange{input},
1387+
rewriter.getStringAttr("LuDecomposition"),
1388+
/*has_side_effect*/ nullptr,
1389+
/*backend_config*/ nullptr,
1390+
/*api_version*/ nullptr,
1391+
/*calledcomputations*/ nullptr,
1392+
/*operand_layouts*/ nullptr,
1393+
/*result_layouts*/ nullptr,
1394+
/*output_operand_aliases*/ nullptr);
1395+
1396+
// LAPACK returns 1-indexed pivots, while XLA returns 0-indexed pivots.
1397+
// We make it consistent with LAPACK by adding 1 to the pivots.
1398+
auto pivots1Indexed = stablehlo::AddOp::create(
1399+
rewriter, op.getLoc(),
1400+
stablehlo::ConstantOp::create(
1401+
rewriter, op.getLoc(), pivotType,
1402+
cast<ElementsAttr>(makeAttr(pivotType, 1))),
1403+
stablehlo::ConvertOp::create(rewriter, op.getLoc(), pivotType,
1404+
customCall.getResult(1)));
1405+
1406+
auto permutation1Indexed = stablehlo::AddOp::create(
1407+
rewriter, op.getLoc(),
1408+
stablehlo::ConstantOp::create(
1409+
rewriter, op.getLoc(), permutationType,
1410+
cast<ElementsAttr>(makeAttr(permutationType, 1))),
1411+
stablehlo::ConvertOp::create(rewriter, op.getLoc(), permutationType,
1412+
customCall.getResult(2)));
1413+
1414+
auto info = anyNonFiniteValue(rewriter, op.getLoc(), infoType,
1415+
customCall.getResult(0), inputRank);
1416+
1417+
rewriter.replaceAllUsesWith(op.getResult(0), customCall.getResult(0));
1418+
rewriter.replaceAllUsesWith(op.getResult(1), pivots1Indexed);
1419+
rewriter.replaceAllUsesWith(op.getResult(2), permutation1Indexed);
1420+
rewriter.replaceAllUsesWith(op.getResult(3), info);
1421+
rewriter.eraseOp(op);
1422+
1423+
return success();
13221424
}
13231425
};
13241426

0 commit comments

Comments
 (0)