Skip to content

Commit b55ffd4

Browse files
authored
feat: progressive lowering of LU and SVD via lapack ops (#1625)
* feat: more lapack operations for LU/SVD * fix: lower linalg ops to lapack * chore: stub code for lowering * feat: lowering for getrf * feat: svd lapack op lowering * feat: cpu lowering * fix: missing impl in header * fix: export names * fix: update workspace buffer correctly * test: update all tests * feat: support batching
1 parent 8c2e29e commit b55ffd4

File tree

12 files changed

+1716
-1214
lines changed

12 files changed

+1716
-1214
lines changed

src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,24 @@ def EnzymeXLA_QrAlgorithmAttr : EnumAttr<EnzymeXLA_Dialect,
8282
let assemblyFormat = "`<` $value `>`";
8383
}
8484

85+
def EnzymeXLA_SVDAlgorithm : I32EnumAttr<"SVDAlgorithm",
86+
"Algorithm to use for the SVD factorization",
87+
[
88+
I32EnumAttrCase<"DEFAULT", 0>,
89+
I32EnumAttrCase<"QRIteration", 1>,
90+
I32EnumAttrCase<"DivideAndConquer", 2>,
91+
I32EnumAttrCase<"Jacobi", 3>
92+
]> {
93+
let genSpecializedAttr = 0;
94+
let cppNamespace = "::mlir::enzymexla";
95+
}
96+
97+
def EnzymeXLA_SVDAlgorithmAttr : EnumAttr<EnzymeXLA_Dialect,
98+
EnzymeXLA_SVDAlgorithm, "svd_algorithm"> {
99+
let assemblyFormat = "`<` $value `>`";
100+
let cppType = "::mlir::enzymexla::SVDAlgorithmAttr";
101+
}
102+
85103
// Machine Learning
86104

87105
def EnzymeXLA_GeluApproximation : I32EnumAttr<"GeluApproximation",

src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td

Lines changed: 116 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -446,15 +446,15 @@ def QRFactorizationOp: EnzymeXLA_Op<"linalg.qr", [Pure]> {
446446
let summary = "QR factorization operation (high-level op)";
447447

448448
let description = [{
449-
This operation computes the QR factorization of a matrix using Householder
450-
reflections. Mathematically, it decomposes A into the product of an
451-
orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
449+
This operation computes the QR factorization of a matrix using Householder
450+
reflections. Mathematically, it decomposes A into the product of an
451+
orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
452452
such that A = QR.
453453

454454
If A has size m x n and m > n, Q is an m x n isometric matrix. If m < n, R
455455
will be a m x n trapezoidal matrix.
456456

457-
This operation is modeled after the mathematical formulation of the QR
457+
This operation is modeled after the mathematical formulation of the QR
458458
factorization, and not after LAPACK's compact formats.
459459
}];
460460

@@ -473,28 +473,49 @@ def QRFactorizationOp: EnzymeXLA_Op<"linalg.qr", [Pure]> {
473473
}];
474474
}
475475

476+
def SVDFactorizationOp : EnzymeXLA_Op<"linalg.svd", [Pure]> {
477+
let summary = "Singular Value Decomposition (SVD) factorization operation.";
478+
479+
let arguments = (ins
480+
HLO_Tensor:$input,
481+
DefaultValuedAttr<BoolAttr, "false">:$full,
482+
DefaultValuedAttr<EnzymeXLA_SVDAlgorithmAttr, "::mlir::enzymexla::SVDAlgorithm::DEFAULT">:$algorithm
483+
);
484+
485+
let results = (outs
486+
HLO_Tensor:$U,
487+
HLO_Tensor:$S,
488+
HLO_Tensor:$Vt,
489+
HLO_Tensor:$info
490+
);
491+
492+
let assemblyFormat = [{
493+
$input attr-dict `:` functional-type($input, results)
494+
}];
495+
}
496+
476497
def GeqrfOp : EnzymeXLA_Op<"lapack.geqrf", [Pure]> {
477498
let summary = "QR factorization operation (low-level op)";
478499

479500
let description = [{
480-
This operation computes the QR factorization of a matrix using Householder
481-
reflections. Mathematically, it decomposes A into the product of an
501+
This operation computes the QR factorization of a matrix using Householder
502+
reflections. Mathematically, it decomposes A into the product of an
482503
orthogonal matri x Q and an upper triangular matrix R,
483-
such that A = QR.
504+
such that A = QR.
484505

485-
This operation is modeled after
486-
LAPACK's *GEQRF routines, which returns the result in
487-
the QR packed format.
488-
}];
506+
This operation is modeled after LAPACK's *GEQRF routines, which returns the
507+
result in the QR packed format.
508+
}];
489509

490-
let arguments = (ins HLO_Tensor : $input);
510+
let arguments = (ins HLO_Tensor : $input);
491511

492-
let results = (outs HLO_Tensor
493-
: $output, HLO_Tensor
494-
: $tau, HLO_Tensor
495-
: $info);
512+
let results = (outs
513+
HLO_Tensor:$output,
514+
HLO_Tensor:$tau,
515+
HLO_Tensor:$info
516+
);
496517

497-
let assemblyFormat = [{
518+
let assemblyFormat = [{
498519
$input attr-dict `:` functional-type($input, results)
499520
}];
500521
}
@@ -503,11 +524,11 @@ def GeqrtOp : EnzymeXLA_Op<"lapack.geqrt", [Pure]> {
503524
let summary = "QR factorization operation (low-level op)";
504525

505526
let description = [{
506-
This operation computes the QR factorization of a matrix using Householder
507-
reflections. Mathematically, it decomposes A into the product of an
527+
This operation computes the QR factorization of a matrix using Householder
528+
reflections. Mathematically, it decomposes A into the product of an
508529
orthogonal matrix Q and an upper triangular matrix R, such that A = QR.
509530

510-
This operation is modeled after LAPACK's *GEQRT routines, which returns the
531+
This operation is modeled after LAPACK's *GEQRT routines, which returns the
511532
result in the QR CompactWY format.
512533
}];
513534

@@ -605,7 +626,81 @@ def GemqrtOp : EnzymeXLA_Op<"lapack.gemqrt", [Pure]> {
605626
}];
606627
}
607628

608-
def SVDFactorizationOp : EnzymeXLA_Op<"linalg.svd", [Pure]> {
629+
def GetrfOp : EnzymeXLA_Op<"lapack.getrf", [Pure]> {
630+
let summary = "LU factorization operation with row-major pivoting.";
631+
632+
let arguments = (ins HLO_Tensor:$input);
633+
634+
let results = (outs
635+
HLO_Tensor:$output,
636+
HLO_Tensor:$pivots,
637+
HLO_Tensor:$permutation,
638+
HLO_Tensor:$info
639+
);
640+
641+
let assemblyFormat = [{
642+
$input attr-dict `:` functional-type($input, results)
643+
}];
644+
}
645+
646+
def GetriOp : EnzymeXLA_Op<"lapack.getri", [Pure]> {
647+
let summary = "Computes the inverse of a matrix using the LU factorization.";
648+
649+
let arguments = (ins
650+
HLO_Tensor:$input,
651+
HLO_Tensor:$ipiv
652+
);
653+
654+
let results = (outs
655+
HLO_Tensor:$output
656+
);
657+
658+
let assemblyFormat = [{
659+
$input `,` $ipiv attr-dict `:` functional-type(operands, results)
660+
}];
661+
}
662+
663+
def GesddOp : EnzymeXLA_Op<"lapack.gesdd", [Pure]> {
664+
let summary = "Singular Value Decomposition (SVD) factorization operation.";
665+
666+
let arguments = (ins
667+
HLO_Tensor:$input,
668+
DefaultValuedAttr<BoolAttr, "false">:$full
669+
);
670+
671+
let results = (outs
672+
HLO_Tensor:$U,
673+
HLO_Tensor:$S,
674+
HLO_Tensor:$Vt,
675+
HLO_Tensor:$info
676+
);
677+
678+
let assemblyFormat = [{
679+
$input attr-dict `:` functional-type($input, results)
680+
}];
681+
}
682+
683+
def GesvdOp : EnzymeXLA_Op<"lapack.gesvd", [Pure]> {
684+
let summary = "Singular Value Decomposition (SVD) factorization operation.";
685+
686+
let arguments = (ins
687+
HLO_Tensor:$input,
688+
DefaultValuedAttr<BoolAttr, "false">:$full
689+
);
690+
691+
let results = (outs
692+
HLO_Tensor:$U,
693+
HLO_Tensor:$S,
694+
HLO_Tensor:$Vt,
695+
HLO_Tensor:$info
696+
);
697+
698+
let assemblyFormat = [{
699+
$input attr-dict `:` functional-type($input, results)
700+
}];
701+
}
702+
703+
def GesvjOp : EnzymeXLA_Op<"lapack.gesvj", [Pure]> {
609704
let summary = "Singular Value Decomposition (SVD) factorization operation.";
610705

611706
let arguments = (ins

src/enzyme_ad/jax/Integrations/c/EnzymeXLA.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,25 @@ MlirAttribute enzymexlaQRAlgorithmAttrGet(MlirContext ctx, int32_t mode) {
6464
return wrap(mlir::enzymexla::QrAlgorithmAttr::get(unwrap(ctx), algorithm));
6565
}
6666

67+
MlirAttribute enzymexlaSVDAlgorithmAttrGet(MlirContext ctx, int32_t mode) {
68+
mlir::enzymexla::SVDAlgorithm algorithm;
69+
switch (mode) {
70+
case 0:
71+
algorithm = mlir::enzymexla::SVDAlgorithm::DEFAULT;
72+
break;
73+
case 1:
74+
algorithm = mlir::enzymexla::SVDAlgorithm::QRIteration;
75+
break;
76+
case 2:
77+
algorithm = mlir::enzymexla::SVDAlgorithm::DivideAndConquer;
78+
break;
79+
case 3:
80+
algorithm = mlir::enzymexla::SVDAlgorithm::Jacobi;
81+
break;
82+
}
83+
return wrap(mlir::enzymexla::SVDAlgorithmAttr::get(unwrap(ctx), algorithm));
84+
}
85+
6786
MlirAttribute enzymexlaGeluApproximationAttrGet(MlirContext ctx, int32_t mode) {
6887
mlir::enzymexla::GeluApproximation approximation;
6988
switch (mode) {

src/enzyme_ad/jax/Integrations/c/EnzymeXLA.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ MLIR_CAPI_EXPORTED MlirAttribute enzymexlaLapackUploAttrGet(MlirContext ctx,
3333
MLIR_CAPI_EXPORTED MlirAttribute enzymexlaQRAlgorithmAttrGet(MlirContext ctx,
3434
int32_t mode);
3535

36+
MLIR_CAPI_EXPORTED MlirAttribute enzymexlaSVDAlgorithmAttrGet(MlirContext ctx,
37+
int32_t mode);
38+
3639
//===----------------------------------------------------------------------===//
3740
// Machine Learning Ops
3841
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)