diff --git a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl index d845dd97..da115a97 100644 --- a/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl +++ b/ext/MatrixAlgebraKitMooncakeExt/MatrixAlgebraKitMooncakeExt.jl @@ -3,12 +3,13 @@ module MatrixAlgebraKitMooncakeExt using Mooncake using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive using MatrixAlgebraKit -using MatrixAlgebraKit: inv_safe, diagview, copy_input -using MatrixAlgebraKit: qr_pullback!, lq_pullback! -using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback! +using MatrixAlgebraKit: inv_safe, diagview, copy_input, zero!, truncate, truncation_error! +using MatrixAlgebraKit: qr_pullback!, qr_pushforward!, lq_pullback!, lq_pushforward! +using MatrixAlgebraKit: qr_null_pullback!, qr_null_pushforward!, lq_null_pullback!, lq_null_pushforward! using MatrixAlgebraKit: eig_pullback!, eigh_pullback!, eig_trunc_pullback!, eigh_trunc_pullback! -using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback! -using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback! +using MatrixAlgebraKit: eig_pushforward!, eigh_pushforward!, eig_trunc_pushforward!, eigh_trunc_pushforward! +using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!, left_polar_pushforward!, right_polar_pushforward! +using MatrixAlgebraKit: svd_pullback!, svd_trunc_pullback!, svd_pushforward!, svd_trunc_pushforward! using LinearAlgebra @@ -24,21 +25,20 @@ function Mooncake.rrule!!(::CoDual{typeof(copy_input)}, f_df::CoDual, A_dA::CoDu end return CoDual(Ac, dAc), copy_input_pb end - -# two-argument in-place factorizations like LQ, QR, EIG -for (f!, f, pb, adj) in ( - (:qr_full!, :qr_full, :qr_pullback!, :qr_adjoint), - (:lq_full!, :lq_full, :lq_pullback!, :lq_adjoint), - (:qr_compact!, :qr_compact, :qr_pullback!, :qr_adjoint), - (:lq_compact!, :lq_compact, :lq_pullback!, :lq_adjoint), - (:eig_full!, :eig_full, :eig_pullback!, :eig_adjoint), - (:eigh_full!, :eigh_full, :eigh_pullback!, :eigh_adjoint), - (:left_polar!, :left_polar, :left_polar_pullback!, :left_polar_adjoint), - (:right_polar!, :right_polar, :right_polar_pullback!, :right_polar_adjoint), - ) +# two-argument factorizations like LQ, QR, EIG +for (f!, f, pb, pf, adj) in ((qr_full!, qr_full, qr_pullback!, qr_pushforward!, :dqr_adjoint), + (qr_compact!, qr_compact, qr_pullback!, qr_pushforward!, :dqr_adjoint), + (lq_full!, lq_full, lq_pullback!, lq_pushforward!, :dlq_adjoint), + (lq_compact!, lq_compact, lq_pullback!, lq_pushforward!, :dlq_adjoint), + (eig_full!, eig_full, eig_pullback!, eig_pushforward!, :deig_adjoint), + (eigh_full!, eigh_full, eigh_pullback!, eigh_pushforward!, :deigh_adjoint), + (left_polar!, left_polar, left_polar_pullback!, left_polar_pushforward!, :dleft_polar_adjoint), + (right_polar!, right_polar, right_polar_pullback!, right_polar_pushforward!, :dright_polar_adjoint), + ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) args = Mooncake.primal(args_dargs) @@ -54,13 +54,12 @@ for (f!, f, pb, adj) in ( $pb(dA, A, (arg1, arg2), (darg1, darg2)) copy!(arg1, arg1c) copy!(arg2, arg2c) - MatrixAlgebraKit.zero!(darg1) - MatrixAlgebraKit.zero!(darg2) + zero!(darg1) + zero!(darg2) return NoRData(), NoRData(), NoRData(), NoRData() end return args_dargs, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) output = $f(A, Mooncake.primal(alg_dalg)) @@ -81,15 +80,38 @@ for (f!, f, pb, adj) in ( end return output_codual, $adj end + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, args_dargs::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + args = Mooncake.primal(args_dargs) + args = $f!(A, args, Mooncake.primal(alg_dalg)) + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(args[1], dargs[1]) + arg2, darg2 = arrayify(args[2], dargs[2]) + darg1, darg2 = $pf(dA, A, (arg1, arg2), (darg1, darg2)) + zero!(dA) + return args_dargs + end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + args = $f(A, Mooncake.primal(alg_dalg)) + args_dargs = Mooncake.zero_dual(args) + arg1, arg2 = args + dargs = Mooncake.tangent(args_dargs) + arg1, darg1 = arrayify(arg1, dargs[1]) + arg2, darg2 = arrayify(arg2, dargs[2]) + $pf(dA, A, (arg1, arg2), (darg1, darg2)) + return args_dargs + end end end -for (f!, f, pb, adj) in ( - (:qr_null!, :qr_null, :qr_null_pullback!, :qr_null_adjoint), - (:lq_null!, :lq_null, :lq_null_pullback!, :lq_null_adjoint), - ) +for (f!, f, pb, pf, adj) in ((qr_null!, qr_null, qr_null_pullback!, qr_null_pushforward!, :dqr_null_adjoint), + (lq_null!, lq_null, lq_null_pullback!, lq_null_pushforward!, :dlq_null_adjoint), + ) + #forward mode not implemented yet @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f!)}, A_dA::CoDual, arg_darg::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) Ac = copy(A) @@ -100,12 +122,11 @@ for (f!, f, pb, adj) in ( copy!(A, Ac) $pb(dA, A, arg, darg) copy!(arg, argc) - MatrixAlgebraKit.zero!(darg) + zero!(darg) return NoRData(), NoRData(), NoRData(), NoRData() end return arg_darg, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}) A, dA = arrayify(A_dA) output = $f(A, Mooncake.primal(alg_dalg)) @@ -113,20 +134,37 @@ for (f!, f, pb, adj) in ( function $adj(::NoRData) arg, darg = arrayify(output_codual) $pb(dA, A, arg, darg) - MatrixAlgebraKit.zero!(darg) + zero!(darg) return NoRData(), NoRData(), NoRData() end return output_codual, $adj end + function Mooncake.frule!!(f_df::Dual{typeof($f!)}, A_dA::Dual, arg_darg::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + Ac = MatrixAlgebraKit.copy_input($f, A) + arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg)) + arg = $f!(A, arg, Mooncake.primal(alg_dalg)) + $pf(dA, Ac, arg, darg) + zero!(dA) + return arg_darg + end + function Mooncake.frule!!(f_df::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual{<:MatrixAlgebraKit.AbstractAlgorithm}) + A, dA = arrayify(A_dA) + arg = $f(A, Mooncake.primal(alg_dalg)) + darg = Mooncake.zero_tangent(arg) + $pf(dA, A, arg, darg) + return Dual(arg, darg) + end end end -for (f!, f, f_full, pb, adj) in ( - (:eig_vals!, :eig_vals, :eig_full, :eig_pullback!, :eig_vals_adjoint), - (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_pullback!, :eigh_vals_adjoint), +for (f!, f, f_full, pb, pf, adj) in ( + (:eig_vals!, :eig_vals, :eig_full, :eig_pullback!, :eig_pushforward!, :eig_vals_adjoint), + (:eigh_vals!, :eigh_vals, :eigh_full, :eigh_pullback!, :eigh_pushforward!, :eigh_vals_adjoint), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -142,7 +180,16 @@ for (f!, f, f_full, pb, adj) in ( end return D_dD, $adj end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, D_dD::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + D, dD = arrayify(D_dD) + nD, V = $f_full(A, Mooncake.primal(alg_dalg)) + copy!(D, diagview(nD)) + $pf(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing)) + zero!(dA) + return D_dD + end function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -159,15 +206,24 @@ for (f!, f, f_full, pb, adj) in ( end return output_codual, $adj end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + fullD, V = $f_full(A, Mooncake.primal(alg_dalg)) + D_dD = Mooncake.zero_dual(diagview(fullD)) + D, dD = arrayify(D_dD) + $pf(dA, A, (Diagonal(D), V), (Diagonal(dD), nothing)) + return D_dD + end end end -for (f, pb, adj) in ( - (:eig_trunc, :eig_trunc_pullback!, :eig_trunc_adjoint), - (:eigh_trunc, :eigh_trunc_pullback!, :eigh_trunc_adjoint), +for (f, pb, pf, adj) in ( + (:eig_trunc, :eig_trunc_pullback!, :eig_trunc_pushforward!, :eig_trunc_adjoint), + (:eigh_trunc, :eigh_trunc_pullback!, :eigh_trunc_pushforward!, :eigh_trunc_adjoint), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -185,12 +241,25 @@ for (f, pb, adj) in ( D, dD = arrayify(Dtrunc, dDtrunc_) V, dV = arrayify(Vtrunc, dVtrunc_) $pb(dA, A, (D, V), (dD, dV)) - MatrixAlgebraKit.zero!(dD) - MatrixAlgebraKit.zero!(dV) + zero!(dD) + zero!(dV) return NoRData(), NoRData(), NoRData() end return output_codual, $adj end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + output = $f(A, alg) + output_dual = Mooncake.zero_dual(output) + dD_ = Mooncake.tangent(output_dual)[1] + dV_ = Mooncake.tangent(output_dual)[2] + D, dD = arrayify(output[1], dD_) + V, dV = arrayify(output[2], dV_) + $pf(dA, A, (D, V), (dD, dV)) + return output_dual + end end end @@ -199,7 +268,8 @@ for (f!, f) in ( (:svd_compact!, :svd_compact), ) @eval begin - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f!), Any, Tuple{<:Any, <:Any, <:Any}, MatrixAlgebraKit.AbstractAlgorithm} + @is_primitive Mooncake.DefaultCtx Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f!)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) Ac = copy(A) @@ -223,14 +293,13 @@ for (f!, f) in ( vdVᴴ = view(dVᴴ, 1:minmn, :) svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData(), NoRData() end return CoDual(output, dUSVᴴ), svd_adjoint end - @is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, alg_dalg::CoDual) A, dA = arrayify(A_dA) USVᴴ = $f(A, Mooncake.primal(alg_dalg)) @@ -257,17 +326,134 @@ for (f!, f) in ( vdVᴴ = view(dVᴴ, 1:minmn, :) svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ)) end - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) + zero!(dU) + zero!(dS) + zero!(dVᴴ) return NoRData(), NoRData(), NoRData() end return USVᴴ_codual, svd_adjoint end + function Mooncake.frule!!(::Dual{typeof($f!)}, A_dA::Dual, USVᴴ_dUSVᴴ::Dual, alg_dalg::Dual) + # compute primal + USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ) + dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ) + A, dA = arrayify(A_dA) + $f!(A, USVᴴ, Mooncake.primal(alg_dalg)) + # update tangents + U_, S_, Vᴴ_ = USVᴴ + dU_, dS_, dVᴴ_ = dUSVᴴ + U, dU = arrayify(U_, dU_) + S, dS = arrayify(S_, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ_, dVᴴ_) + minmn = min(size(A)...) + if $(f == svd_compact!) # compact + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + else # full + vU = view(U, :, 1:minmn) + vS = view(S, 1:minmn, 1:minmn) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = view(dS, 1:minmn, 1:minmn) + vdVᴴ = view(dVᴴ, 1:minmn, :) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + end + zero!(dA) + return USVᴴ_dUSVᴴ + end + function Mooncake.frule!!(::Dual{typeof($f)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = arrayify(A_dA) + USVᴴ = $f(A, Mooncake.primal(alg_dalg)) + # update tangents + U, S, Vᴴ = USVᴴ + dU_ = Mooncake.zero_tangent(U) + dS_ = Mooncake.zero_tangent(S) + dVᴴ_ = Mooncake.zero_tangent(Vᴴ) + U, dU = arrayify(U, dU_) + S, dS = arrayify(S, dS_) + Vᴴ, dVᴴ = arrayify(Vᴴ, dVᴴ_) + if $(f == svd_compact!) # compact + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + else # full + minmn = min(size(A)...) + vU = view(U, :, 1:minmn) + vS = view(S, 1:minmn, 1:minmn) + vVᴴ = view(Vᴴ, 1:minmn, :) + vdU = view(dU, :, 1:minmn) + vdS = view(dS, 1:minmn, 1:minmn) + vdVᴴ = view(dVᴴ, 1:minmn, :) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + end + return Dual(USVᴴ, (dU_, dS_, dVᴴ_)) + end end end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals!), Any, Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) + # compute primal + A, dA = arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + output = svd_trunc(A, alg) + # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal + # of ComplexF32) into the correct **forwards** data type (since we are now in the forward + # pass). For many types this is done automatically when the forward step returns, but + # not for nested structs with various fields (like Diagonal{Complex}) + output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) + function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} + Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) + dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) + abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error" + U, dU = arrayify(Utrunc, dUtrunc_) + S, dS = arrayify(Strunc, dStrunc_) + Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) + svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) + zero!(dU) + zero!(dS) + zero!(dVᴴ) + return NoRData(), NoRData(), NoRData() + end + return output_codual, svd_trunc_adjoint +end +function Mooncake.frule!!(::Dual{typeof(svd_trunc)}, A_dA::Dual, alg_dalg::Dual) + # compute primal + A, dA = Mooncake.arrayify(A_dA) + alg = Mooncake.primal(alg_dalg) + USVᴴ = svd_compact(A, alg.alg) + U, S, Vᴴ = USVᴴ + dUfull = zeros(eltype(U), size(U)) + dSfull = Diagonal(zeros(eltype(S), length(diagview(S)))) + dVᴴfull = zeros(eltype(Vᴴ), size(Vᴴ)) + svd_pushforward!(dA, A, (U, S, Vᴴ), (dUfull, dSfull, dVᴴfull)) + + USVᴴtrunc, ind = truncate(svd_trunc!, USVᴴ, alg.trunc) + ϵ = truncation_error!(diagview(S), ind) + output = (USVᴴtrunc..., ϵ) + output_dual = Mooncake.zero_dual(output) + Utrunc, Strunc, Vᴴtrunc, ϵ = output + dU_, dS_, dVᴴ_, dϵ = Mooncake.tangent(output_dual) + Utrunc, dU = arrayify(Utrunc, dU_) + Strunc, dS = arrayify(Strunc, dS_) + Vᴴtrunc, dVᴴ = arrayify(Vᴴtrunc, dVᴴ_) + dU .= view(dUfull, :, ind) + diagview(dS) .= view(diagview(dSfull), ind) + dVᴴ .= view(dVᴴfull, ind, :) + return output_dual +end + +@is_primitive Mooncake.DefaultCtx Tuple{typeof(MatrixAlgebraKit.svd_vals!), Any, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm} +function Mooncake.frule!!(::Dual{typeof(svd_vals!)}, A_dA::Dual, S_dS::Dual, alg_dalg::Dual) + # compute primal + S, dS = Mooncake.arrayify(S_dS) + A, dA = Mooncake.arrayify(A_dA) + U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + # update tangent + copyto!(dS, diag(real.(Vᴴ * dA' * U))) + copyto!(S, diagview(nS)) + zero!(dA) + return S_dS +end + function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -282,7 +468,7 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals!)}, A_dA::CoDual, S_dS::CoDua return S_dS, svd_vals_adjoint end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} +@is_primitive Mooncake.DefaultCtx Tuple{typeof(svd_vals), Any, MatrixAlgebraKit.AbstractAlgorithm} function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::CoDual) # compute primal A, dA = arrayify(A_dA) @@ -301,33 +487,14 @@ function Mooncake.rrule!!(::CoDual{typeof(svd_vals)}, A_dA::CoDual, alg_dalg::Co return S_codual, svd_vals_adjoint end -@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(svd_trunc), Any, MatrixAlgebraKit.AbstractAlgorithm} -function Mooncake.rrule!!(::CoDual{typeof(svd_trunc)}, A_dA::CoDual, alg_dalg::CoDual) +function Mooncake.frule!!(::Dual{typeof(svd_vals)}, A_dA::Dual, alg_dalg::Dual) # compute primal - A_ = Mooncake.primal(A_dA) - dA_ = Mooncake.tangent(A_dA) - A, dA = arrayify(A_, dA_) - alg = Mooncake.primal(alg_dalg) - output = svd_trunc(A, alg) - # fdata call here is necessary to convert complicated Tangent type (e.g. of a Diagonal - # of ComplexF32) into the correct **forwards** data type (since we are now in the forward - # pass). For many types this is done automatically when the forward step returns, but - # not for nested structs with various fields (like Diagonal{Complex}) - output_codual = CoDual(output, Mooncake.fdata(Mooncake.zero_tangent(output))) - function svd_trunc_adjoint(dy::Tuple{NoRData, NoRData, NoRData, T}) where {T <: Real} - Utrunc, Strunc, Vᴴtrunc, ϵ = Mooncake.primal(output_codual) - dUtrunc_, dStrunc_, dVᴴtrunc_, dϵ = Mooncake.tangent(output_codual) - abs(dy[4]) > MatrixAlgebraKit.defaulttol(dy[4]) && @warn "Pullback for svd_trunc! does not yet support non-zero tangent for the truncation error" - U, dU = arrayify(Utrunc, dUtrunc_) - S, dS = arrayify(Strunc, dStrunc_) - Vᴴ, dVᴴ = arrayify(Vᴴtrunc, dVᴴtrunc_) - svd_trunc_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ)) - MatrixAlgebraKit.zero!(dU) - MatrixAlgebraKit.zero!(dS) - MatrixAlgebraKit.zero!(dVᴴ) - return NoRData(), NoRData(), NoRData() - end - return output_codual, svd_trunc_adjoint + A, dA = arrayify(A_dA) + U, S, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg)) + S_dS = Mooncake.zero_dual(diagview(S)) + S_, dS = arrayify(S_dS) + copyto!(dS, diag(real.(Vᴴ * dA' * U))) + return S_dS end end diff --git a/src/MatrixAlgebraKit.jl b/src/MatrixAlgebraKit.jl index d3a04741..d329f579 100644 --- a/src/MatrixAlgebraKit.jl +++ b/src/MatrixAlgebraKit.jl @@ -113,4 +113,11 @@ include("pullbacks/eigh.jl") include("pullbacks/svd.jl") include("pullbacks/polar.jl") +include("pushforwards/qr.jl") +include("pushforwards/lq.jl") +include("pushforwards/eig.jl") +include("pushforwards/eigh.jl") +include("pushforwards/polar.jl") +include("pushforwards/svd.jl") + end diff --git a/src/common/view.jl b/src/common/view.jl index c8ae1aa5..0bc7b9ef 100644 --- a/src/common/view.jl +++ b/src/common/view.jl @@ -1,5 +1,5 @@ # diagind: provided by LinearAlgebra.jl -diagview(D::Diagonal) = D.diag +diagview(D::Diagonal) = D.diag diagview(D::AbstractMatrix) = view(D, diagind(D)) # triangularind diff --git a/src/implementations/eigh.jl b/src/implementations/eigh.jl index a45300dc..527f5425 100644 --- a/src/implementations/eigh.jl +++ b/src/implementations/eigh.jl @@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real = end function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm) - check_hermitian(A, alg) + #check_hermitian(A, alg) D, V = DV m = size(A, 1) @assert D isa Diagonal && V isa AbstractMatrix diff --git a/src/pullbacks/eig.jl b/src/pullbacks/eig.jl index ff0de512..5c16c03f 100644 --- a/src/pullbacks/eig.jl +++ b/src/pullbacks/eig.jl @@ -46,7 +46,8 @@ function eig_pullback!( Δgauge ≤ gauge_atol || @warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)" - VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol)) + VᴴΔV ./= conj.(transpose(D) .- D) + diagview(VᴴΔV) .= zero(eltype(VᴴΔV)) if !iszerotangent(ΔDmat) ΔDvec = diagview(ΔDmat) diff --git a/src/pushforwards/eig.jl b/src/pushforwards/eig.jl new file mode 100644 index 00000000..66094111 --- /dev/null +++ b/src/pushforwards/eig.jl @@ -0,0 +1,15 @@ +function eig_pushforward!(ΔA, A, DV, ΔDV; kwargs...) + D, V = DV + ΔD, ΔV = ΔDV + iVΔAV = inv(V) * ΔA * V + diagview(ΔD) .= diagview(iVΔAV) + if !iszerotangent(ΔV) + F = 1 ./ (transpose(diagview(D)) .- diagview(D)) + fill!(diagview(F), zero(eltype(F))) + K̇ = F .* iVΔAV + mul!(ΔV, V, K̇, 1, 0) + end + return ΔDV +end + +function eig_trunc_pushforward!(ΔA, A, DV, ΔDV; kwargs...) end diff --git a/src/pushforwards/eigh.jl b/src/pushforwards/eigh.jl new file mode 100644 index 00000000..edf418ad --- /dev/null +++ b/src/pushforwards/eigh.jl @@ -0,0 +1,19 @@ +function eigh_pushforward!(dA, A, DV, dDV; kwargs...) + D, V = DV + dD, dV = dDV + tmpV = V \ dA + ∂K = tmpV * V + ∂Kdiag = diag(∂K) + diagview(dD) .= real.(∂Kdiag) + if !iszerotangent(dV) + dDD = transpose(diagview(D)) .- diagview(D) + F = one(eltype(dDD)) ./ dDD + diagview(F) .= zero(eltype(F)) + ∂K .*= F + ∂V = mul!(tmpV, V, ∂K) + copyto!(dV, ∂V) + end + return (dD, dV) +end + +function eigh_trunc_pushforward!(dA, A, DV, dDV; kwargs...) end diff --git a/src/pushforwards/lq.jl b/src/pushforwards/lq.jl new file mode 100644 index 00000000..2f5c0e57 --- /dev/null +++ b/src/pushforwards/lq.jl @@ -0,0 +1,7 @@ +function lq_pushforward!(dA, A, LQ, dLQ; tol::Real=default_pullback_gauge_atol(LQ[1]), rank_atol::Real=tol, gauge_atol::Real=tol) + qr_pushforward!(adjoint(dA), adjoint(A), adjoint.(reverse(LQ)), adjoint.(reverse(dLQ)); tol, rank_atol, gauge_atol) +end + +function lq_null_pushforward!(dA, A, Nᴴ, dNᴴ; tol::Real=default_pullback_gauge_atol(Nᴴ), rank_atol::Real=tol, gauge_atol::Real=tol) + iszero(min(size(Nᴴ)...)) && return # nothing to do +end diff --git a/src/pushforwards/polar.jl b/src/pushforwards/polar.jl new file mode 100644 index 00000000..e8f89bb5 --- /dev/null +++ b/src/pushforwards/polar.jl @@ -0,0 +1,21 @@ +function left_polar_pushforward!(ΔA, A, WP, ΔWP; kwargs...) + W, P = WP + ΔW, ΔP = ΔWP + aWdA = adjoint(W) * ΔA + K̇ = sylvester(P, P, -(aWdA - adjoint(aWdA))) + L̇ = (Diagonal(ones(eltype(W), size(W, 1))) - W*adjoint(W))*ΔA*inv(P) + ΔW .= W * K̇ + L̇ + ΔP .= aWdA - K̇*P + return (ΔW, ΔP) +end + +function right_polar_pushforward!(ΔA, A, PWᴴ, ΔPWᴴ; kwargs...) + P, Wᴴ = PWᴴ + ΔP, ΔWᴴ = ΔPWᴴ + dAW = ΔA * adjoint(Wᴴ) + K̇ = sylvester(P, P, -(dAW - adjoint(dAW))) + L̇ = inv(P)*ΔA*(Diagonal(ones(eltype(Wᴴ), size(Wᴴ, 2))) - adjoint(Wᴴ) * Wᴴ) + ΔWᴴ .= K̇ * Wᴴ + L̇ + ΔP .= dAW - P * K̇ + return (ΔWᴴ, ΔP) +end diff --git a/src/pushforwards/qr.jl b/src/pushforwards/qr.jl new file mode 100644 index 00000000..aba05b0a --- /dev/null +++ b/src/pushforwards/qr.jl @@ -0,0 +1,61 @@ +function qr_pushforward!(dA, A, QR, dQR; tol::Real=default_pullback_gauge_atol(QR[2]), rank_atol::Real=tol, gauge_atol::Real=tol) + Q, R = QR + m = size(A, 1) + n = size(A, 2) + minmn = min(m, n) + Rd = diagview(R) + p = findlast(>=(rank_atol) ∘ abs, Rd) + + m1 = p + m2 = minmn - p + m3 = m - minmn + n1 = p + n2 = n - p + + Q1 = view(Q, 1:m, 1:m1) # full rank portion + Q2 = view(Q, 1:m, m1+1:m2+m1) + R11 = view(R, 1:m1, 1:n1) + R12 = view(R, 1:m1, n1+1:n) + + dA1 = view(dA, 1:m, 1:n1) + dA2 = view(dA, 1:m, (n1 + 1):n) + + dQ, dR = dQR + dQ1 = view(dQ, 1:m, 1:m1) + dQ2 = view(dQ, 1:m, m1+1:m2+m1) + dQ3 = minmn+1 < size(dQ, 2) ? view(dQ, :, minmn+1:size(dQ,2)) : similar(dQ, eltype(dQ), (0, 0)) + dR11 = view(dR, 1:m1, 1:n1) + dR12 = view(dR, 1:m1, n1+1:n) + dR22 = view(dR, m1+1:m1+m2, n1+1:n) + + # fwd rule for Q1 and R11 -- for a non-rank redeficient QR, this is all we need + invR11 = inv(R11) + tmp = Q1' * dA1 * invR11 + Rtmp = tmp + tmp' + diagview(Rtmp) ./= 2 + ltRtmp = view(Rtmp, lowertriangularind(Rtmp)) + ltRtmp .= zero(eltype(Rtmp)) + dR11 .= Rtmp * R11 + dQ1 .= dA1 * invR11 - Q1 * dR11 * invR11 + dR12 .= adjoint(Q1) * (dA2 - dQ1 * R12) + if size(Q2, 2) > 0 + dQ2 .= -Q1 * (Q1' * Q2) + dQ2 .+= Q2 * (Q2' * dQ2) + end + if m3 > 0 && size(Q, 2) > minmn + # only present for qr_full or rank-deficient qr_compact + Q′ = view(Q, :, 1:minmn) + Q3 = view(Q, :, minmn+1:m) + #dQ3 .= Q′ * (Q′' * Q3) + dQ3 .= Q3 + end + if !isempty(dR22) + _, r22 = qr_compact(dA2 - dQ1*R12 - Q1*dR12; positive=true) + dR22 .= view(r22, 1:size(dR22, 1), 1:size(dR22, 2)) + end + return (dQ, dR) +end + +function qr_null_pushforward!(dA, A, N, dN; tol::Real=default_pullback_gauge_atol(N), rank_atol::Real=tol, gauge_atol::Real=tol) + iszero(min(size(N)...)) && return # nothing to do +end diff --git a/src/pushforwards/svd.jl b/src/pushforwards/svd.jl new file mode 100644 index 00000000..d8591d94 --- /dev/null +++ b/src/pushforwards/svd.jl @@ -0,0 +1,82 @@ +function svd_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ; rank_atol=default_pullback_rank_atol(A), kwargs...) + U, Smat, Vᴴ = USVᴴ + m, n = size(U, 1), size(Vᴴ, 2) + (m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)")) + minmn = min(m, n) + S = diagview(Smat) + ΔU, ΔS, ΔVᴴ = ΔUSVᴴ + r = searchsortedlast(S, rank_atol; rev = true) # rank + + vΔU = view(ΔU, :, 1:r) + vΔS = view(ΔS, 1:r, 1:r) + vΔVᴴ = view(ΔVᴴ, 1:r, :) + + vU = view(U, :, 1:r) + vS = view(S, 1:r) + vSmat = view(Smat, 1:r, 1:r) + vVᴴ = view(Vᴴ, 1:r, :) + + # compact region + vV = adjoint(vVᴴ) + UΔAV = vU' * ΔA * vV + copyto!(diagview(vΔS), diag(real.(UΔAV))) + F = one(eltype(S)) ./ (transpose(vS) .- vS) + G = one(eltype(S)) ./ (transpose(vS) .+ vS) + diagview(F) .= zero(eltype(F)) + hUΔAV = F .* (UΔAV + UΔAV') ./ 2 + aUΔAV = G .* (UΔAV - UΔAV') ./ 2 + K̇ = hUΔAV + aUΔAV + Ṁ = hUΔAV - aUΔAV + + # check gauge condition + @assert isantihermitian(K̇) + @assert isantihermitian(Ṁ) + K̇diag = diagview(K̇) + for i in 1:length(K̇diag) + @assert K̇diag[i] ≈ (im/2) * imag(diagview(UΔAV)[i])/S[i] + end + + ∂U = vU * K̇ + ∂V = vV * Ṁ + # full component + if size(U, 2) > minmn && size(Vᴴ, 1) > minmn + Uperp = view(U, :, minmn+1:m) + Vᴴperp = view(Vᴴ, minmn+1:n, :) + + aUAV = adjoint(Uperp) * A * adjoint(Vᴴperp) + + UÃÃV = similar(A, (size(aUAV, 1) + size(aUAV, 2), size(aUAV, 1) + size(aUAV, 2))) + fill!(UÃÃV, 0) + view(UÃÃV, (1:size(aUAV, 1)), size(aUAV, 1) .+ (1:size(aUAV, 2))) .= aUAV + view(UÃÃV, size(aUAV, 1) .+ (1:size(aUAV, 2)), 1:size(aUAV, 1)) .= aUAV' + rhs = vcat( adjoint(Uperp, ΔA, V), Vᴴperp * ΔA' * U) + superKM = -sylvester(UÃÃV, Smat, rhs) + K̇perp = view(superKM, 1:size(aUAV, 2)) + Ṁperp = view(superKM, size(aUAV, 2)+1:size(aUAV, 1)+size(aUAV, 2)) + ∂U .+= Uperp * K̇perp + ∂V .+= Vperp * Ṁperp + else + ImUU = (LinearAlgebra.diagm(ones(eltype(U), m)) - vU*vU') + ImVV = (LinearAlgebra.diagm(ones(eltype(Vᴴ), n)) - vV*vVᴴ) + upper = ImUU * ΔA * vV + lower = ImVV * ΔA' * vU + rhs = vcat(upper, lower) + + Ã = ImUU * A * ImVV + ÃÃ = similar(A, (m + n, m + n)) + fill!(ÃÃ, 0) + view(ÃÃ, (1:m), m .+ (1:n)) .= Ã + view(ÃÃ, m .+ (1:n), 1:m ) .= Ã' + + superLN = -sylvester(ÃÃ, vSmat, rhs) + ∂U += view(superLN, 1:size(upper, 1), :) + ∂V += view(superLN, size(upper, 1)+1:size(upper,1)+size(lower,1), :) + end + copyto!(vΔU, ∂U) + adjoint!(vΔVᴴ, ∂V) + return (ΔU, ΔS, ΔVᴴ) +end + +function svd_trunc_pushforward!(ΔA, A, USVᴴ, ΔUSVᴴ, ind; rank_atol=default_pullback_rank_atol(A), kwargs...) + +end diff --git a/test/mooncake.jl b/test/mooncake.jl index 7aed68cf..23e5b077 100644 --- a/test/mooncake.jl +++ b/test/mooncake.jl @@ -22,7 +22,7 @@ make_mooncake_tangent(T::Tuple) = Mooncake.build_tangent(typeof(T), T...) make_mooncake_fdata(x) = make_mooncake_tangent(x) make_mooncake_fdata(x::Diagonal) = Mooncake.FData((diag = make_mooncake_tangent(x.diag),)) -ETs = (Float64, Float32, ComplexF64, ComplexF32) +ETs = (Float64, ComplexF64,)# Float32,)# ComplexF64, ComplexF32) # no `alg` argument function _get_copying_derivative(f_c, rrule, A, ΔA, args, Δargs, ::Nothing, rdata) @@ -119,8 +119,8 @@ end rng = StableRNG(12345) m = 19 @testset "size ($m, $n)" for n in (17, m, 23) - atol = rtol = m * n * precision(T) - A = randn(rng, T, m, n) + atol = rtol = m * n * precision(T) + A = randn(rng, T, m, n) minmn = min(m, n) @testset for alg in ( LAPACK_HouseholderQR(), @@ -128,9 +128,9 @@ end ) @testset "qr_compact" begin QR = qr_compact(A, alg) - Q = randn(rng, T, m, minmn) - R = randn(rng, T, minmn, n) - Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Q = randn(rng, T, m, minmn) + R = randn(rng, T, minmn, n) + Mooncake.TestUtils.test_rule(rng, qr_compact, A, alg; atol=atol, rtol=rtol) test_pullbacks_match(rng, qr_compact!, qr_compact, A, (Q, R), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) end @testset "qr_null" begin @@ -138,38 +138,46 @@ end ΔN = Q * randn(rng, T, minmn, max(0, m - minmn)) N = qr_null(A, alg) dN = make_mooncake_tangent(copy(ΔN)) - Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dN, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, qr_null, A, alg; output_tangent = dN, atol=atol, rtol=rtol) test_pullbacks_match(rng, qr_null!, qr_null, A, N, ΔN, alg) end @testset "qr_full" begin Q, R = qr_full(A, alg) - Q1 = view(Q, 1:m, 1:minmn) - ΔQ = randn(rng, T, m, m) - ΔQ2 = view(ΔQ, :, (minmn + 1):m) + Q1 = view(Q, 1:m, 1:minmn) + ΔQ = randn(rng, T, m, m) + ΔQ2 = view(ΔQ, :, (minmn + 1):m) mul!(ΔQ2, Q1, Q1' * ΔQ2) - ΔR = randn(rng, T, m, n) - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) + ΔR = randn(rng, T, m, n) + dQ = make_mooncake_tangent(copy(ΔQ)) + dR = make_mooncake_tangent(copy(ΔR)) + dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) + #Mooncake.TestUtils.test_rule(rng, qr_full, A, alg; output_tangent = dQR, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, 1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_full(A, alg)[1][1:m, minmn+1:m]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) test_pullbacks_match(rng, qr_full!, qr_full, A, (Q, R), (ΔQ, ΔR), alg) end @testset "qr_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) Q, R = qr_compact(Ard, alg) - QR = (Q, R) - ΔQ = randn(rng, T, m, minmn) - Q1 = view(Q, 1:m, 1:r) - Q2 = view(Q, 1:m, (r + 1):minmn) - ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) + QR = (Q, R) + ΔQ = randn(rng, T, m, minmn) + Q1 = view(Q, 1:m, 1:r) + Q2 = view(Q, 1:m, (r + 1):minmn) + ΔQ2 = view(ΔQ, 1:m, (r + 1):minmn) ΔQ2 .= 0 - ΔR = randn(rng, T, minmn, n) + ΔR = randn(rng, T, minmn, n) view(ΔR, (r + 1):minmn, :) .= 0 - dQ = make_mooncake_tangent(copy(ΔQ)) - dR = make_mooncake_tangent(copy(ΔR)) - dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) - Mooncake.TestUtils.test_rule(rng, qr_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dQR, is_primitive = false, atol = atol, rtol = rtol) + dQ = make_mooncake_tangent(copy(ΔQ)) + dR = make_mooncake_tangent(copy(ΔR)) + dQR = Mooncake.build_tangent(typeof((ΔQ, ΔR)), dQ, dR) + Mooncake.TestUtils.test_rule(rng, qr_compact, copy(Ard), alg; output_tangent = dQR, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[2]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][1:r, 1:r]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][r+1:m, 1:r]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][1:r, r+1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) + Mooncake.TestUtils.test_rule(rng, ((A, alg)->qr_compact(A, alg)[1][r+1:m, r+1:minmn]), A, alg; mode=Mooncake.ForwardMode, is_primitive=false, atol=atol, rtol=rtol) test_pullbacks_match(rng, qr_compact!, qr_compact, Ard, (Q, R), (ΔQ, ΔR), alg) end end @@ -189,45 +197,45 @@ end ) @testset "lq_compact" begin L, Q = lq_compact(A, alg) - Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, lq_compact, A, alg; atol = atol, rtol = rtol) test_pullbacks_match(rng, lq_compact!, lq_compact, A, (L, Q), (randn(rng, T, m, minmn), randn(rng, T, minmn, n)), alg) end @testset "lq_null" begin L, Q = lq_compact(A, alg) - ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q - Nᴴ = randn(rng, T, max(0, n - minmn), n) - dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; mode = Mooncake.ReverseMode, output_tangent = dNᴴ, is_primitive = false, atol = atol, rtol = rtol) + ΔNᴴ = randn(rng, T, max(0, n - minmn), minmn) * Q + Nᴴ = randn(rng, T, max(0, n - minmn), n) + dNᴴ = make_mooncake_tangent(ΔNᴴ) + Mooncake.TestUtils.test_rule(rng, lq_null, A, alg; output_tangent = dNᴴ, atol = atol, rtol = rtol) test_pullbacks_match(rng, lq_null!, lq_null, A, Nᴴ, ΔNᴴ, alg) end @testset "lq_full" begin L, Q = lq_full(A, alg) - Q1 = view(Q, 1:minmn, 1:n) - ΔQ = randn(rng, T, n, n) - ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) + Q1 = view(Q, 1:minmn, 1:n) + ΔQ = randn(rng, T, n, n) + ΔQ2 = view(ΔQ, (minmn + 1):n, 1:n) mul!(ΔQ2, ΔQ2 * Q1', Q1) - ΔL = randn(rng, T, m, n) - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) + ΔL = randn(rng, T, m, n) + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_full, A, alg; output_tangent = dLQ, atol = atol, rtol = rtol) test_pullbacks_match(rng, lq_full!, lq_full, A, (L, Q), (ΔL, ΔQ), alg) end @testset "lq_compact - rank-deficient A" begin - r = minmn - 5 - Ard = randn(rng, T, m, r) * randn(rng, T, r, n) + r = minmn - 5 + Ard = randn(rng, T, m, r) * randn(rng, T, r, n) L, Q = lq_compact(Ard, alg) - ΔL = randn(rng, T, m, minmn) - ΔQ = randn(rng, T, minmn, n) - Q1 = view(Q, 1:r, 1:n) - Q2 = view(Q, (r + 1):minmn, 1:n) - ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) + ΔL = randn(rng, T, m, minmn) + ΔQ = randn(rng, T, minmn, n) + Q1 = view(Q, 1:r, 1:n) + Q2 = view(Q, (r + 1):minmn, 1:n) + ΔQ2 = view(ΔQ, (r + 1):minmn, 1:n) ΔQ2 .= 0 view(ΔL, :, (r + 1):minmn) .= 0 - dL = make_mooncake_tangent(ΔL) - dQ = make_mooncake_tangent(ΔQ) - dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) - Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; mode = Mooncake.ReverseMode, output_tangent = dLQ, is_primitive = false, atol = atol, rtol = rtol) + dL = make_mooncake_tangent(ΔL) + dQ = make_mooncake_tangent(ΔQ) + dLQ = Mooncake.build_tangent(typeof((ΔL, ΔQ)), dL, dQ) + Mooncake.TestUtils.test_rule(rng, lq_compact, Ard, alg; output_tangent = dLQ, atol = atol, rtol = rtol) test_pullbacks_match(rng, lq_compact!, lq_compact, Ard, (L, Q), (ΔL, ΔQ), alg) end end @@ -253,14 +261,15 @@ end # compute the dA corresponding to the above dD, dV @testset for alg in (LAPACK_Simple(), LAPACK_Expert()) @testset "eig_full" begin - Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, eig_full, A, alg; output_tangent = dDV, atol = atol, rtol = rtol) test_pullbacks_match(rng, eig_full!, eig_full, A, (D, V), (ΔD2, ΔV), alg) end @testset "eig_vals" begin - Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, eig_vals, A, alg; atol = atol, rtol = rtol) test_pullbacks_match(rng, eig_vals!, eig_vals, A, D.diag, ΔD2.diag, alg) end @testset "eig_trunc" begin + Ah = (A + A')/2 for r in 1:4:m truncalg = TruncatedAlgorithm(alg, truncrank(r; by = abs)) ind = MatrixAlgebraKit.findtruncated(Ddiag, truncalg.trunc) @@ -271,7 +280,7 @@ end dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol) test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end truncalg = TruncatedAlgorithm(alg, truncrank(5; by = real)) @@ -283,7 +292,7 @@ end dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, eig_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol) test_pullbacks_match(rng, eig_trunc!, eig_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end end @@ -345,11 +354,11 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop LAPACK_MultipleRelativelyRobustRepresentations(), ) @testset "eigh_full" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, copy_eigh_full, A, alg; output_tangent = dDV, is_primitive = false, atol = atol, rtol = rtol) test_pullbacks_match(rng, copy_eigh_full!, copy_eigh_full, A, (D, V), (ΔD2, ΔV), alg) end @testset "eigh_vals" begin - Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, copy_eigh_vals, A, alg; is_primitive = false, atol = atol, rtol = rtol) test_pullbacks_match(rng, copy_eigh_vals!, copy_eigh_vals, A, D.diag, ΔD2.diag, alg) end @testset "eigh_trunc" begin @@ -363,7 +372,7 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end truncalg = TruncatedAlgorithm(alg, trunctol(; atol = maximum(abs, Ddiag) / 2)) @@ -375,7 +384,7 @@ MatrixAlgebraKit.copy_input(::typeof(copy_eigh_trunc), A) = MatrixAlgebraKit.cop dDtrunc = make_mooncake_tangent(ΔDtrunc) dVtrunc = make_mooncake_tangent(ΔVtrunc) dDVtrunc = Mooncake.build_tangent(typeof((ΔDtrunc, ΔVtrunc, zero(real(T)))), dDtrunc, dVtrunc, zero(real(T))) - Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, copy_eigh_trunc, A, truncalg; output_tangent = dDVtrunc, atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, copy_eigh_trunc!, copy_eigh_trunc, A, (D, V), (ΔD2, ΔV), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end end @@ -403,7 +412,7 @@ end dU = make_mooncake_tangent(ΔU) dVᴴ = make_mooncake_tangent(ΔVᴴ) dUSVᴴ = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_compact, A, alg; output_tangent = dUSVᴴ, atol = atol, rtol = rtol) test_pullbacks_match(rng, svd_compact!, svd_compact, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), alg) end @testset "svd_full" begin @@ -424,11 +433,11 @@ end dU = make_mooncake_tangent(ΔUfull) dVᴴ = make_mooncake_tangent(ΔVᴴfull) dUSVᴴ = Mooncake.build_tangent(typeof((ΔUfull, ΔSfull, ΔVᴴfull)), dU, dS, dVᴴ) - Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴ, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_full, A, alg; output_tangent = dUSVᴴ, atol = atol, rtol = rtol) test_pullbacks_match(rng, svd_full!, svd_full, A, (U, S, Vᴴ), (ΔUfull, ΔSfull, ΔVᴴfull), alg) end @testset "svd_vals" begin - Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_vals, A, alg; atol = atol, rtol = rtol) S = svd_vals(A, alg) test_pullbacks_match(rng, svd_vals!, svd_vals, A, S, randn(rng, real(T), minmn), alg) end @@ -439,44 +448,44 @@ end ΔS = randn(rng, real(T), minmn, minmn) ΔS2 = Diagonal(randn(rng, real(T), minmn)) ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) truncalg = TruncatedAlgorithm(alg, truncrank(r)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) + dStrunc = make_mooncake_tangent(ΔStrunc) + dUtrunc = make_mooncake_tangent(ΔUtrunc) dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) + ϵ = zero(real(T)) dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end @testset "trunctol" begin U, S, Vᴴ = svd_compact(A) - ΔU = randn(rng, T, m, minmn) - ΔS = randn(rng, real(T), minmn, minmn) - ΔS2 = Diagonal(randn(rng, real(T), minmn)) - ΔVᴴ = randn(rng, T, minmn, n) - ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) + ΔU = randn(rng, T, m, minmn) + ΔS = randn(rng, real(T), minmn, minmn) + ΔS2 = Diagonal(randn(rng, real(T), minmn)) + ΔVᴴ = randn(rng, T, minmn, n) + ΔU, ΔVᴴ = remove_svdgauge_dependence!(ΔU, ΔVᴴ, U, S, Vᴴ; degeneracy_atol = atol) truncalg = TruncatedAlgorithm(alg, trunctol(atol = S[1, 1] / 2)) - ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) - Strunc = Diagonal(diagview(S)[ind]) - Utrunc = U[:, ind] - Vᴴtrunc = Vᴴ[ind, :] - ΔStrunc = Diagonal(diagview(ΔS2)[ind]) - ΔUtrunc = ΔU[:, ind] + ind = MatrixAlgebraKit.findtruncated(diagview(S), truncalg.trunc) + Strunc = Diagonal(diagview(S)[ind]) + Utrunc = U[:, ind] + Vᴴtrunc = Vᴴ[ind, :] + ΔStrunc = Diagonal(diagview(ΔS2)[ind]) + ΔUtrunc = ΔU[:, ind] ΔVᴴtrunc = ΔVᴴ[ind, :] - dStrunc = make_mooncake_tangent(ΔStrunc) - dUtrunc = make_mooncake_tangent(ΔUtrunc) + dStrunc = make_mooncake_tangent(ΔStrunc) + dUtrunc = make_mooncake_tangent(ΔUtrunc) dVᴴtrunc = make_mooncake_tangent(ΔVᴴtrunc) - ϵ = zero(real(T)) + ϵ = zero(real(T)) dUSVᴴerr = Mooncake.build_tangent(typeof((ΔU, ΔS2, ΔVᴴ, ϵ)), dUtrunc, dStrunc, dVᴴtrunc, ϵ) - Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; mode = Mooncake.ReverseMode, output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, svd_trunc, A, truncalg; output_tangent = dUSVᴴerr, atol = atol, rtol = rtol) test_pullbacks_match(rng, svd_trunc!, svd_trunc, A, (U, S, Vᴴ), (ΔU, ΔS2, ΔVᴴ), truncalg; rdata = (Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), zero(real(T)))) end end @@ -493,11 +502,11 @@ end @testset for alg in PolarViaSVD.((LAPACK_QRIteration(), LAPACK_DivideAndConquer())) if m >= n WP = left_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, left_polar, A, alg; atol = atol, rtol = rtol) test_pullbacks_match(rng, left_polar!, left_polar, A, WP, (randn(rng, T, m, n), randn(rng, T, n, n)), alg) elseif m <= n PWᴴ = right_polar(A, alg) - Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; mode = Mooncake.ReverseMode, is_primitive = false, atol = atol, rtol = rtol) + Mooncake.TestUtils.test_rule(rng, right_polar, A, alg; atol = atol, rtol = rtol) test_pullbacks_match(rng, right_polar!, right_polar, A, PWᴴ, (randn(rng, T, m, m), randn(rng, T, m, n)), alg) end end @@ -511,12 +520,12 @@ right_orth_lq(X) = right_orth(X; alg = :lq) right_orth_polar(X) = right_orth(X; alg = :polar) right_null_lq(X) = right_null(X; alg = :lq) -MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) -MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) -MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_orth_qr), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_orth_polar), A) = MatrixAlgebraKit.copy_input(left_orth, A) +MatrixAlgebraKit.copy_input(::typeof(left_null_qr), A) = MatrixAlgebraKit.copy_input(left_null, A) +MatrixAlgebraKit.copy_input(::typeof(right_orth_lq), A) = MatrixAlgebraKit.copy_input(right_orth, A) MatrixAlgebraKit.copy_input(::typeof(right_orth_polar), A) = MatrixAlgebraKit.copy_input(right_orth, A) -MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) +MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_input(right_null, A) @timedtestset "Orth and null with eltype $T" for T in ETs rng = StableRNG(12345) @@ -526,36 +535,36 @@ MatrixAlgebraKit.copy_input(::typeof(right_null_lq), A) = MatrixAlgebraKit.copy_ A = randn(rng, T, m, n) VC = left_orth(A) CVᴴ = right_orth(A) - Mooncake.TestUtils.test_rule(rng, left_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth, A; atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, left_orth!, left_orth, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) - Mooncake.TestUtils.test_rule(rng, right_orth, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth, A; atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, right_orth!, right_orth, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) - Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth_qr, A; atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :qr)), left_orth_qr, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) if m >= n - Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, left_orth_polar, A; atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, ((X, VC) -> left_orth!(X, VC; alg = :polar)), left_orth_polar, A, VC, (randn(rng, T, size(VC[1])...), randn(rng, T, size(VC[2])...))) end N = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) ΔN = left_orth(A; alg = :qr)[1] * randn(rng, T, min(m, n), m - min(m, n)) dN = make_mooncake_tangent(ΔN) - Mooncake.TestUtils.test_rule(rng, left_null_qr, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) + Mooncake.TestUtils.test_rule(rng, left_null_qr, A; atol = atol, rtol = rtol, is_primitive = false, output_tangent = dN) test_pullbacks_match(rng, ((X, N) -> left_null!(X, N; alg = :qr)), left_null_qr, A, N, ΔN) - Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth_lq, A; atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :lq)), right_orth_lq, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) if m <= n - Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false) + Mooncake.TestUtils.test_rule(rng, right_orth_polar, A; atol = atol, rtol = rtol, is_primitive = false) test_pullbacks_match(rng, ((X, CVᴴ) -> right_orth!(X, CVᴴ; alg = :polar)), right_orth_polar, A, CVᴴ, (randn(rng, T, size(CVᴴ[1])...), randn(rng, T, size(CVᴴ[2])...))) end Nᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] ΔNᴴ = randn(rng, T, n - min(m, n), min(m, n)) * right_orth(A; alg = :lq)[2] dNᴴ = make_mooncake_tangent(ΔNᴴ) - Mooncake.TestUtils.test_rule(rng, right_null_lq, A; mode = Mooncake.ReverseMode, atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) + Mooncake.TestUtils.test_rule(rng, right_null_lq, A; atol = atol, rtol = rtol, is_primitive = false, output_tangent = dNᴴ) test_pullbacks_match(rng, ((X, Nᴴ) -> right_null!(X, Nᴴ; alg = :lq)), right_null_lq, A, Nᴴ, ΔNᴴ) end end diff --git a/test/runtests.jl b/test/runtests.jl index 4b69a3dc..a9f7f69e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using SafeTestsets # specific ones is_buildkite = get(ENV, "BUILDKITE", "false") == "true" if !is_buildkite - @safetestset "Algorithms" begin + #=@safetestset "Algorithms" begin include("algorithms.jl") end @safetestset "Projections" begin @@ -38,9 +38,11 @@ if !is_buildkite @safetestset "Image and Null Space" begin include("orthnull.jl") end + =# @safetestset "Mooncake" begin include("mooncake.jl") end + #= @safetestset "ChainRules" begin include("chainrules.jl") end @@ -55,7 +57,7 @@ if !is_buildkite using JET JET.test_package(MatrixAlgebraKit; target_defined_modules = true) end - end + end=# end using CUDA