@@ -446,15 +446,15 @@ def QRFactorizationOp: EnzymeXLA_Op<"linalg.qr", [Pure]> {
446446 let summary = "QR factorization operation (high-level op)";
447447
448448 let description = [{
449- This operation computes the QR factorization of a matrix using Householder
450- reflections. Mathematically, it decomposes A into the product of an
451- orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
449+ This operation computes the QR factorization of a matrix using Householder
450+ reflections. Mathematically, it decomposes A into the product of an
451+ orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
452452 such that A = QR.
453453
454454 If A has size m x n and m > n, Q is an m x n isometric matrix. If m < n, R
455455 will be a m x n trapezoidal matrix.
456456
457- This operation is modeled after the mathematical formulation of the QR
457+ This operation is modeled after the mathematical formulation of the QR
458458 factorization, and not after LAPACK's compact formats.
459459 }];
460460
@@ -473,28 +473,49 @@ def QRFactorizationOp: EnzymeXLA_Op<"linalg.qr", [Pure]> {
473473 }];
474474}
475475
476+ def SVDFactorizationOp : EnzymeXLA_Op<"linalg.svd", [Pure]> {
477+ let summary = "Singular Value Decomposition (SVD) factorization operation.";
478+
479+ let arguments = (ins
480+ HLO_Tensor:$input,
481+ DefaultValuedAttr<BoolAttr, "false">:$full,
482+ DefaultValuedAttr<EnzymeXLA_SVDAlgorithmAttr, "::mlir::enzymexla::SVDAlgorithm::DEFAULT">:$algorithm
483+ );
484+
485+ let results = (outs
486+ HLO_Tensor:$U,
487+ HLO_Tensor:$S,
488+ HLO_Tensor:$Vt,
489+ HLO_Tensor:$info
490+ );
491+
492+ let assemblyFormat = [{
493+ $input attr-dict `:` functional-type($input, results)
494+ }];
495+ }
496+
476497def GeqrfOp : EnzymeXLA_Op<"lapack.geqrf", [Pure]> {
477498 let summary = "QR factorization operation (low-level op)";
478499
479500 let description = [{
480- This operation computes the QR factorization of a matrix using Householder
481- reflections. Mathematically, it decomposes A into the product of an
501+ This operation computes the QR factorization of a matrix using Householder
502+ reflections. Mathematically, it decomposes A into the product of an
482503 orthogonal matri x Q and an upper triangular matrix R,
483- such that A = QR.
504+ such that A = QR.
484505
485- This operation is modeled after
486- LAPACK's *GEQRF routines, which returns the result in
487- the QR packed format.
488- }];
506+ This operation is modeled after LAPACK's *GEQRF routines, which returns the
507+ result in the QR packed format.
508+ }];
489509
490- let arguments = (ins HLO_Tensor : $input);
510+ let arguments = (ins HLO_Tensor : $input);
491511
492- let results = (outs HLO_Tensor
493- : $output, HLO_Tensor
494- : $tau, HLO_Tensor
495- : $info);
512+ let results = (outs
513+ HLO_Tensor:$output,
514+ HLO_Tensor:$tau,
515+ HLO_Tensor:$info
516+ );
496517
497- let assemblyFormat = [{
518+ let assemblyFormat = [{
498519 $input attr-dict `:` functional-type($input, results)
499520 }];
500521}
@@ -503,11 +524,11 @@ def GeqrtOp : EnzymeXLA_Op<"lapack.geqrt", [Pure]> {
503524 let summary = "QR factorization operation (low-level op)";
504525
505526 let description = [{
506- This operation computes the QR factorization of a matrix using Householder
507- reflections. Mathematically, it decomposes A into the product of an
527+ This operation computes the QR factorization of a matrix using Householder
528+ reflections. Mathematically, it decomposes A into the product of an
508529 orthogonal matrix Q and an upper triangular matrix R, such that A = QR.
509530
510- This operation is modeled after LAPACK's *GEQRT routines, which returns the
531+ This operation is modeled after LAPACK's *GEQRT routines, which returns the
511532 result in the QR CompactWY format.
512533 }];
513534
@@ -605,7 +626,81 @@ def GemqrtOp : EnzymeXLA_Op<"lapack.gemqrt", [Pure]> {
605626 }];
606627}
607628
608- def SVDFactorizationOp : EnzymeXLA_Op<"linalg.svd", [Pure]> {
629+ def GetrfOp : EnzymeXLA_Op<"lapack.getrf", [Pure]> {
630+ let summary = "LU factorization operation with row-major pivoting.";
631+
632+ let arguments = (ins HLO_Tensor:$input);
633+
634+ let results = (outs
635+ HLO_Tensor:$output,
636+ HLO_Tensor:$pivots,
637+ HLO_Tensor:$permutation,
638+ HLO_Tensor:$info
639+ );
640+
641+ let assemblyFormat = [{
642+ $input attr-dict `:` functional-type($input, results)
643+ }];
644+ }
645+
646+ def GetriOp : EnzymeXLA_Op<"lapack.getri", [Pure]> {
647+ let summary = "Computes the inverse of a matrix using the LU factorization.";
648+
649+ let arguments = (ins
650+ HLO_Tensor:$input,
651+ HLO_Tensor:$ipiv
652+ );
653+
654+ let results = (outs
655+ HLO_Tensor:$output
656+ );
657+
658+ let assemblyFormat = [{
659+ $input `,` $ipiv attr-dict `:` functional-type(operands, results)
660+ }];
661+ }
662+
663+ def GesddOp : EnzymeXLA_Op<"lapack.gesdd", [Pure]> {
664+ let summary = "Singular Value Decomposition (SVD) factorization operation.";
665+
666+ let arguments = (ins
667+ HLO_Tensor:$input,
668+ DefaultValuedAttr<BoolAttr, "false">:$full
669+ );
670+
671+ let results = (outs
672+ HLO_Tensor:$U,
673+ HLO_Tensor:$S,
674+ HLO_Tensor:$Vt,
675+ HLO_Tensor:$info
676+ );
677+
678+ let assemblyFormat = [{
679+ $input attr-dict `:` functional-type($input, results)
680+ }];
681+ }
682+
683+ def GesvdOp : EnzymeXLA_Op<"lapack.gesvd", [Pure]> {
684+ let summary = "Singular Value Decomposition (SVD) factorization operation.";
685+
686+ let arguments = (ins
687+ HLO_Tensor:$input,
688+ DefaultValuedAttr<BoolAttr, "false">:$full
689+ );
690+
691+ let results = (outs
692+ HLO_Tensor:$U,
693+ HLO_Tensor:$S,
694+ HLO_Tensor:$Vt,
695+ HLO_Tensor:$info
696+ );
697+
698+ let assemblyFormat = [{
699+ $input attr-dict `:` functional-type($input, results)
700+ }];
701+ }
702+
703+ def GesvjOp : EnzymeXLA_Op<"lapack.gesvj", [Pure]> {
609704 let summary = "Singular Value Decomposition (SVD) factorization operation.";
610705
611706 let arguments = (ins
0 commit comments