Skip to content

Commit a4f959b

Browse files
committed
improve getproperty(::Field, k) pullbacks
1 parent a7c6bbc commit a4f959b

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

src/autodiff.jl

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,21 +57,14 @@ end
5757
@adjoint function Zygote.literal_getproperty(f::BaseField{B,M,T}, ::Val{:arr}) where {B<:SpatialBasis{AzFourier},M,T}
5858
getfield(f,:arr), Δ -> (BaseField{B}./ adapt(typeof(Δ), T.(rfft_degeneracy_fac(f.Nx)' ./ Zfac(B(), f.metadata))), f.metadata),)
5959
end
60-
# preserve field type for sub-component property getters
61-
function _getproperty_subcomponent_pullback(f, k)
62-
function getproperty_pullback(Δ)
63-
g = similar(f, promote_type(eltype(f), eltype(Δ)))
64-
g .= 0
60+
# needed to preserve field type for sub-component property getters
61+
@adjoint function Zygote.getproperty(f::BaseField, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B,:P,:IP)))...})
62+
function field_getproperty_pullback(Δ)
63+
g = (similar(f, promote_type(eltype(f), eltype(Δ))) .= 0)
6564
getproperty(g, k) .= Δ
6665
(g, nothing)
6766
end
68-
getproperty(f, k), getproperty_pullback
69-
end
70-
@adjoint function Zygote.literal_getproperty(f::BaseField{B}, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B)))...}) where {B₀, B<:SpatialBasis{B₀}}
71-
_getproperty_subcomponent_pullback(f, k)
72-
end
73-
@adjoint function Zygote.literal_getproperty(f::BaseS02{Basis3Prod{𝐈,B₂,B₀}}, k::Val{:P}) where {B₂,B₀}
74-
_getproperty_subcomponent_pullback(f, k)
67+
getproperty(f, k), field_getproperty_pullback
7568
end
7669
# if accumulting from one branch that was just a f.metadata
7770
Zygote.accum(f::BaseField, nt::NamedTuple{(:arr,:metadata)}) = (@assert(isnothing(nt.arr)); f)

0 commit comments

Comments
 (0)