Skip to content

Commit a29eabd

Browse files
committed
feat: update to allow algorithms
1 parent 992f864 commit a29eabd

File tree

5 files changed

+164
-14
lines changed

5 files changed

+164
-14
lines changed

deps/ReactantExtra/WORKSPACE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ NSYNC_COMMIT = "82b118aa7ace3132e517e2c467f8732978cf4023"
44

55
NSYNC_SHA256 = ""
66

7-
ENZYMEXLA_COMMIT = "719107a26418020ecee54c18cd2a299df7004786"
7+
ENZYMEXLA_COMMIT = "634866abf929fce4cb2d1ac25e76dca9b7a33f3c"
88

99
ENZYMEXLA_SHA256 = ""
1010

src/Compiler.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1722,6 +1722,8 @@ function compile_mlir!(
17221722

17231723
blas_int_width = sizeof(BlasInt) * 8
17241724
lower_enzymexla_linalg_pass = "lower-enzymexla-linalg{backend=$backend \
1725+
blas_int_width=$blas_int_width},\
1726+
lower-enzymexla-lapack{backend=$backend \
17251727
blas_int_width=$blas_int_width}"
17261728

17271729
legalize_chlo_to_stablehlo =

src/Ops.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3316,6 +3316,7 @@ end
33163316
x::TracedRArray{T,N},
33173317
::Type{iT}=Int32;
33183318
full::Bool=false,
3319+
algorithm::String="DEFAULT",
33193320
location=mlir_stacktrace("svd", @__FILE__, @__LINE__),
33203321
) where {T,iT,N}
33213322
@assert N >= 2
@@ -3329,13 +3330,26 @@ end
33293330
Vt_size = (batch_sizes..., full ? n : r, n)
33303331
info_size = batch_sizes
33313332

3333+
if algorithm == "DEFAULT"
3334+
algint = 0
3335+
elseif algorithm == "QRIteration"
3336+
algint = 1
3337+
elseif algorithm == "DivideAndConquer"
3338+
algint = 2
3339+
elseif algorithm == "Jacobi"
3340+
algint = 3
3341+
else
3342+
error("Unsupported SVD algorithm: $algorithm")
3343+
end
3344+
33323345
svd_op = enzymexla.linalg_svd(
33333346
x.mlir_data;
33343347
U=mlir_type(TracedRArray{T,N}, U_size),
33353348
S=mlir_type(TracedRArray{Base.real(T),N - 1}, S_size),
33363349
Vt=mlir_type(TracedRArray{T,N}, Vt_size),
33373350
info=mlir_type(TracedRArray{iT,N - 2}, info_size),
33383351
full=full,
3352+
algorithm=MLIR.API.enzymexlaSVDAlgorithmAttrGet(MLIR.IR.context(), algint),
33393353
location,
33403354
)
33413355

src/mlir/Dialects/EnzymeXLA.jl

Lines changed: 143 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -346,14 +346,13 @@ end
346346
"""
347347
`lapack_geqrf`
348348
349-
This operation computes the QR factorization of a matrix using Householder
350-
reflections. Mathematically, it decomposes A into the product of an
349+
This operation computes the QR factorization of a matrix using Householder
350+
reflections. Mathematically, it decomposes A into the product of an
351351
orthogonal matri x Q and an upper triangular matrix R,
352-
such that A = QR.
352+
such that A = QR.
353353
354-
This operation is modeled after
355-
LAPACK\'s *GEQRF routines, which returns the result in
356-
the QR packed format.
354+
This operation is modeled after LAPACK\'s *GEQRF routines, which returns the
355+
result in the QR packed format.
357356
"""
358357
function lapack_geqrf(
359358
input::Value; output::IR.Type, tau::IR.Type, info::IR.Type, location=Location()
@@ -379,11 +378,11 @@ end
379378
"""
380379
`lapack_geqrt`
381380
382-
This operation computes the QR factorization of a matrix using Householder
383-
reflections. Mathematically, it decomposes A into the product of an
381+
This operation computes the QR factorization of a matrix using Householder
382+
reflections. Mathematically, it decomposes A into the product of an
384383
orthogonal matrix Q and an upper triangular matrix R, such that A = QR.
385384
386-
This operation is modeled after LAPACK\'s *GEQRT routines, which returns the
385+
This operation is modeled after LAPACK\'s *GEQRT routines, which returns the
387386
result in the QR CompactWY format.
388387
"""
389388
function lapack_geqrt(
@@ -413,6 +412,90 @@ function lapack_geqrt(
413412
)
414413
end
415414

415+
function lapack_gesdd(
416+
input::Value;
417+
U::IR.Type,
418+
S::IR.Type,
419+
Vt::IR.Type,
420+
info::IR.Type,
421+
full=nothing,
422+
location=Location(),
423+
)
424+
op_ty_results = IR.Type[U, S, Vt, info]
425+
operands = Value[input,]
426+
owned_regions = Region[]
427+
successors = Block[]
428+
attributes = NamedAttribute[]
429+
!isnothing(full) && push!(attributes, namedattribute("full", full))
430+
431+
return create_operation(
432+
"enzymexla.lapack.gesdd",
433+
location;
434+
operands,
435+
owned_regions,
436+
successors,
437+
attributes,
438+
results=op_ty_results,
439+
result_inference=false,
440+
)
441+
end
442+
443+
function lapack_gesvd(
444+
input::Value;
445+
U::IR.Type,
446+
S::IR.Type,
447+
Vt::IR.Type,
448+
info::IR.Type,
449+
full=nothing,
450+
location=Location(),
451+
)
452+
op_ty_results = IR.Type[U, S, Vt, info]
453+
operands = Value[input,]
454+
owned_regions = Region[]
455+
successors = Block[]
456+
attributes = NamedAttribute[]
457+
!isnothing(full) && push!(attributes, namedattribute("full", full))
458+
459+
return create_operation(
460+
"enzymexla.lapack.gesvd",
461+
location;
462+
operands,
463+
owned_regions,
464+
successors,
465+
attributes,
466+
results=op_ty_results,
467+
result_inference=false,
468+
)
469+
end
470+
471+
function lapack_gesvj(
472+
input::Value;
473+
U::IR.Type,
474+
S::IR.Type,
475+
Vt::IR.Type,
476+
info::IR.Type,
477+
full=nothing,
478+
location=Location(),
479+
)
480+
op_ty_results = IR.Type[U, S, Vt, info]
481+
operands = Value[input,]
482+
owned_regions = Region[]
483+
successors = Block[]
484+
attributes = NamedAttribute[]
485+
!isnothing(full) && push!(attributes, namedattribute("full", full))
486+
487+
return create_operation(
488+
"enzymexla.lapack.gesvj",
489+
location;
490+
operands,
491+
owned_regions,
492+
successors,
493+
attributes,
494+
results=op_ty_results,
495+
result_inference=false,
496+
)
497+
end
498+
416499
function get_stream(; result::IR.Type, location=Location())
417500
op_ty_results = IR.Type[result,]
418501
operands = Value[]
@@ -432,6 +515,51 @@ function get_stream(; result::IR.Type, location=Location())
432515
)
433516
end
434517

518+
function lapack_getrf(
519+
input::Value;
520+
output::IR.Type,
521+
pivots::IR.Type,
522+
permutation::IR.Type,
523+
info::IR.Type,
524+
location=Location(),
525+
)
526+
op_ty_results = IR.Type[output, pivots, permutation, info]
527+
operands = Value[input,]
528+
owned_regions = Region[]
529+
successors = Block[]
530+
attributes = NamedAttribute[]
531+
532+
return create_operation(
533+
"enzymexla.lapack.getrf",
534+
location;
535+
operands,
536+
owned_regions,
537+
successors,
538+
attributes,
539+
results=op_ty_results,
540+
result_inference=false,
541+
)
542+
end
543+
544+
function lapack_getri(input::Value, ipiv::Value; output::IR.Type, location=Location())
545+
op_ty_results = IR.Type[output,]
546+
operands = Value[input, ipiv]
547+
owned_regions = Region[]
548+
successors = Block[]
549+
attributes = NamedAttribute[]
550+
551+
return create_operation(
552+
"enzymexla.lapack.getri",
553+
location;
554+
operands,
555+
owned_regions,
556+
successors,
557+
attributes,
558+
results=op_ty_results,
559+
result_inference=false,
560+
)
561+
end
562+
435563
function jit_call(
436564
inputs::Vector{Value};
437565
result_0::Vector{IR.Type},
@@ -754,15 +882,15 @@ end
754882
"""
755883
`linalg_qr`
756884
757-
This operation computes the QR factorization of a matrix using Householder
758-
reflections. Mathematically, it decomposes A into the product of an
759-
orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
885+
This operation computes the QR factorization of a matrix using Householder
886+
reflections. Mathematically, it decomposes A into the product of an
887+
orthogonal (unitary if complex) matrix Q and an upper triangular matrix R,
760888
such that A = QR.
761889
762890
If A has size m x n and m > n, Q is an m x n isometric matrix. If m < n, R
763891
will be a m x n trapezoidal matrix.
764892
765-
This operation is modeled after the mathematical formulation of the QR
893+
This operation is modeled after the mathematical formulation of the QR
766894
factorization, and not after LAPACK\'s compact formats.
767895
"""
768896
function linalg_qr(
@@ -842,6 +970,7 @@ function linalg_svd(
842970
Vt::IR.Type,
843971
info::IR.Type,
844972
full=nothing,
973+
algorithm=nothing,
845974
location=Location(),
846975
)
847976
op_ty_results = IR.Type[U, S, Vt, info]
@@ -850,6 +979,7 @@ function linalg_svd(
850979
successors = Block[]
851980
attributes = NamedAttribute[]
852981
!isnothing(full) && push!(attributes, namedattribute("full", full))
982+
!isnothing(algorithm) && push!(attributes, namedattribute("algorithm", algorithm))
853983

854984
return create_operation(
855985
"enzymexla.linalg.svd",

src/mlir/libMLIR_h.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11655,6 +11655,10 @@ function enzymexlaQRAlgorithmAttrGet(ctx, mode)
1165511655
@ccall mlir_c.enzymexlaQRAlgorithmAttrGet(ctx::MlirContext, mode::Int32)::MlirAttribute
1165611656
end
1165711657

11658+
function enzymexlaSVDAlgorithmAttrGet(ctx, mode)
11659+
@ccall mlir_c.enzymexlaSVDAlgorithmAttrGet(ctx::MlirContext, mode::Int32)::MlirAttribute
11660+
end
11661+
1165811662
function enzymexlaGeluApproximationAttrGet(ctx, mode)
1165911663
@ccall mlir_c.enzymexlaGeluApproximationAttrGet(
1166011664
ctx::MlirContext, mode::Int32

0 commit comments

Comments
 (0)