diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td index 25a3e3bb5f..d2e53d1662 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td @@ -52,6 +52,23 @@ def EnzymeXLA_LapackSideAttr : EnumAttr, + I32EnumAttrCase<"all", 1>, + I32EnumAttrCase<"overwrite", 2>, + I32EnumAttrCase<"some", 3> + ]> { + let genSpecializedAttr = 0; + let cppNamespace = "::mlir::enzymexla"; +} + +def EnzymeXLA_LapackJobAttr : EnumAttr { + let assemblyFormat = "`<` $value `>`"; +} + def EnzymeXLA_QrAlgorithm : I32EnumAttr<"QrAlgorithm", "Algorithm to use for the QR factorization", [ diff --git a/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td index 4d2ec27a91..43cf2cdcc6 100644 --- a/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td +++ b/src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td @@ -550,6 +550,41 @@ def SVDFactorizationOp : EnzymeXLA_Op<"linalg.svd", [Pure]> { }]; } +def GesvdOp : EnzymeXLA_Op<"lapack.gesvd", [Pure]> { + let summary = "Perform Singular Value Decomposition (SVD) factorization using QR iteration." + + let arguments = (ins + HLO_Tensor:$input, + EnzymeXLA_LapackJobAttr:$jobu, + EnzymeXLA_LapackJobAttr:$jobvt + ); + + let results = (outs + Variadic:$outputs + ); + + let assemblyFormat = [{ + $input attr-dict `:` functional-type($input, results) + }]; +} + +def GesddOp : EnzymeXLA_Op<"lapack.gesdd", [Pure]> { + let summary = "Singular Value Decomposition (SVD) using divide-and-conquer strategy."; + + let arguments = (ins + HLO_Tensor:$input, + EnzymeXLA_LapackJobAttr:$jobz + ); + + let results = (outs + Variadic:$outputs + ); + + let assemblyFormat = [{ + $input attr-dict `:` functional-type($input, results) + }]; +} + // Machine Learning Ops def GeluOp: EnzymeXLA_Op<"ml.gelu", [Pure, SameOperandsAndResultType, Elementwise]> {