113113 end
114114end
115115
116+ function MatrixAlgebraKit. eig_trunc! (A, DV, ϵ:: Vector{T} , alg:: MatrixAlgebraKit.TruncatedAlgorithm ) where {T}
117+ D, V = eig_full! (A, DV, alg. alg)
118+ DVtrunc, ind = MatrixAlgebraKit. truncate (eig_trunc!, (D, V), alg. trunc)
119+ ϵ[1 ] = MatrixAlgebraKit. truncation_error! (diagview (D), ind)
120+ return DVtrunc... , ϵ
121+ end
122+ function dummy_eig_trunc (A, ϵ:: Vector{T} , alg:: TruncatedAlgorithm ) where {T}
123+ Ac = MatrixAlgebraKit. copy_input (MatrixAlgebraKit. eig_trunc, A)
124+ DV = MatrixAlgebraKit. initialize_output (eig_trunc!, A, alg)
125+ Dtrunc, Vtrunc, ϵ = MatrixAlgebraKit. eig_trunc! (Ac, DV, ϵ, alg)
126+ return Dtrunc, Vtrunc, ϵ
127+ end
128+
116129@timedtestset " EIG AD Rules with eltype $T " for T in ETs
117130 rng = StableRNG (12345 )
118131 m = 19
@@ -129,16 +142,16 @@ end
129142 test_reverse (eig_full, RT, (copy (A), TA); fkwargs= (alg= alg,), atol= atol, rtol= rtol, output_tangent= (copy (ΔD2), copy (ΔV)))
130143 test_reverse (eig_vals, RT, (copy (A), TA); fkwargs= (alg= alg,), atol= atol, rtol= rtol, output_tangent= copy (ΔD2. diag))
131144 end
132- @testset " reverse: RT $RT , TA $TA " for RT in (MixedDuplicated ,), TA in (Duplicated,)
145+ @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated ,), TA in (Duplicated,)
133146 for r in 1 : 4 : m
134147 truncalg = TruncatedAlgorithm (alg, truncrank (r; by = abs))
135148 ind = MatrixAlgebraKit. findtruncated (diagview (D), truncalg. trunc)
136149 Dtrunc = Diagonal (diagview (D)[ind])
137150 Vtrunc = V[:, ind]
138151 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
139152 ΔVtrunc = ΔV[:, ind]
140- # broken right now due to Enzyme
141- # test_reverse(eig_trunc! , RT, (A, TA), ((D, V), TA), (truncalg, Const); atol=atol, rtol=rtol, output_tangent=Base.RefValue(( ΔDtrunc, ΔVtrunc, zero(real(T))) ))
153+ ϵ = [ zero ( real (T))]
154+ test_reverse (dummy_eig_trunc , RT, (A, TA), (ϵ, TA), (truncalg, Const); atol= atol, rtol= rtol, output_tangent= ( ΔDtrunc, ΔVtrunc, [ zero (real (T))] ))
142155 dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
143156 dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
144157 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
149162 Vtrunc = V[:, ind]
150163 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
151164 ΔVtrunc = ΔV[:, ind]
152- # broken right now due to Enzyme
153- # test_reverse(eig_trunc , RT, (A, TA); fkwargs=(alg=truncalg, ), atol=atol, rtol=rtol, output_tangent=Base.RefValue(( ΔDtrunc, ΔVtrunc, zero(real(T))) ))
165+ ϵ = [ zero ( real (T))]
166+ test_reverse (dummy_eig_trunc , RT, (A, TA), (ϵ, TA ), (truncalg, Const); atol= atol, rtol= rtol, output_tangent= ( ΔDtrunc, ΔVtrunc, [ zero (real (T))] ))
154167 dA1 = MatrixAlgebraKit. eig_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
155168 dA2 = MatrixAlgebraKit. eig_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
156169 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
@@ -168,10 +181,18 @@ function copy_eigh_vals(A; kwargs...)
168181 eigh_vals (A; kwargs... )
169182end
170183
171- function copy_eigh_trunc! (A; kwargs... )
184+ function MatrixAlgebraKit. eigh_trunc! (A, DV, ϵ:: Vector{T} , alg:: MatrixAlgebraKit.TruncatedAlgorithm ) where {T}
185+ D, V = eigh_full! (A, DV, alg. alg)
186+ DVtrunc, ind = MatrixAlgebraKit. truncate (eigh_trunc!, (D, V), alg. trunc)
187+ ϵ[1 ] = MatrixAlgebraKit. truncation_error! (diagview (D), ind)
188+ return DVtrunc... , ϵ
189+ end
190+ function dummy_eigh_trunc (A, ϵ:: Vector{T} , alg:: TruncatedAlgorithm ) where {T}
172191 A = (A + A' )/ 2
173- DV = MatrixAlgebraKit. initialize_output (eigh_trunc!, A, kwargs[:alg ])
174- eigh_trunc! (A, DV; kwargs... )
192+ Ac = MatrixAlgebraKit. copy_input (MatrixAlgebraKit. eigh_trunc, A)
193+ DV = MatrixAlgebraKit. initialize_output (eigh_trunc!, A, alg)
194+ Dtrunc, Vtrunc, ϵ = MatrixAlgebraKit. eigh_trunc! (Ac, DV, ϵ, alg)
195+ return Dtrunc, Vtrunc, ϵ
175196end
176197
177198@timedtestset " EIGH AD Rules with eltype $T " for T in ETs
@@ -194,11 +215,11 @@ end
194215 @testset " forward: RT $RT , TA $TA " for RT in (Const, Duplicated,), TA in (Const, Duplicated,)
195216 test_forward (copy_eigh_vals, RT, (copy (A), TA); fkwargs= (alg= alg,), atol= atol, rtol= rtol)
196217 end
197- @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
218+ @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
198219 test_reverse (copy_eigh_full, RT, (copy (A), TA); fkwargs= (alg= alg,), atol= atol, rtol= rtol, output_tangent= (copy (ΔD2), copy (ΔV)))
199220 test_reverse (copy_eigh_vals, RT, (copy (A), TA); fkwargs= (alg= alg,), atol= atol, rtol= rtol, output_tangent= copy (ΔD2. diag))
200221 end
201- @testset " reverse: RT $RT , TA $TA " for RT in (MixedDuplicated ,), TA in (Duplicated,)
222+ @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated ,), TA in (Duplicated,)
202223 for r in 1 : 4 : m
203224 Ddiag = diagview (D)
204225 truncalg = TruncatedAlgorithm (alg, truncrank (r; by = abs))
207228 Vtrunc = V[:, ind]
208229 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
209230 ΔVtrunc = ΔV[:, ind]
210- # broken right now due to Enzyme
211- # test_reverse(copy_eigh_trunc , RT, (A, TA); fkwargs=(alg=truncalg, ), atol=atol, rtol=rtol, output_tangent=(copy( ΔDtrunc), copy( ΔVtrunc), zero(real(T))))
231+ ϵ = [ zero ( real (T))]
232+ test_reverse (dummy_eigh_trunc , RT, (A, TA), (ϵ, TA ), (truncalg, Const); atol= atol, rtol= rtol, output_tangent= (ΔDtrunc, ΔVtrunc, [ zero (real (T))] ))
212233 dA1 = MatrixAlgebraKit. eigh_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
213234 dA2 = MatrixAlgebraKit. eigh_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
214235 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
@@ -220,15 +241,28 @@ end
220241 Vtrunc = V[:, ind]
221242 ΔDtrunc = Diagonal (diagview (ΔD2)[ind])
222243 ΔVtrunc = ΔV[:, ind]
223- # broken right now due to Enzyme
224- # test_reverse(copy_eigh_trunc , RT, (A, TA); fkwargs=(alg=truncalg, ), atol=atol, rtol=rtol, output_tangent=(copy( ΔDtrunc), copy( ΔVtrunc), zero(real(T))))
244+ ϵ = [ zero ( real (T))]
245+ test_reverse (dummy_eigh_trunc , RT, (A, TA), (ϵ, TA ), (truncalg, Const); atol= atol, rtol= rtol, output_tangent= (ΔDtrunc, ΔVtrunc, [ zero (real (T))] ))
225246 dA1 = MatrixAlgebraKit. eigh_pullback! (zero (A), A, (D, V), (ΔDtrunc, ΔVtrunc), ind)
226247 dA2 = MatrixAlgebraKit. eigh_trunc_pullback! (zero (A), A, (Dtrunc, Vtrunc), (ΔDtrunc, ΔVtrunc))
227248 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
228249 end
229250 end
230251end
231252
253+ function MatrixAlgebraKit. svd_trunc! (A, USVᴴ, ϵ:: Vector{T} , alg:: MatrixAlgebraKit.TruncatedAlgorithm ) where {T}
254+ U, S, Vᴴ = svd_compact! (A, USVᴴ, alg. alg)
255+ USVᴴtrunc, ind = MatrixAlgebraKit. truncate (svd_trunc!, (U, S, Vᴴ), alg. trunc)
256+ ϵ[1 ] = MatrixAlgebraKit. truncation_error! (diagview (S), ind)
257+ return USVᴴtrunc... , ϵ
258+ end
259+ function dummy_svd_trunc (A, ϵ:: Vector{T} , alg:: TruncatedAlgorithm ) where {T}
260+ Ac = MatrixAlgebraKit. copy_input (MatrixAlgebraKit. svd_trunc, A)
261+ USVᴴ = MatrixAlgebraKit. initialize_output (svd_trunc!, A, alg)
262+ Utrunc, Strunc, Vᴴtrunc, ϵ = MatrixAlgebraKit. svd_trunc! (Ac, USVᴴ, ϵ, alg)
263+ return Utrunc, Strunc, Vᴴtrunc, ϵ
264+ end
265+
232266@timedtestset " SVD AD Rules with eltype $T " for T in ETs
233267 rng = StableRNG (12345 )
234268 m = 19
239273 @testset for alg in (LAPACK_QRIteration (),
240274 LAPACK_DivideAndConquer (),
241275 )
242- @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
276+ @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated,), TA in (Duplicated,)
243277 @testset " svd_compact" begin
244278 U, S, Vᴴ = svd_compact (A)
245279 ΔU = randn (rng, T, m, minmn)
250284 test_reverse (svd_compact, RT, (A, TA); atol= atol, rtol= rtol, fkwargs= (alg= alg,), output_tangent= (ΔU, ΔS, ΔVᴴ), fdm= fdm)
251285 end
252286 end
253- @testset " reverse: RT $RT , TA $TA " for RT in (MixedDuplicated ,), TA in (Duplicated,)
287+ @testset " reverse: RT $RT , TA $TA " for RT in (Duplicated ,), TA in (Duplicated,)
254288 @testset " svd_trunc" begin
255289 for r in 1 : 4 : minmn
256290 U, S, Vᴴ = svd_compact (A)
269303 ΔVᴴtrunc = ΔVᴴ[ind, :]
270304 fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 , max_range= 1e-2 ) : EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 )
271305 # broken due to Enzyme
272- # test_reverse(svd_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm=fdm)
306+ ϵ = [zero (real (T))]
307+ test_reverse (dummy_svd_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol= atol, rtol= rtol, output_tangent= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero (real (T))]), fdm= fdm)
273308 dA1 = MatrixAlgebraKit. svd_pullback! (zero (A), copy (A), (U, S, Vᴴ), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc), ind)
274309 dA2 = MatrixAlgebraKit. svd_trunc_pullback! (zero (A), copy (A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
275310 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
290325 ΔVᴴtrunc = ΔVᴴ[ind, :]
291326 fdm = T <: Union{Float32, ComplexF32} ? EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 , max_range= 1e-2 ) : EnzymeTestUtils. FiniteDifferences. central_fdm (5 , 1 )
292327 # broken due to Enzyme
293- # test_reverse(svd_trunc, RT, (A, TA); fkwargs=(alg=truncalg,), atol=atol, rtol=rtol, output_tangent=(ΔUtrunc, ΔStrunc, ΔVᴴtrunc, zero(real(T))), fdm)
328+ ϵ = [zero (real (T))]
329+ test_reverse (dummy_svd_trunc, RT, (A, TA), (ϵ, TA), (truncalg, Const); atol= atol, rtol= rtol, output_tangent= (ΔUtrunc, ΔStrunc, ΔVᴴtrunc, [zero (real (T))]), fdm= fdm)
294330 dA1 = MatrixAlgebraKit. svd_pullback! (zero (A), copy (A), (copy (U), copy (S), copy (Vᴴ)), (copy (ΔUtrunc), copy (ΔStrunc), copy (ΔVᴴtrunc)), ind)
295331 dA2 = MatrixAlgebraKit. svd_trunc_pullback! (zero (A), copy (A), (Utrunc, Strunc, Vᴴtrunc), (ΔUtrunc, ΔStrunc, ΔVᴴtrunc))
296332 @test isapprox (dA1, dA2; atol = atol, rtol = rtol)
0 commit comments