@@ -1283,6 +1283,48 @@ struct GemqrtOpLowering : public OpRewritePattern<enzymexla::GemqrtOp> {
12831283 }
12841284};
12851285
1286+ Value anyNonFiniteValue (PatternRewriter &rewriter, Location loc, Type outType,
1287+ Value input, int64_t inputRank) {
1288+ auto areFinite = stablehlo::AndOp::create (
1289+ rewriter, loc,
1290+ stablehlo::IsFiniteOp::create (
1291+ rewriter, loc, stablehlo::RealOp::create (rewriter, loc, input)),
1292+ stablehlo::IsFiniteOp::create (
1293+ rewriter, loc, stablehlo::ImagOp::create (rewriter, loc, input)));
1294+
1295+ SmallVector<int64_t > reductionDims;
1296+ for (int i = inputRank - 2 ; i < inputRank; i++)
1297+ reductionDims.push_back (i);
1298+
1299+ auto initValType = RankedTensorType::get ({}, rewriter.getI1Type ());
1300+ auto initVal = stablehlo::ConstantOp::create (
1301+ rewriter, loc, initValType, cast<ElementsAttr>(makeAttr (initValType, 1 )));
1302+
1303+ auto allFinite = stablehlo::ReduceOp::create (
1304+ rewriter, loc, ValueRange{areFinite.getResult ()}, ValueRange{initVal},
1305+ rewriter.getDenseI64ArrayAttr (reductionDims));
1306+
1307+ {
1308+ OpBuilder::InsertionGuard guard (rewriter);
1309+ auto ®ion = allFinite.getBody ();
1310+ auto *block = rewriter.createBlock (®ion, {}, {initValType, initValType},
1311+ {loc, loc});
1312+
1313+ rewriter.setInsertionPointToStart (block);
1314+ stablehlo::ReturnOp::create (
1315+ rewriter, loc,
1316+ ValueRange{stablehlo::AndOp::create (rewriter, loc,
1317+ block->getArgument (0 ),
1318+ block->getArgument (1 ))
1319+ .getResult ()});
1320+ }
1321+
1322+ // 0 is success, 1 is failure
1323+ return stablehlo::ConvertOp::create (
1324+ rewriter, loc, outType,
1325+ stablehlo::NotOp::create (rewriter, loc, allFinite.getResult (0 )));
1326+ }
1327+
12861328struct GetrfOpLowering : public OpRewritePattern <enzymexla::GetrfOp> {
12871329 std::string backend;
12881330 int64_t blasIntWidth;
@@ -1307,18 +1349,78 @@ struct GetrfOpLowering : public OpRewritePattern<enzymexla::GetrfOp> {
13071349
13081350private:
13091351 LogicalResult matchAndRewriteCPU (enzymexla::GetrfOp op,
1310- PatternRewriter &rewriter) const override {
1352+ PatternRewriter &rewriter) const {
13111353 return failure ();
13121354 }
13131355
13141356 LogicalResult matchAndRewriteCUDA (enzymexla::GetrfOp op,
1315- PatternRewriter &rewriter) const override {
1357+ PatternRewriter &rewriter) const {
13161358 return failure ();
13171359 }
13181360
13191361 LogicalResult matchAndRewriteTPU (enzymexla::GetrfOp op,
1320- PatternRewriter &rewriter) const override {
1321- return failure ();
1362+ PatternRewriter &rewriter) const {
1363+ auto input = op.getInput ();
1364+
1365+ auto inputType = cast<RankedTensorType>(input.getType ());
1366+ auto pivotType = cast<RankedTensorType>(op.getResult (1 ).getType ());
1367+ auto infoType = cast<RankedTensorType>(op.getResult (3 ).getType ());
1368+
1369+ auto inputShape = inputType.getShape ();
1370+ auto inputRank = inputType.getRank ();
1371+
1372+ SmallVector<int64_t > permutationShape (inputShape.begin (),
1373+ inputShape.end () - 2 );
1374+ permutationShape.push_back (inputShape[inputRank - 2 ]);
1375+ auto permutationType =
1376+ RankedTensorType::get (permutationShape, rewriter.getI32Type ());
1377+
1378+ auto pivotTPUType =
1379+ RankedTensorType::get (pivotType.getShape (), rewriter.getI32Type ());
1380+
1381+ // TPU returns (LU, pivots, permutation). info isn't returned. based on
1382+ // how JAX operates, I am assuming info != 0 when there is a nan in the
1383+ // output.
1384+ auto customCall = stablehlo::CustomCallOp::create (
1385+ rewriter, op.getLoc (),
1386+ TypeRange{inputType, pivotTPUType, permutationType}, ValueRange{input},
1387+ rewriter.getStringAttr (" LuDecomposition" ),
1388+ /* has_side_effect*/ nullptr ,
1389+ /* backend_config*/ nullptr ,
1390+ /* api_version*/ nullptr ,
1391+ /* calledcomputations*/ nullptr ,
1392+ /* operand_layouts*/ nullptr ,
1393+ /* result_layouts*/ nullptr ,
1394+ /* output_operand_aliases*/ nullptr );
1395+
1396+ // LAPACK returns 1-indexed pivots, while XLA returns 0-indexed pivots.
1397+ // We make it consistent with LAPACK by adding 1 to the pivots.
1398+ auto pivots1Indexed = stablehlo::AddOp::create (
1399+ rewriter, op.getLoc (),
1400+ stablehlo::ConstantOp::create (
1401+ rewriter, op.getLoc (), pivotType,
1402+ cast<ElementsAttr>(makeAttr (pivotType, 1 ))),
1403+ stablehlo::ConvertOp::create (rewriter, op.getLoc (), pivotType,
1404+ customCall.getResult (1 )));
1405+
1406+ auto permutation1Indexed = stablehlo::AddOp::create (
1407+ rewriter, op.getLoc (),
1408+ stablehlo::ConstantOp::create (
1409+ rewriter, op.getLoc (), permutationType,
1410+ cast<ElementsAttr>(makeAttr (permutationType, 1 ))),
1411+ stablehlo::ConvertOp::create (rewriter, op.getLoc (), permutationType,
1412+ customCall.getResult (2 )));
1413+
1414+ auto info = anyNonFiniteValue (rewriter, op.getLoc (), infoType,
1415+ customCall.getResult (0 ), inputRank);
1416+
1417+ rewriter.replaceAllUsesWith (op.getResult (0 ), customCall.getResult (0 ));
1418+ rewriter.replaceAllUsesWith (op.getResult (1 ), pivots1Indexed);
1419+ rewriter.replaceAllUsesWith (op.getResult (2 ), permutation1Indexed);
1420+ rewriter.replaceAllUsesWith (op.getResult (3 ), info);
1421+ rewriter.eraseOp (op);
1422+
1423+ return success ();
13221424 }
13231425};
13241426
0 commit comments