Skip to content

Commit b117e36

Browse files
authored
feat: updated lowering pipeline of linalg ops (#1783)
* feat: svd op * feat: map more symbols * feat: update to allow algorithms * chore: bump reactant_jll version
1 parent 417c361 commit b117e36

File tree

4 files changed

+73
-1
lines changed

4 files changed

+73
-1
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ PythonCall = "0.9.25"
105105
Random = "1.10"
106106
Random123 = "1.7"
107107
ReactantCore = "0.1.16"
108-
Reactant_jll = "0.0.262"
108+
Reactant_jll = "0.0.263"
109109
ScopedValues = "1.3.0"
110110
Scratch = "1.2"
111111
Sockets = "1.10"

src/Compiler.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,6 +1722,8 @@ function compile_mlir!(
17221722

17231723
blas_int_width = sizeof(BlasInt) * 8
17241724
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
1725+
blas_int_width=$blas_int_width},\
1726+
lower-enzymexla-lapack{backend=$backend \
17251727
blas_int_width=$blas_int_width}"
17261728

17271729
legalize_chlo_to_stablehlo =

src/Ops.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3312,6 +3312,60 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
33123312
return (res, ipiv, perm, info)
33133313
end
33143314

3315+
@noinline function svd(
3316+
x::TracedRArray{T,N},
3317+
::Type{iT}=Int32;
3318+
full::Bool=false,
3319+
algorithm::String="DEFAULT",
3320+
location=mlir_stacktrace("svd", @__FILE__, @__LINE__),
3321+
) where {T,iT,N}
3322+
@assert N >= 2
3323+
3324+
batch_sizes = size(x)[1:(end - 2)]
3325+
m, n = size(x)[(end - 1):end]
3326+
r = min(m, n)
3327+
3328+
U_size = (batch_sizes..., m, full ? m : r)
3329+
S_size = (batch_sizes..., r)
3330+
Vt_size = (batch_sizes..., full ? n : r, n)
3331+
info_size = batch_sizes
3332+
3333+
if algorithm == "DEFAULT"
3334+
algint = 0
3335+
elseif algorithm == "QRIteration"
3336+
algint = 1
3337+
elseif algorithm == "DivideAndConquer"
3338+
algint = 2
3339+
elseif algorithm == "Jacobi"
3340+
algint = 3
3341+
else
3342+
error("Unsupported SVD algorithm: $algorithm")
3343+
end
3344+
3345+
svd_op = enzymexla.linalg_svd(
3346+
x.mlir_data;
3347+
U=mlir_type(TracedRArray{T,N}, U_size),
3348+
S=mlir_type(TracedRArray{Base.real(T),N - 1}, S_size),
3349+
Vt=mlir_type(TracedRArray{T,N}, Vt_size),
3350+
info=mlir_type(TracedRArray{iT,N - 2}, info_size),
3351+
full=full,
3352+
algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(MLIR.IR.context(), algint),
3353+
location,
3354+
)
3355+
3356+
U = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 1), U_size)
3357+
S = TracedRArray{Base.real(T),N - 1}((), MLIR.IR.result(svd_op, 2), S_size)
3358+
Vt = TracedRArray{T,N}((), MLIR.IR.result(svd_op, 3), Vt_size)
3359+
3360+
if N == 2
3361+
info = TracedRNumber{iT}((), MLIR.IR.result(svd_op, 4))
3362+
else
3363+
info = TracedRArray{iT,N - 2}((), MLIR.IR.result(svd_op, 4), info_size)
3364+
end
3365+
3366+
return U, S, Vt, info
3367+
end
3368+
33153369
@noinline function reduce_window(
33163370
f::F,
33173371
inputs::Vector{TracedRArray{T,N}},

src/stdlibs/LinearAlgebra.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,26 @@ function __init__()
2525
libblastrampoline_handle = Libdl.dlopen(BLAS.libblas)
2626

2727
for (cname, enzymexla_name) in [
28+
# LU
2829
(BLAS.@blasfunc(sgetrf_), :enzymexla_lapack_sgetrf_),
2930
(BLAS.@blasfunc(dgetrf_), :enzymexla_lapack_dgetrf_),
3031
(BLAS.@blasfunc(cgetrf_), :enzymexla_lapack_cgetrf_),
3132
(BLAS.@blasfunc(zgetrf_), :enzymexla_lapack_zgetrf_),
33+
# SVD QR Iteration
34+
(BLAS.@blasfunc(sgesvd_), :enzymexla_lapack_sgesvd_),
35+
(BLAS.@blasfunc(dgesvd_), :enzymexla_lapack_dgesvd_),
36+
(BLAS.@blasfunc(cgesvd_), :enzymexla_lapack_cgesvd_),
37+
(BLAS.@blasfunc(zgesvd_), :enzymexla_lapack_zgesvd_),
38+
# SVD Divide and Conquer
39+
(BLAS.@blasfunc(sgesdd_), :enzymexla_lapack_sgesdd_),
40+
(BLAS.@blasfunc(dgesdd_), :enzymexla_lapack_dgesdd_),
41+
(BLAS.@blasfunc(cgesdd_), :enzymexla_lapack_cgesdd_),
42+
(BLAS.@blasfunc(zgesdd_), :enzymexla_lapack_zgesdd_),
43+
# SVD Jacobi
44+
(BLAS.@blasfunc(sgesvj_), :enzymexla_lapack_sgesvj_),
45+
(BLAS.@blasfunc(dgesvj_), :enzymexla_lapack_dgesvj_),
46+
(BLAS.@blasfunc(cgesvj_), :enzymexla_lapack_cgesvj_),
47+
(BLAS.@blasfunc(zgesvj_), :enzymexla_lapack_zgesvj_),
3248
]
3349
sym = Libdl.dlsym(libblastrampoline_handle, cname)
3450
@ccall MLIR.API.mlir_c.EnzymeJaXMapSymbol(

0 commit comments

Comments
 (0)