|
57 | 57 | @adjoint function Zygote.literal_getproperty(f::BaseField{B,M,T}, ::Val{:arr}) where {B<:SpatialBasis{AzFourier},M,T}
|
58 | 58 | getfield(f,:arr), Δ -> (BaseField{B}(Δ ./ adapt(typeof(Δ), T.(rfft_degeneracy_fac(f.Nx)' ./ Zfac(B(), f.metadata))), f.metadata),)
|
59 | 59 | end
|
60 |
| -# preserve field type for sub-component property getters |
61 |
| -function _getproperty_subcomponent_pullback(f, k) |
62 |
| - function getproperty_pullback(Δ) |
63 |
| - g = similar(f, promote_type(eltype(f), eltype(Δ))) |
64 |
| - g .= 0 |
| 60 | +# needed to preserve field type for sub-component property getters |
| 61 | +@adjoint function Zygote.getproperty(f::BaseField, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B,:P,:IP)))...}) |
| 62 | + function field_getproperty_pullback(Δ) |
| 63 | + g = (similar(f, promote_type(eltype(f), eltype(Δ))) .= 0) |
65 | 64 | getproperty(g, k) .= Δ
|
66 | 65 | (g, nothing)
|
67 | 66 | end
|
68 |
| - getproperty(f, k), getproperty_pullback |
69 |
| -end |
70 |
| -@adjoint function Zygote.literal_getproperty(f::BaseField{B}, k::Union{typeof.(Val.((:I,:Q,:U,:E,:B)))...}) where {B₀, B<:SpatialBasis{B₀}} |
71 |
| - _getproperty_subcomponent_pullback(f, k) |
72 |
| -end |
73 |
| -@adjoint function Zygote.literal_getproperty(f::BaseS02{Basis3Prod{𝐈,B₂,B₀}}, k::Val{:P}) where {B₂,B₀} |
74 |
| - _getproperty_subcomponent_pullback(f, k) |
| 67 | + getproperty(f, k), field_getproperty_pullback |
75 | 68 | end
|
76 | 69 | # if accumulting from one branch that was just a f.metadata
|
77 | 70 | Zygote.accum(f::BaseField, nt::NamedTuple{(:arr,:metadata)}) = (@assert(isnothing(nt.arr)); f)
|
|
0 commit comments