From a19464293421fa41ebdfccb18976c26dc2601243 Mon Sep 17 00:00:00 2001 From: marius Date: Thu, 26 May 2022 23:14:03 -0700 Subject: [PATCH 1/2] allow scalars and Arrays in FieldTuples --- src/field_tuples.jl | 12 +++++++----- src/generic.jl | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/field_tuples.jl b/src/field_tuples.jl index 105bf2ab..bc5c928d 100644 --- a/src/field_tuples.jl +++ b/src/field_tuples.jl @@ -18,7 +18,7 @@ end FieldTuple(pairs::Vector{<:Pair}) = FieldTuple(NamedTuple(pairs)) ### printing -getindex(f::FieldTuple,::Colon) = vcat(getindex.(values(f.fs),:)...)[:] +getindex(f::FieldTuple,::Colon) = mapreduce(collect, vcat, values(f.fs)) getindex(D::DiagOp{<:FieldTuple}, i::Int, j::Int) = (i==j) ? D.diag[:][i] : diagzero(D, i, j) typealias_def(::Type{<:FieldTuple{NamedTuple{Names,FS},T}}) where {Names,FS<:Tuple,T} = "Field-($(join(map(string,Names),",")))-$FS" @@ -102,10 +102,12 @@ propertynames(f::FieldTuple) = (:fs, propertynames(f.fs)...) randn!(rng::AbstractRNG, ξ::FieldTuple) = FieldTuple(map(f -> randn!(rng, f), ξ.fs)) ### Diagonal-ops +Diagonal_or_scalar(x::Number) = x +Diagonal_or_scalar(x) = Diagonal(x) # need a method specific for FieldTuple since we don't carry around # the basis in a way that works with the default implementation -(*)(D::DiagOp{<:FieldTuple}, f::FieldTuple) = FieldTuple(map((d,f)->Diagonal(d)*f, D.diag.fs, f.fs)) -(\)(D::DiagOp{<:FieldTuple}, f::FieldTuple) = FieldTuple(map((d,f)->Diagonal(d)\f, D.diag.fs, f.fs)) +(*)(D::DiagOp{<:FieldTuple}, f::FieldTuple) = FieldTuple(map((d,f)->Diagonal_or_scalar(d)*f, D.diag.fs, f.fs)) +(\)(D::DiagOp{<:FieldTuple}, f::FieldTuple) = FieldTuple(map((d,f)->Diagonal_or_scalar(d)\f, D.diag.fs, f.fs)) # promote before recursing for these @@ -113,8 +115,8 @@ dot(a::FieldTuple, b::FieldTuple) = reduce(+, map(dot, getfield.(promote(a,b),:f hash(ft::FieldTuple, h::UInt64) = foldr(hash, (typeof(ft), ft.fs), init=h) # logdet & trace -@auto_adjoint logdet(L::Diagonal{<:Union{Real,Complex}, <:FieldTuple}) = reduce(+, map(logdet∘Diagonal, diag(L).fs), init=0) -tr(L::Diagonal{<:Union{Real,Complex}, <:FieldTuple}) = reduce(+, map(tr∘Diagonal, diag(L).fs), init=0) +@auto_adjoint logdet(L::Diagonal{<:Union{Real,Complex}, <:FieldTuple}) = reduce(+, map(logdet∘Diagonal_or_scalar, diag(L).fs), init=0) +tr(L::Diagonal{<:Union{Real,Complex}, <:FieldTuple}) = reduce(+, map(tr∘Diagonal_or_scalar, diag(L).fs), init=0) # misc batch_length(ft::FieldTuple) = only(unique(map(batch_length, ft.fs))) diff --git a/src/generic.jl b/src/generic.jl index 6ee81d6c..f5acccc8 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -224,7 +224,7 @@ unknown_rule_error(::typeof(promote_basis_strict_rule), ::B₁, ::B₂) where {B basis(f::F) where {F<:Field} = basis(F) basis(::Type{<:Field{B}}) where {B<:Basis} = B basis(::Type{<:Field}) = Basis -basis(::AbstractVector) = Basis +basis(::Union{Number,AbstractVector}) = Basis # allows them to be in FieldTuple ### printing From 4b2fcb620eb2893e924d69ec00a55ceb042dd253 Mon Sep 17 00:00:00 2001 From: marius Date: Thu, 26 May 2022 23:16:56 -0700 Subject: [PATCH 2/2] fix gradients --- src/generic.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generic.jl b/src/generic.jl index f5acccc8..47e8912f 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -123,7 +123,7 @@ HarmonicBasis(::Basis2Prod{𝐄𝐁, <:S0Basis}) = EBFourier (::Type{B})(dst::AbstractArray{<:Field}, src::AbstractArray{<:Field}) where {B<:Basis} = B.(dst,src) # The abstract `Basis` type means "any basis", hence this conversion rule: -Basis(f::Field) = f +Basis(f) = f # used in make_field_aliases below