Skip to content

Commit

Permalink
BlockDiagIEB AD fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
marius311 committed Jan 20, 2023
1 parent a4f959b commit 9b821d4
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 11 deletions.
2 changes: 1 addition & 1 deletion src/CMBLensing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ using Serialization
using Setfield
using SparseArrays
using StaticArrays: @SMatrix, @SVector, SMatrix, StaticArray, StaticArrayStyle,
StaticMatrix, StaticVector, SVector, SArray, SizedArray
StaticMatrix, StaticVector, SVector, SArray, SizedArray, SizedMatrix
using Statistics
using StatsBase
using TimerOutputs: @timeit, get_defaulttimer, reset_timer!
Expand Down
1 change: 0 additions & 1 deletion src/field_vectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,3 @@ function pinv!(dst::FieldOrOpMatrix{<:Diagonal}, src::FieldOrOpMatrix{<:Diagonal
end

promote_rule(::Type{F}, ::Type{<:Scalar}) where {F<:Field} = F
arithmetic_closure(::F) where {F<:Field} = F
1 change: 1 addition & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ precompute!!(L::Adjoint, f) = precompute!!(parent(L),f)'
# splatted into a giant matrix when doing [f f; f f] (which they would othewise
# be since they're Arrays)
hvcat(rows::Tuple{Vararg{Int}}, values::Field...) = hvcat(rows, ([x] for x in values)...)
hvcat(rows::Tuple{Vararg{Int}}, values::DiagOp...) = hvcat(rows, ([x] for x in values)...)
hcat(values::Field...) = hcat(([x] for x in values)...)

### printing
Expand Down
2 changes: 1 addition & 1 deletion src/proj_lambert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ function Cℓ_to_Cov(::Val{:P}, proj::ProjLambert, CℓEE::Cℓs, CℓBB::Cℓs;
end
function Cℓ_to_Cov(::Val{:IP}, proj::ProjLambert, CℓTT, CℓEE, CℓBB, CℓTE; kwargs...)
ΣTT, ΣEE, ΣBB, ΣTE = [Cℓ_to_Cov(:I,proj,Cℓ; kwargs...) for Cℓ in (CℓTT,CℓEE,CℓBB,CℓTE)]
BlockDiagIEB(@SMatrix([ΣTT ΣTE; ΣTE ΣEE]), ΣBB)
BlockDiagIEB([ΣTT ΣTE; ΣTE ΣEE], ΣBB)
end

## ParamDependentOp covariances scaled by amplitudes in different ℓ-bins
Expand Down
11 changes: 6 additions & 5 deletions src/specialops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,10 @@ end
# call sqrt/inv on it, and the ΣBB block separately as ΣB. This type
# is generic with regards to the field type, F.
struct BlockDiagIEB{T,F} <: ImplicitOp{T}
ΣTE :: SMatrix{2,2,Diagonal{T,F},4}
ΣTE :: SizedMatrix{2,2,Diagonal{T,F},2,Matrix{Diagonal{T,F}}}
ΣB :: Diagonal{T,F}
end
BlockDiagIEB(ΣTE::AbstractMatrix{Diagonal{T,F}}, ΣB::Diagonal{T,F}) where {T,F} = BlockDiagIEB{T,F}(ΣTE, ΣB)
# applying
*(L::BlockDiagIEB, f::BaseS02) = L * IEBFourier(f)
\(L::BlockDiagIEB, f::BaseS02) = pinv(L) * IEBFourier(f)
Expand All @@ -77,13 +78,13 @@ similar(L::BlockDiagIEB) = BlockDiagIEB(similar.(L.ΣTE), similar(L.ΣB))
get_storage(L::BlockDiagIEB) = get_storage(L.ΣB)
adapt_structure(storage, L::BlockDiagIEB) = BlockDiagIEB(adapt.(Ref(storage), L.ΣTE), adapt(storage, L.ΣB))
simulate(rng::AbstractRNG, L::BlockDiagIEB; Nbatch=()) = sqrt(L) * randn!(rng, similar(diag(L), Nbatch...))
logdet(L::BlockDiagIEB) = logdet(L.ΣTE[1,1]*L.ΣTE[2,2]-L.ΣTE[1,2]*L.ΣTE[2,1]) + logdet(L.ΣB)
logdet(L::BlockDiagIEB) = logdet(det(L.ΣTE)) + logdet(L.ΣB)
# arithmetic
*(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB(SMatrix{2,2}(L.ΣTE * [[D[:I]] [0]; [0] [D[:E]]]), L.ΣB * D[:B])
+(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB(@SMatrix[L.ΣTE[1,1]+D[:I] L.ΣTE[1,2]; L.ΣTE[2,1] L.ΣTE[2,2]+D[:E]], L.ΣB + D[:B])
*(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB(L.ΣTE * [[D[:I]] [0]; [0] [D[:E]]], L.ΣB * D[:B])
+(L::BlockDiagIEB, D::DiagOp{<:BaseIEBFourier}) = BlockDiagIEB([L.ΣTE[1,1]+D[:I] L.ΣTE[1,2]; L.ΣTE[2,1] L.ΣTE[2,2]+D[:E]], L.ΣB + D[:B])
*(La::F, Lb::F) where {F<:BlockDiagIEB} = F(La.ΣTE * Lb.ΣTE, La.ΣB * Lb.ΣB)
+(La::F, Lb::F) where {F<:BlockDiagIEB} = F(La.ΣTE + Lb.ΣTE, La.ΣB + Lb.ΣB)
+(L::BlockDiagIEB, U::UniformScaling{<:Scalar}) = BlockDiagIEB(@SMatrix[(L.ΣTE[1,1]+U) L.ΣTE[1,2]; L.ΣTE[2,1] (L.ΣTE[2,2]+U)], L.ΣB+U)
+(L::BlockDiagIEB, U::UniformScaling{<:Scalar}) = BlockDiagIEB([(L.ΣTE[1,1]+U) L.ΣTE[1,2]; L.ΣTE[2,1] (L.ΣTE[2,2]+U)], L.ΣB+U)
*(L::BlockDiagIEB, λ::Scalar) = BlockDiagIEB(L.ΣTE * λ, L.ΣB * λ)
*(D::DiagOp{<:BaseIEBFourier}, L::BlockDiagIEB) = L * D
+(U::UniformScaling{<:Scalar}, L::BlockDiagIEB) = L + U
Expand Down
11 changes: 8 additions & 3 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,20 +100,25 @@ end
# these allow pinv and sqrt of SMatrices of Diagonals to work correctly, which
# we use for the T-E block of the covariance. hopefully some of this can be cut
# down on in the futue with some PRs into StaticArrays.
permutedims(A::SMatrix{2,2}) = @SMatrix[A[1] A[3]; A[2] A[4]]
@auto_adjoint function sqrt(A::SMatrix{2,2,<:Diagonal})
# permutedims(A::SMatrix{2,2}) = @SMatrix[A[1] A[3]; A[2] A[4]]
@auto_adjoint function sqrt(A::SizedMatrix{2,2,<:Diagonal})
# A = [a b; c d]
a,c,b,d = A
s = sqrt(a*d-b*c)
t = pinv(sqrt(a+(d+2s)))
@SMatrix[t*(a+s) t*b; t*c t*(d+s)]
end
@auto_adjoint function pinv(A::SMatrix{2,2,<:Diagonal})
@auto_adjoint function pinv(A::SizedMatrix{2,2,<:Diagonal})
# A = [a b; c d]
a,c,b,d = A
idet = pinv(a*d-b*c)
@SMatrix[d*idet -(b*idet); -(c*idet) a*idet]
end
@auto_adjoint function det(A::SizedMatrix{2,2,<:Diagonal})
# A = [a b; c d]
a,c,b,d = A
a*d-b*c
end


# some usefule tuple manipulation functions:
Expand Down

0 comments on commit 9b821d4

Please sign in to comment.