Skip to content

Commit

Permalink
EM: optimize mean handling
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Stukalov authored and alyst committed Apr 19, 2024
1 parent 22985e6 commit 981baf1
Showing 1 changed file with 46 additions and 26 deletions.
72 changes: 46 additions & 26 deletions src/observed/EM.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ via expectation maximization (EM) for `observed`.
Returns the tuple of the EM covariance matrix and the EM mean vector.
Uses the EM algorithm for MVN-distributed data with missing values
Based on the EM algorithm for MVN-distributed data with missing values
adapted from the supplementary material to the book *Machine Learning: A Probabilistic Perspective*,
copyright (2010) Kevin Murphy and Matt Dunham: see
[*gaussMissingFitEm.m*](https://github.com/probml/pmtk3/blob/master/toolbox/BasicModels/gauss/sub/gaussMissingFitEm.m) and
Expand Down Expand Up @@ -106,6 +106,8 @@ function em_step!(
copy!(μ, 𝔼x_full)
copy!(Σ, 𝔼xxᵀ_full)
nobs_used = nobs_full
mul!(Σ, μ₀, μ₀', -nobs_used, 1)
axpy!(-nobs_used, μ₀, μ)

# Compute the expected sufficient statistics
for pat in patterns
Expand All @@ -115,50 +117,68 @@ function em_step!(
u = pat.miss_mask
o = pat.obs_mask

# precompute for pattern
Σoo_chol = cholesky(Symmetric(Σ₀[o, o]))
Σuo = Σ₀[u, o]
μu = μ₀[u]
μo = μ₀[o]
# compute cholesky to speed-up ldiv!()
Σ₀oo_chol = cholesky(Symmetric(Σ₀[o, o]))
Σ₀uo = Σ₀[u, o]
μ₀u = μ₀[u]
μ₀o = μ₀[o]

# get pattern observations
nobs = !isnothing(max_nobs_em) ? min(max_nobs_em, n_obs(pat)) : n_obs(pat)
pat_data =
zo =
nobs < n_obs(pat) ?
view(pat.data, :, sort!(sample(1:n_obs(pat), nobs, replace = false))) : pat.data
pat.data[:, sort!(sample(1:n_obs(pat), nobs, replace = false))] : copy(pat.data)
zo .-= μ₀o # subtract current mean from observations

𝔼xu = fill!(similar(μu), 0)
𝔼xo = fill!(similar(μo), 0)
𝔼xᵢu = similar(μu)
𝔼zo = sum(zo, dims = 2)
𝔼zu = fill!(similar(μ₀u), 0)

𝔼xxᵀuo = fill!(similar(Σuo), 0)
𝔼xxᵀuu = n_obs(pat) * (Σ₀[u, u] - Σuo * (Σoo_chol \ Σuo'))
𝔼zzᵀuo = fill!(similar(Σ₀uo), 0)
𝔼zzᵀuu = nobs * Σ₀[u, u]
mul!(𝔼zzᵀuu, Σ₀uo, Σ₀oo_chol \ Σ₀uo', -nobs, 1)

# loop through observations
@inbounds for obsdata in eachcol(pat_data)
mul!(𝔼xᵢu, Σuo, Σoo_chol \ (obsdata - μo))
𝔼xᵢu .+= μu
mul!(𝔼xxᵀuu, 𝔼xᵢu, 𝔼xᵢu', 1, 1)
mul!(𝔼xxᵀuo, 𝔼xᵢu, obsdata', 1, 1)
𝔼xu .+= 𝔼xᵢu
𝔼xo .+= obsdata
yᵢo = similar(μ₀o)
𝔼zᵢu = similar(μ₀u)
@inbounds for zᵢo in eachcol(zo)
ldiv!(yᵢo, Σ₀oo_chol, zᵢo)
mul!(𝔼zᵢu, Σ₀uo, yᵢo)
mul!(𝔼zzᵀuu, 𝔼zᵢu, 𝔼zᵢu', 1, 1)
mul!(𝔼zzᵀuo, 𝔼zᵢu, zᵢo', 1, 1)
𝔼zu .+= 𝔼zᵢu
end
# correct 𝔼zzᵀ by adding back μ₀×𝔼z' + 𝔼z'×μ₀
mul!(𝔼zzᵀuo, μ₀u, 𝔼zo', 1, 1)
mul!(𝔼zzᵀuo, 𝔼zu, μ₀o', 1, 1)

Σ[o, o] .+= pat_data * pat_data'
Σ[u, o] .+= 𝔼xxᵀuo
Σ[o, u] .+= 𝔼xxᵀuo'
Σ[u, u] .+= 𝔼xxᵀuu
mul!(𝔼zzᵀuu, μ₀u, 𝔼zu', 1, 1)
mul!(𝔼zzᵀuu, 𝔼zu, μ₀u', 1, 1)

μ[o] .+= 𝔼xo
μ[u] .+= 𝔼xu
𝔼zzᵀoo = zo * zo'
mul!(𝔼zzᵀoo, μ₀o, 𝔼zo', 1, 1)
mul!(𝔼zzᵀoo, 𝔼zo, μ₀o', 1, 1)

# update Σ and μ
Σ[o, o] .+= 𝔼zzᵀoo
Σ[u, o] .+= 𝔼zzᵀuo
Σ[o, u] .+= 𝔼zzᵀuo'
Σ[u, u] .+= 𝔼zzᵀuu

μ[o] .+= 𝔼zo
μ[u] .+= 𝔼zu

nobs_used += nobs
end

# M step, update em_model
lmul!(1 / nobs_used, Σ)
lmul!(1 / nobs_used, μ)
# at this point μ = μ - μ₀
# and Σ = Σ + (μ - μ₀)×(μ - μ₀)' - μ₀×μ₀'
mul!(Σ, μ, μ₀', -1, 1)
mul!(Σ, μ₀, μ', -1, 1)
mul!(Σ, μ, μ', -1, 1)
μ .+= μ₀

# ridge Σ
# while !isposdef(Σ)
Expand Down

0 comments on commit 981baf1

Please sign in to comment.