diff --git a/src/CMBLensing.jl b/src/CMBLensing.jl index f1cbdd38..d5db4bb0 100644 --- a/src/CMBLensing.jl +++ b/src/CMBLensing.jl @@ -51,7 +51,7 @@ using Serialization using Setfield using SparseArrays using StaticArrays: @SMatrix, @SVector, SMatrix, StaticArray, StaticArrayStyle, - StaticMatrix, StaticVector, SVector, SArray, SizedArray + StaticMatrix, StaticVector, SVector, SArray, SizedArray, SizedMatrix using Statistics using StatsBase using TimerOutputs: @timeit, get_defaulttimer, reset_timer! diff --git a/src/field_vectors.jl b/src/field_vectors.jl index f8a5d1ef..dbae81a7 100644 --- a/src/field_vectors.jl +++ b/src/field_vectors.jl @@ -65,4 +65,3 @@ function pinv!(dst::FieldOrOpMatrix{<:Diagonal}, src::FieldOrOpMatrix{<:Diagonal end promote_rule(::Type{F}, ::Type{<:Scalar}) where {F<:Field} = F -arithmetic_closure(::F) where {F<:Field} = F diff --git a/src/generic.jl b/src/generic.jl index 891e99b9..010bea01 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -322,6 +322,7 @@ precompute!!(L::Adjoint, f) = precompute!!(parent(L),f)' # splatted into a giant matrix when doing [f f; f f] (which they would othewise # be since they're Arrays) hvcat(rows::Tuple{Vararg{Int}}, values::Field...) = hvcat(rows, ([x] for x in values)...) +hvcat(rows::Tuple{Vararg{Int}}, values::DiagOp...) = hvcat(rows, ([x] for x in values)...) hcat(values::Field...) = hcat(([x] for x in values)...) ### printing diff --git a/src/proj_lambert.jl b/src/proj_lambert.jl index f49be830..84f92624 100644 --- a/src/proj_lambert.jl +++ b/src/proj_lambert.jl @@ -366,7 +366,7 @@ function Cℓ_to_Cov(::Val{:P}, proj::ProjLambert, CℓEE::Cℓs, CℓBB::Cℓs; end function Cℓ_to_Cov(::Val{:IP}, proj::ProjLambert, CℓTT, CℓEE, CℓBB, CℓTE; kwargs...) ΣTT, ΣEE, ΣBB, ΣTE = [Cℓ_to_Cov(:I,proj,Cℓ; kwargs...) for Cℓ in (CℓTT,CℓEE,CℓBB,CℓTE)] - BlockDiagIEB(@SMatrix([ΣTT ΣTE; ΣTE ΣEE]), ΣBB) + BlockDiagIEB([ΣTT ΣTE; ΣTE ΣEE], ΣBB) end ## ParamDependentOp covariances scaled by amplitudes in different ℓ-bins diff --git a/src/specialops.jl b/src/specialops.jl index 86647321..74c0efc3 100644 --- a/src/specialops.jl +++ b/src/specialops.jl @@ -57,9 +57,10 @@ end # call sqrt/inv on it, and the ΣBB block separately as ΣB. This type # is generic with regards to the field type, F. struct BlockDiagIEB{T,F} <: ImplicitOp{T} - ΣTE :: SMatrix{2,2,Diagonal{T,F},4} + ΣTE :: SizedMatrix{2,2,Diagonal{T,F},2,Matrix{Diagonal{T,F}}} ΣB :: Diagonal{T,F} end +BlockDiagIEB(ΣTE::AbstractMatrix{Diagonal{T,F}}, ΣB::Diagonal{T,F}) where {T,F} = BlockDiagIEB{T,F}(ΣTE, ΣB) # applying *(L::BlockDiagIEB, f::BaseS02) = L * IEBFourier(f) \(L::BlockDiagIEB, f::BaseS02) = pinv(L) * IEBFourier(f) @@ -77,13 +78,13 @@ similar(L::BlockDiagIEB) = BlockDiagIEB(similar.(L.ΣTE), similar(L.ΣB)) get_storage(L::BlockDiagIEB) = get_storage(L.ΣB) adapt_structure(storage, L::BlockDiagIEB) = BlockDiagIEB(adapt.(Ref(storage), L.ΣTE), adapt(storage, L.ΣB)) simulate(rng::AbstractRNG, L::BlockDiagIEB; Nbatch=()) = sqrt(L) * randn!(rng, similar(diag(L), Nbatch...)) -logdet(L::BlockDiagIEB) = logdet(L.ΣTE[1,1]*L.ΣTE[2,2]-L.ΣTE[1,2]*L.ΣTE[2,1]) + logdet(L.ΣB) +logdet(L::BlockDiagIEB) = logdet(det(L.ΣTE)) + logdet(L.ΣB) # arithmetic -*(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB(SMatrix{2,2}(L.ΣTE * [[D[:I]] [0]; [0] [D[:E]]]), L.ΣB * D[:B]) -+(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB(@SMatrix[L.ΣTE[1,1]+D[:I] L.ΣTE[1,2]; L.ΣTE[2,1] L.ΣTE[2,2]+D[:E]], L.ΣB + D[:B]) +*(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB(L.ΣTE * [[D[:I]] [0]; [0] [D[:E]]], L.ΣB * D[:B]) ++(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB([L.ΣTE[1,1]+D[:I] L.ΣTE[1,2]; L.ΣTE[2,1] L.ΣTE[2,2]+D[:E]], L.ΣB + D[:B]) *(La::F, Lb::F) where {F<:BlockDiagIEB} = F(La.ΣTE * Lb.ΣTE, La.ΣB * Lb.ΣB) +(La::F, Lb::F) where {F<:BlockDiagIEB} = F(La.ΣTE + Lb.ΣTE, La.ΣB + Lb.ΣB) -+(L::BlockDiagIEB, U::UniformScaling{<:Scalar}) = BlockDiagIEB(@SMatrix[(L.ΣTE[1,1]+U) L.ΣTE[1,2]; L.ΣTE[2,1] (L.ΣTE[2,2]+U)], L.ΣB+U) ++(L::BlockDiagIEB, U::UniformScaling{<:Scalar}) = BlockDiagIEB([(L.ΣTE[1,1]+U) L.ΣTE[1,2]; L.ΣTE[2,1] (L.ΣTE[2,2]+U)], L.ΣB+U) *(L::BlockDiagIEB, λ::Scalar) = BlockDiagIEB(L.ΣTE * λ, L.ΣB * λ) *(D::DiagOp{<:BaseIEBFourier}, L::BlockDiagIEB) = L * D +(U::UniformScaling{<:Scalar}, L::BlockDiagIEB) = L + U diff --git a/src/util.jl b/src/util.jl index 66322e43..9c765f6a 100644 --- a/src/util.jl +++ b/src/util.jl @@ -100,20 +100,25 @@ end # these allow pinv and sqrt of SMatrices of Diagonals to work correctly, which # we use for the T-E block of the covariance. hopefully some of this can be cut # down on in the futue with some PRs into StaticArrays. -permutedims(A::SMatrix{2,2}) = @SMatrix[A[1] A[3]; A[2] A[4]] -@auto_adjoint function sqrt(A::SMatrix{2,2,<:Diagonal}) +# permutedims(A::SMatrix{2,2}) = @SMatrix[A[1] A[3]; A[2] A[4]] +@auto_adjoint function sqrt(A::SizedMatrix{2,2,<:Diagonal}) # A = [a b; c d] a,c,b,d = A s = sqrt(a*d-b*c) t = pinv(sqrt(a+(d+2s))) @SMatrix[t*(a+s) t*b; t*c t*(d+s)] end -@auto_adjoint function pinv(A::SMatrix{2,2,<:Diagonal}) +@auto_adjoint function pinv(A::SizedMatrix{2,2,<:Diagonal}) # A = [a b; c d] a,c,b,d = A idet = pinv(a*d-b*c) @SMatrix[d*idet -(b*idet); -(c*idet) a*idet] end +@auto_adjoint function det(A::SizedMatrix{2,2,<:Diagonal}) + # A = [a b; c d] + a,c,b,d = A + a*d-b*c +end # some usefule tuple manipulation functions: