8989function initialize_output (:: typeof (svd_vals!), A:: AbstractMatrix , :: AbstractAlgorithm )
9090 return similar (A, real (eltype (A)), (min (size (A)... ),))
9191end
92- function initialize_output (:: typeof (svd_trunc!), A, alg:: TruncatedAlgorithm )
92+ function initialize_output (:: typeof (svd_trunc!), A:: AbstractMatrix , alg:: TruncatedAlgorithm )
9393 return initialize_output (svd_compact!, A, alg. alg)
9494end
9595
@@ -347,46 +347,25 @@ function _gpu_gesvdj!(
347347 )
348348 throw (MethodError (_gpu_gesvdj!, (A, S, U, Vᴴ)))
349349end
350- function _gpu_gesvd_maybe_transpose! (A:: AbstractMatrix , S:: AbstractVector , U:: AbstractMatrix , Vᴴ:: AbstractMatrix )
351- m, n = size (A)
352- m ≥ n && return _gpu_gesvd! (A, S, U, Vᴴ)
353- # both CUSOLVER and ROCSOLVER require m ≥ n for gesvd (QR_Iteration)
354- # if this condition is not met, do the SVD via adjoint
355- minmn = min (m, n)
356- Aᴴ = min (m, n) > 0 ? adjoint! (similar (A' ), A):: AbstractMatrix : similar (A' )
357- Uᴴ = similar (U' )
358- V = similar (Vᴴ' )
359- if size (U) == (m, m)
360- _gpu_gesvd! (Aᴴ, view (S, 1 : minmn, 1 ), V, Uᴴ)
361- else
362- _gpu_gesvd! (Aᴴ, S, V, Uᴴ)
363- end
364- length (U) > 0 && adjoint! (U, Uᴴ)
365- length (Vᴴ) > 0 && adjoint! (Vᴴ, V)
366- return U, S, Vᴴ
367- end
368-
369350# GPU SVD implementation
370- function svd_full! (A:: AbstractMatrix , USVᴴ, alg:: GPU_SVDAlgorithm )
351+ function MatrixAlgebraKit . svd_full! (A:: AbstractMatrix , USVᴴ, alg:: GPU_SVDAlgorithm )
371352 check_input (svd_full!, A, USVᴴ, alg)
372353 U, S, Vᴴ = USVᴴ
373354 fill! (S, zero (eltype (S)))
374355 m, n = size (A)
375356 minmn = min (m, n)
376- if minmn == 0
377- one! (U)
378- zero! (S)
379- one! (Vᴴ)
380- return USVᴴ
381- end
382357 if alg isa GPU_QRIteration
383358 isempty (alg. kwargs) ||
384- @warn " GPU_QRIteration does not accept any keyword arguments"
385- _gpu_gesvd_maybe_transpose ! (A, view (S, 1 : minmn, 1 ), U, Vᴴ)
359+ throw ( ArgumentError ( " GPU_QRIteration does not accept any keyword arguments" ))
360+ _gpu_gesvd ! (A, view (S, 1 : minmn, 1 ), U, Vᴴ)
386361 elseif alg isa GPU_SVDPolar
387362 _gpu_Xgesvdp! (A, view (S, 1 : minmn, 1 ), U, Vᴴ; alg. kwargs... )
388363 elseif alg isa GPU_Jacobi
389364 _gpu_gesvdj! (A, view (S, 1 : minmn, 1 ), U, Vᴴ; alg. kwargs... )
365+ # elseif alg isa LAPACK_Bisection
366+ # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
367+ # elseif alg isa LAPACK_Jacobi
368+ # throw(ArgumentError("LAPACK_Bisection is not supported for full SVD"))
390369 else
391370 throw (ArgumentError (" Unsupported SVD algorithm" ))
392371 end
@@ -403,21 +382,16 @@ function svd_trunc!(A::AbstractMatrix, USVᴴ, alg::TruncatedAlgorithm{<:GPU_Ran
403382 _gpu_Xgesvdr! (A, S. diag, U, Vᴴ; alg. alg. kwargs... )
404383 # TODO : make this controllable using a `gaugefix` keyword argument
405384 gaugefix! (svd_trunc!, U, S, Vᴴ, size (A)... )
406- # TODO : make sure that truncation is based on maxrank, otherwise this might be wrong
407- USVᴴtrunc, ind = truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
408- Strunc = diagview (USVᴴtrunc[2 ])
409- # normal `truncation_error!` does not work here since `S` is not the full singular value spectrum
410- ϵ = sqrt (norm (A)^ 2 - norm (Strunc)^ 2 ) # is there a more accurate way to do this?
411- return USVᴴtrunc... , ϵ
385+ return first (truncate (svd_trunc!, USVᴴ, alg. trunc))
412386end
413387
414- function svd_compact! (A:: AbstractMatrix , USVᴴ, alg:: GPU_SVDAlgorithm )
388+ function MatrixAlgebraKit . svd_compact! (A:: AbstractMatrix , USVᴴ, alg:: GPU_SVDAlgorithm )
415389 check_input (svd_compact!, A, USVᴴ, alg)
416390 U, S, Vᴴ = USVᴴ
417391 if alg isa GPU_QRIteration
418392 isempty (alg. kwargs) ||
419- @warn " GPU_QRIteration does not accept any keyword arguments"
420- _gpu_gesvd_maybe_transpose ! (A, S. diag, U, Vᴴ)
393+ throw ( ArgumentError ( " GPU_QRIteration does not accept any keyword arguments" ))
394+ _gpu_gesvd ! (A, S. diag, U, Vᴴ)
421395 elseif alg isa GPU_SVDPolar
422396 _gpu_Xgesvdp! (A, S. diag, U, Vᴴ; alg. kwargs... )
423397 elseif alg isa GPU_Jacobi
@@ -437,8 +411,8 @@ function MatrixAlgebraKit.svd_vals!(A::AbstractMatrix, S, alg::GPU_SVDAlgorithm)
437411 U, Vᴴ = similar (A, (0 , 0 )), similar (A, (0 , 0 ))
438412 if alg isa GPU_QRIteration
439413 isempty (alg. kwargs) ||
440- @warn " GPU_QRIteration does not accept any keyword arguments"
441- _gpu_gesvd_maybe_transpose ! (A, S, U, Vᴴ)
414+ throw ( ArgumentError ( " GPU_QRIteration does not accept any keyword arguments" ))
415+ _gpu_gesvd ! (A, S, U, Vᴴ)
442416 elseif alg isa GPU_SVDPolar
443417 _gpu_Xgesvdp! (A, S, U, Vᴴ; alg. kwargs... )
444418 elseif alg isa GPU_Jacobi
0 commit comments