Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/enzyme_ad/jax/Dialect/EnzymeXLAAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,23 @@ def EnzymeXLA_LapackSideAttr : EnumAttr<EnzymeXLA_Dialect,
let assemblyFormat = "`<` $value `>`";
}

def EnzymeXLA_LapackJob : I32EnumAttr<"Job",
"Job to perform in the SVD factorization",
[
I32EnumAttrCase<"none", 0>,
I32EnumAttrCase<"all", 1>,
I32EnumAttrCase<"overwrite", 2>,
I32EnumAttrCase<"some", 3>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::enzymexla";
}

def EnzymeXLA_LapackJobAttr : EnumAttr<EnzymeXLA_Dialect,
EnzymeXLA_LapackJob, "job"> {
let assemblyFormat = "`<` $value `>`";
}

def EnzymeXLA_QrAlgorithm : I32EnumAttr<"QrAlgorithm",
"Algorithm to use for the QR factorization",
[
Expand Down
35 changes: 35 additions & 0 deletions src/enzyme_ad/jax/Dialect/EnzymeXLAOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,41 @@ def SVDFactorizationOp : EnzymeXLA_Op<"linalg.svd", [Pure]> {
}];
}

def GesvdOp : EnzymeXLA_Op<"lapack.gesvd", [Pure]> {
let summary = "Perform Singular Value Decomposition (SVD) factorization using QR iteration."

let arguments = (ins
HLO_Tensor:$input,
EnzymeXLA_LapackJobAttr:$jobu,
EnzymeXLA_LapackJobAttr:$jobvt
);

let results = (outs
Variadic<HLO_Tensor>:$outputs
);

let assemblyFormat = [{
$input attr-dict `:` functional-type($input, results)
}];
}

def GesddOp : EnzymeXLA_Op<"lapack.gesdd", [Pure]> {
let summary = "Singular Value Decomposition (SVD) using divide-and-conquer strategy.";

let arguments = (ins
HLO_Tensor:$input,
EnzymeXLA_LapackJobAttr:$jobz
);

let results = (outs
Variadic<HLO_Tensor>:$outputs
);

let assemblyFormat = [{
$input attr-dict `:` functional-type($input, results)
}];
}

// Machine Learning Ops

def GeluOp: EnzymeXLA_Op<"ml.gelu", [Pure, SameOperandsAndResultType, Elementwise]> {
Expand Down
Loading