@@ -3312,6 +3312,60 @@ Compute the row maximum pivoted LU factorization of `x` and return the factors `
33123312 return (res, ipiv, perm, info)
33133313end
33143314
3315+ @noinline function svd (
3316+ x:: TracedRArray{T,N} ,
3317+ :: Type{iT} = Int32;
3318+ full:: Bool = false ,
3319+ algorithm:: String = " DEFAULT" ,
3320+ location= mlir_stacktrace (" svd" , @__FILE__ , @__LINE__ ),
3321+ ) where {T,iT,N}
3322+ @assert N >= 2
3323+
3324+ batch_sizes = size (x)[1 : (end - 2 )]
3325+ m, n = size (x)[(end - 1 ): end ]
3326+ r = min (m, n)
3327+
3328+ U_size = (batch_sizes... , m, full ? m : r)
3329+ S_size = (batch_sizes... , r)
3330+ Vt_size = (batch_sizes... , full ? n : r, n)
3331+ info_size = batch_sizes
3332+
3333+ if algorithm == " DEFAULT"
3334+ algint = 0
3335+ elseif algorithm == " QRIteration"
3336+ algint = 1
3337+ elseif algorithm == " DivideAndConquer"
3338+ algint = 2
3339+ elseif algorithm == " Jacobi"
3340+ algint = 3
3341+ else
3342+ error (" Unsupported SVD algorithm: $algorithm " )
3343+ end
3344+
3345+ svd_op = enzymexla. linalg_svd (
3346+ x. mlir_data;
3347+ U= mlir_type (TracedRArray{T,N}, U_size),
3348+ S= mlir_type (TracedRArray{Base. real (T),N - 1 }, S_size),
3349+ Vt= mlir_type (TracedRArray{T,N}, Vt_size),
3350+ info= mlir_type (TracedRArray{iT,N - 2 }, info_size),
3351+ full= full,
3352+ algorithm= MLIR. API. enzymexlaSVDAlgorithmAttrGet (MLIR. IR. context (), algint),
3353+ location,
3354+ )
3355+
3356+ U = TracedRArray {T,N} ((), MLIR. IR. result (svd_op, 1 ), U_size)
3357+ S = TracedRArray {Base.real(T),N - 1} ((), MLIR. IR. result (svd_op, 2 ), S_size)
3358+ Vt = TracedRArray {T,N} ((), MLIR. IR. result (svd_op, 3 ), Vt_size)
3359+
3360+ if N == 2
3361+ info = TracedRNumber {iT} ((), MLIR. IR. result (svd_op, 4 ))
3362+ else
3363+ info = TracedRArray {iT,N - 2} ((), MLIR. IR. result (svd_op, 4 ), info_size)
3364+ end
3365+
3366+ return U, S, Vt, info
3367+ end
3368+
33153369@noinline function reduce_window (
33163370 f:: F ,
33173371 inputs:: Vector{TracedRArray{T,N}} ,
0 commit comments