Skip to content

Commit c665fc4

Browse files
committed
Mooncake reverse rules
1 parent 4fbc3bf commit c665fc4

File tree

10 files changed

+642
-47
lines changed

10 files changed

+642
-47
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
1111
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1212
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
13+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
1314

1415
[extensions]
1516
MatrixAlgebraKitChainRulesCoreExt = "ChainRulesCore"
1617
MatrixAlgebraKitAMDGPUExt = "AMDGPU"
1718
MatrixAlgebraKitCUDAExt = "CUDA"
19+
MatrixAlgebraKitMooncakeExt = "Mooncake"
1820

1921
[compat]
2022
AMDGPU = "2"
@@ -24,6 +26,7 @@ ChainRulesTestUtils = "1"
2426
CUDA = "5"
2527
JET = "0.9, 0.10"
2628
LinearAlgebra = "1"
29+
Mooncake = "0.4.167"
2730
SafeTestsets = "0.1"
2831
StableRNGs = "1"
2932
Test = "1"
@@ -35,11 +38,12 @@ julia = "1.10"
3538
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
3639
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
3740
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
41+
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
3842
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3943
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
4044
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4145
TestExtras = "5ed8adda-3752-4e41-b88a-e8b09835ee3a"
4246
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
4347

4448
[targets]
45-
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU"]
49+
test = ["Aqua", "JET", "SafeTestsets", "Test", "TestExtras","ChainRulesCore", "ChainRulesTestUtils", "StableRNGs", "Zygote", "CUDA", "AMDGPU", "Mooncake"]
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
module MatrixAlgebraKitMooncakeExt
2+
3+
using Mooncake
4+
using Mooncake: DefaultCtx, CoDual, Dual, NoRData, rrule!!, frule!!, arrayify, @is_primitive
5+
using MatrixAlgebraKit
6+
using MatrixAlgebraKit: inv_safe, diagview
7+
using MatrixAlgebraKit: qr_pullback!, lq_pullback!
8+
using MatrixAlgebraKit: qr_null_pullback!, lq_null_pullback!
9+
using MatrixAlgebraKit: eig_pullback!, eigh_pullback!
10+
using MatrixAlgebraKit: left_polar_pullback!, right_polar_pullback!
11+
using LinearAlgebra
12+
13+
# two-argument factorizations like LQ, QR, EIG
14+
for (f, pb, adj) in (
15+
(qr_full!, qr_pullback!, :dqr_adjoint),
16+
(qr_compact!, qr_pullback!, :dqr_adjoint),
17+
(lq_full!, lq_pullback!, :dlq_adjoint),
18+
(lq_compact!, lq_pullback!, :dlq_adjoint),
19+
(eig_full!, eig_pullback!, :deig_adjoint),
20+
(eigh_full!, eigh_pullback!, :deigh_adjoint),
21+
(left_polar!, left_polar_pullback!, :dleft_polar_adjoint),
22+
(right_polar!, right_polar_pullback!, :dright_polar_adjoint),
23+
)
24+
25+
@eval begin
26+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
27+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, args_dargs::CoDual, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
28+
A, dA = arrayify(A_dA)
29+
dA .= zero(eltype(A))
30+
args = Mooncake.primal(args_dargs)
31+
dargs = Mooncake.tangent(args_dargs)
32+
arg1, darg1 = arrayify(args[1], dargs[1])
33+
arg2, darg2 = arrayify(args[2], dargs[2])
34+
function $adj(::Mooncake.NoRData)
35+
dA = $pb(dA, A, (arg1, arg2), (darg1, darg2); kwargs...)
36+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
37+
end
38+
args = $f(A, args, Mooncake.primal(alg_dalg); kwargs...)
39+
darg1 .= zero(eltype(arg1))
40+
darg2 .= zero(eltype(arg2))
41+
return Mooncake.CoDual(args, dargs), $adj
42+
end
43+
end
44+
end
45+
46+
for (f, f_full, pb, adj) in (
47+
(qr_null!, qr_full, qr_null_pullback!, :dqr_null_adjoint),
48+
(lq_null!, lq_full, lq_null_pullback!, :dlq_null_adjoint),
49+
)
50+
@eval begin
51+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, AbstractMatrix, MatrixAlgebraKit.AbstractAlgorithm}
52+
function Mooncake.rrule!!(f_df::CoDual{typeof($f)}, A_dA::CoDual{<:AbstractMatrix}, arg_darg::CoDual{<:AbstractMatrix}, alg_dalg::CoDual{<:MatrixAlgebraKit.AbstractAlgorithm}; kwargs...)
53+
A, dA = arrayify(A_dA)
54+
Ac = MatrixAlgebraKit.copy_input($f_full, A)
55+
arg, darg = arrayify(Mooncake.primal(arg_darg), Mooncake.tangent(arg_darg))
56+
arg = $f(Ac, arg, Mooncake.primal(alg_dalg))
57+
function $adj(::Mooncake.NoRData)
58+
dA .= zero(eltype(A))
59+
$pb(dA, A, arg, darg; kwargs...)
60+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
61+
end
62+
return arg_darg, $adj
63+
end
64+
end
65+
end
66+
67+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eig_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
68+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eig_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
69+
# compute primal
70+
D_ = Mooncake.primal(D_dD)
71+
dD_ = Mooncake.tangent(D_dD)
72+
A_ = Mooncake.primal(A_dA)
73+
dA_ = Mooncake.tangent(A_dA)
74+
A, dA = arrayify(A_, dA_)
75+
D, dD = arrayify(D_, dD_)
76+
dA .= zero(eltype(dA))
77+
# update primal
78+
DV = eig_full(A, Mooncake.primal(alg_dalg); kwargs...)
79+
V = DV[2]
80+
dD .= zero(eltype(D))
81+
function deig_vals_adjoint(::Mooncake.NoRData)
82+
PΔV = V' \ Diagonal(dD)
83+
if eltype(dA) <: Real
84+
ΔAc = PΔV * V'
85+
dA .+= real.(ΔAc)
86+
else
87+
mul!(dA, PΔV, V', 1, 0)
88+
end
89+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
90+
end
91+
return Mooncake.CoDual(DV[1].diag, dD_), deig_vals_adjoint
92+
end
93+
94+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.eigh_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
95+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.eigh_vals!)}, A_dA::CoDual, D_dD::CoDual, alg_dalg::CoDual; kwargs...)
96+
# compute primal
97+
D_ = Mooncake.primal(D_dD)
98+
dD_ = Mooncake.tangent(D_dD)
99+
A_ = Mooncake.primal(A_dA)
100+
dA_ = Mooncake.tangent(A_dA)
101+
A, dA = arrayify(A_, dA_)
102+
D, dD = arrayify(D_, dD_)
103+
DV = eigh_full(A, Mooncake.primal(alg_dalg); kwargs...)
104+
function deigh_vals_adjoint(::Mooncake.NoRData)
105+
mul!(dA, DV[2] * Diagonal(real(dD)), DV[2]', 1, 0)
106+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
107+
end
108+
return Mooncake.CoDual(DV[1].diag, dD_), deigh_vals_adjoint
109+
end
110+
111+
112+
for (f, St) in ((svd_full!, :AbstractMatrix), (svd_compact!, :Diagonal))
113+
@eval begin
114+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof($f), AbstractMatrix, Tuple{<:AbstractMatrix, <:$St, <:AbstractMatrix}, MatrixAlgebraKit.AbstractAlgorithm}
115+
function Mooncake.rrule!!(::CoDual{typeof($f)}, A_dA::CoDual, USVᴴ_dUSVᴴ::CoDual, alg_dalg::CoDual; kwargs...)
116+
A, dA = arrayify(A_dA)
117+
USVᴴ = Mooncake.primal(USVᴴ_dUSVᴴ)
118+
dUSVᴴ = Mooncake.tangent(USVᴴ_dUSVᴴ)
119+
U, dU = arrayify(USVᴴ[1], dUSVᴴ[1])
120+
S, dS = arrayify(USVᴴ[2], dUSVᴴ[2])
121+
Vᴴ, dVᴴ = arrayify(USVᴴ[3], dUSVᴴ[3])
122+
USVᴴ = $f(A, USVᴴ, Mooncake.primal(alg_dalg); kwargs...)
123+
function dsvd_adjoint(::Mooncake.NoRData)
124+
dA .= zero(eltype(A))
125+
minmn = min(size(A)...)
126+
if size(U, 2) == size(Vᴴ, 1) == minmn # compact
127+
dA = MatrixAlgebraKit.svd_pullback!(dA, A, (U, S, Vᴴ), (dU, dS, dVᴴ))
128+
else # full
129+
vU = view(U, :, 1:minmn)
130+
vS = Diagonal(diagview(S)[1:minmn])
131+
vVᴴ = view(Vᴴ, 1:minmn, :)
132+
vdU = view(dU, :, 1:minmn)
133+
vdS = Diagonal(diagview(dS)[1:minmn])
134+
vdVᴴ = view(dVᴴ, 1:minmn, :)
135+
dA = MatrixAlgebraKit.svd_pullback!(dA, A, (vU, vS, vVᴴ), (vdU, vdS, vdVᴴ))
136+
end
137+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
138+
end
139+
return Mooncake.CoDual(USVᴴ, dUSVᴴ), dsvd_adjoint
140+
end
141+
end
142+
end
143+
144+
@is_primitive Mooncake.DefaultCtx Mooncake.ReverseMode Tuple{typeof(MatrixAlgebraKit.svd_vals!), AbstractMatrix, AbstractVector, MatrixAlgebraKit.AbstractAlgorithm}
145+
function Mooncake.rrule!!(::CoDual{<:typeof(MatrixAlgebraKit.svd_vals!)}, A_dA::CoDual, S_dS::CoDual, alg_dalg::CoDual; kwargs...)
146+
# compute primal
147+
S_ = Mooncake.primal(S_dS)
148+
dS_ = Mooncake.tangent(S_dS)
149+
A_ = Mooncake.primal(A_dA)
150+
dA_ = Mooncake.tangent(A_dA)
151+
A, dA = arrayify(A_, dA_)
152+
S, dS = arrayify(S_, dS_)
153+
U, nS, Vᴴ = svd_compact(A, Mooncake.primal(alg_dalg); kwargs...)
154+
S .= diagview(nS)
155+
dS .= zero(eltype(S))
156+
function dsvd_vals_adjoint(::Mooncake.NoRData)
157+
dA .= U * Diagonal(dS) * Vᴴ
158+
return Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData(), Mooncake.NoRData()
159+
end
160+
return S_dS, dsvd_vals_adjoint
161+
end
162+
163+
end

src/implementations/eigh.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ function check_hermitian(A; atol::Real = default_hermitian_tol(A), rtol::Real =
1919
end
2020

2121
function check_input(::typeof(eigh_full!), A::AbstractMatrix, DV, alg::AbstractAlgorithm)
22-
check_hermitian(A, alg)
22+
#check_hermitian(A, alg)
2323
D, V = DV
2424
m = size(A, 1)
2525
@assert D isa Diagonal && V isa AbstractMatrix

src/pullbacks/eig.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,8 @@ function eig_pullback!(
4848
Δgauge < gauge_atol ||
4949
@warn "`eig` cotangents sensitive to gauge choice: (|Δgauge| = $Δgauge)"
5050

51-
VᴴΔV .*= conj.(inv_safe.(transpose(D) .- D, degeneracy_atol))
51+
VᴴΔV ./= conj.(transpose(D) .- D)
52+
diagview(VᴴΔV) .= zero(eltype(VᴴΔV))
5253

5354
if !iszerotangent(ΔDmat)
5455
ΔDvec = diagview(ΔDmat)

src/pullbacks/polar.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `WP` and
55
cotangent `ΔWP` of `left_polar(A)`.
66
"""
7-
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP)
7+
function left_polar_pullback!(ΔA::AbstractMatrix, A, WP, ΔWP; kwargs...)
88
# Extract the Polar components
99
W, P = WP
1010

@@ -34,7 +34,7 @@ end
3434
Adds the pullback from the left polar decomposition of `A` to `ΔA` given the output `PWᴴ`
3535
and cotangent `ΔPWᴴ` of `right_polar(A)`.
3636
"""
37-
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ)
37+
function right_polar_pullback!(ΔA::AbstractMatrix, A, PWᴴ, ΔPWᴴ; kwargs...)
3838
# Extract the Polar components
3939
P, Wᴴ = PWᴴ
4040

src/pullbacks/svd.jl

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,13 @@ function svd_pullback!(
2828
degeneracy_atol::Real = tol,
2929
gauge_atol::Real = tol
3030
)
31-
3231
# Extract the SVD components
3332
U, Smat, Vᴴ = USVᴴ
3433
m, n = size(U, 1), size(Vᴴ, 2)
35-
(m, n) == size(ΔA) || throw(DimensionMismatch())
34+
(m, n) == size(ΔA) || throw(DimensionMismatch("size of ΔA ($(size(ΔA))) does not match size of U*S*Vᴴ ($m, $n)"))
3635
minmn = min(m, n)
3736
S = diagview(Smat)
38-
length(S) == minmn || throw(DimensionMismatch())
37+
length(S) == minmn || throw(DimensionMismatch("length of S ($(length(S))) does not matrix minimum dimension of U, Vᴴ ($minmn)"))
3938
r = searchsortedlast(S, rank_atol; rev = true) # rank
4039
Ur = view(U, :, 1:r)
4140
Vᴴr = view(Vᴴ, 1:r, :)
@@ -46,22 +45,22 @@ function svd_pullback!(
4645
UΔU = fill!(similar(U, (r, r)), 0)
4746
VΔV = fill!(similar(Vᴴ, (r, r)), 0)
4847
if !iszerotangent(ΔU)
49-
m == size(ΔU, 1) || throw(DimensionMismatch())
48+
m == size(ΔU, 1) || throw(DimensionMismatch("first dimension of ΔU ($(size(ΔU, 1))) does not match first dimension of U ($m)"))
5049
pU = size(ΔU, 2)
51-
pU > r && throw(DimensionMismatch())
50+
pU > r && throw(DimensionMismatch("second dimension of ΔU ($(size(ΔU, 2))) does not match rank of S ($r)"))
5251
indU = axes(U, 2)[ind]
53-
length(indU) == pU || throw(DimensionMismatch())
52+
length(indU) == pU || throw(DimensionMismatch("length of selected U columns ($(length(indU))) does not match second dimension of ΔU ($(size(ΔU, 2)))"))
5453
UΔUp = view(UΔU, :, indU)
5554
mul!(UΔUp, Ur', ΔU)
5655
# ΔU -= Ur * UΔUp but one less allocation without overwriting ΔU
5756
ΔU = mul!(copy(ΔU), Ur, UΔUp, -1, 1)
5857
end
5958
if !iszerotangent(ΔVᴴ)
60-
n == size(ΔVᴴ, 2) || throw(DimensionMismatch())
59+
n == size(ΔVᴴ, 2) || throw(DimensionMismatch("second dimension of ΔVᴴ ($(size(ΔVᴴ, 2))) does not match second dimension of Vᴴ ($n)"))
6160
pV = size(ΔVᴴ, 1)
62-
pV > r && throw(DimensionMismatch())
61+
pV > r && throw(DimensionMismatch("first dimension of ΔVᴴ ($(size(ΔVᴴ, 1))) does not match rank of S ($r)"))
6362
indV = axes(Vᴴ, 1)[ind]
64-
length(indV) == pV || throw(DimensionMismatch())
63+
length(indV) == pV || throw(DimensionMismatch("length of selected Vᴴ rows ($(length(indV))) does not match first dimension of ΔVᴴ ($(size(ΔVᴴ, 1)))"))
6564
VΔVp = view(VΔV, :, indV)
6665
mul!(VΔVp, Vᴴr, ΔVᴴ')
6766
# ΔVᴴ -= VΔVp' * Vᴴr but one less allocation without overwriting ΔVᴴ
@@ -84,7 +83,7 @@ function svd_pullback!(
8483
ΔS = diagview(ΔSmat)
8584
pS = length(ΔS)
8685
indS = axes(S, 1)[ind]
87-
length(indS) == pS || throw(DimensionMismatch())
86+
length(indS) == pS || throw(DimensionMismatch("length of selected S diagonals ($(length(indS))) does not match length of ΔS diagonal ($(length(ΔS)))"))
8887
view(diagview(UdΔAV), indS) .+= real.(ΔS)
8988
end
9089
ΔA = mul!(ΔA, Ur, UdΔAV * Vᴴr, 1, 1) # add the contribution to ΔA

test/ad_utils.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
function remove_svdgauge_dependence!(
2+
ΔU, ΔVᴴ, U, S, Vᴴ;
3+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
4+
)
5+
gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
6+
gaugepart = (gaugepart - gaugepart') / 2
7+
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
8+
mul!(ΔU, U, gaugepart, -1, 1)
9+
return ΔU, ΔVᴴ
10+
end
11+
function remove_eiggauge_dependence!(
12+
ΔV, D, V;
13+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
14+
)
15+
gaugepart = V' * ΔV
16+
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
17+
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
18+
return ΔV
19+
end
20+
function remove_eighgauge_dependence!(
21+
ΔV, D, V;
22+
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
23+
)
24+
gaugepart = V' * ΔV
25+
gaugepart = (gaugepart - gaugepart') / 2
26+
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
27+
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
28+
return ΔV
29+
end
30+
31+
precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32))
32+
precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64))

test/chainrules.jl

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,7 @@ using ChainRulesCore, ChainRulesTestUtils, Zygote
66
using MatrixAlgebraKit: diagview, TruncatedAlgorithm, PolarViaSVD
77
using LinearAlgebra: UpperTriangular, Diagonal, Hermitian, mul!
88

9-
function remove_svdgauge_dependence!(
10-
ΔU, ΔVᴴ, U, S, Vᴴ;
11-
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(S)
12-
)
13-
gaugepart = U' * ΔU + Vᴴ * ΔVᴴ'
14-
gaugepart = (gaugepart - gaugepart') / 2
15-
gaugepart[abs.(transpose(diagview(S)) .- diagview(S)) .>= degeneracy_atol] .= 0
16-
mul!(ΔU, U, gaugepart, -1, 1)
17-
return ΔU, ΔVᴴ
18-
end
19-
function remove_eiggauge_dependence!(
20-
ΔV, D, V;
21-
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
22-
)
23-
gaugepart = V' * ΔV
24-
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
25-
mul!(ΔV, V / (V' * V), gaugepart, -1, 1)
26-
return ΔV
27-
end
28-
function remove_eighgauge_dependence!(
29-
ΔV, D, V;
30-
degeneracy_atol = MatrixAlgebraKit.default_pullback_gaugetol(D)
31-
)
32-
gaugepart = V' * ΔV
33-
gaugepart = (gaugepart - gaugepart') / 2
34-
gaugepart[abs.(transpose(diagview(D)) .- diagview(D)) .>= degeneracy_atol] .= 0
35-
mul!(ΔV, V, gaugepart, -1, 1)
36-
return ΔV
37-
end
38-
39-
precision(::Type{<:Union{Float32, Complex{Float32}}}) = sqrt(eps(Float32))
40-
precision(::Type{<:Union{Float64, Complex{Float64}}}) = sqrt(eps(Float64))
9+
include("ad_utils.jl")
4110

4211
for f in
4312
(

0 commit comments

Comments
 (0)