Skip to content

Commit

Permalink
Merge pull request #202 from alyst/spec_matrices
Browse files Browse the repository at this point in the history
Cleanup Special Matrices code
  • Loading branch information
Maximilian-Stefan-Ernst authored Apr 24, 2024
2 parents f1d0b85 + c1e7a69 commit cdc4415
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 140 deletions.
4 changes: 4 additions & 0 deletions src/StructuralEquationModels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ const SEM = StructuralEquationModels
# type hierarchy
include("types.jl")
include("objective_gradient_hessian.jl")

# helper objects and functions
include("additional_functions/commutation_matrix.jl")

# fitted objects
include("frontend/fit/SemFit.jl")
# specification of models
Expand Down
75 changes: 75 additions & 0 deletions src/additional_functions/commutation_matrix.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
transpose_linear_indices(n, [m])
Put each linear index of the *n×m* matrix to the position of the
corresponding element in the transposed matrix.
## Example
`
1 4
2 5 => 1 2 3
3 6 4 5 6
`
"""
transpose_linear_indices(n::Integer, m::Integer = n) =
repeat(1:n, inner = m) .+ repeat((0:(m-1)) * n, outer = n)

"""
CommutationMatrix(n::Integer) <: AbstractMatrix{Int}
A *commutation matrix* *C* is a n²×n² matrix of 0s and 1s.
If *vec(A)* is a vectorized form of a n×n matrix *A*,
then ``C * vec(A) = vec(Aᵀ)``.
"""
struct CommutationMatrix <: AbstractMatrix{Int}
n::Int
::Int
transpose_inds::Vector{Int} # maps the linear indices of n×n matrix *B* to the indices of matrix *B'*

CommutationMatrix(n::Integer) = new(n, n^2, transpose_linear_indices(n))
end

Base.size(A::CommutationMatrix) = (A.n², A.n²)
Base.size(A::CommutationMatrix, dim::Integer) =
1 <= dim <= 2 ? A.: throw(ArgumentError("invalid matrix dimension $dim"))
Base.length(A::CommutationMatrix) = A.^2
Base.getindex(A::CommutationMatrix, i::Int, j::Int) = j == A.transpose_inds[i] ? 1 : 0

function Base.:(*)(A::CommutationMatrix, B::AbstractVector)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) elements"),
)
return B[A.transpose_inds]
end

function Base.:(*)(A::CommutationMatrix, B::AbstractMatrix)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
)
return B[A.transpose_inds, :]
end

function Base.:(*)(A::CommutationMatrix, B::SparseMatrixCSC)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
)
return SparseMatrixCSC(
size(B, 1),
size(B, 2),
copy(B.colptr),
A.transpose_inds[B.rowval],
copy(B.nzval),
)
end

function LinearAlgebra.lmul!(A::CommutationMatrix, B::SparseMatrixCSC)
size(A, 2) == size(B, 1) || throw(
DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"),
)

@inbounds for (i, rowind) in enumerate(B.rowval)
B.rowval[i] = A.transpose_inds[rowind]
end
return B
end
155 changes: 23 additions & 132 deletions src/additional_functions/helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function get_observed(rowind, data, semobserved; args = (), kwargs = NamedTuple(
return observed_vec
end

skipmissing_mean(mat::AbstractMatrix) =
skipmissing_mean(mat::AbstractMatrix) =
[mean(skipmissing(coldata)) for coldata in eachcol(mat)]

function F_one_person(imp_mean, meandiff, inverse, data, logdet)
Expand Down Expand Up @@ -111,143 +111,34 @@ function cov_and_mean(rows; corrected = false)
return obs_cov, vec(obs_mean)
end

function duplication_matrix(nobs)
nobs = Int(nobs)
n1 = Int(nobs * (nobs + 1) * 0.5)
n2 = Int(nobs^2)
Dt = zeros(n1, n2)

for j in 1:nobs
for i in j:nobs
u = zeros(n1)
u[Int((j - 1) * nobs + i - 0.5 * j * (j - 1))] = 1
T = zeros(nobs, nobs)
T[j, i] = 1
T[i, j] = 1
Dt += u * transpose(vec(T))
# n²×(n(n+1)/2) matrix to transform a vector of lower
# triangular entries into a vectorized form of a n×n symmetric matrix,
# opposite of elimination_matrix()
function duplication_matrix(n::Integer)
ntri = div(n * (n + 1), 2)
D = zeros(n^2, ntri)
for j in 1:n
for i in j:n
tri_ix = (j - 1) * n + i - div(j * (j - 1), 2)
D[j+n*(i-1), tri_ix] = 1
D[i+n*(j-1), tri_ix] = 1
end
end
D = transpose(Dt)
return D
end

function elimination_matrix(nobs)
nobs = Int(nobs)
n1 = Int(nobs * (nobs + 1) * 0.5)
n2 = Int(nobs^2)
L = zeros(n1, n2)

for j in 1:nobs
for i in j:nobs
u = zeros(n1)
u[Int((j - 1) * nobs + i - 0.5 * j * (j - 1))] = 1
T = zeros(nobs, nobs)
T[i, j] = 1
L += u * transpose(vec(T))
# (n(n+1)/2)×n² matrix to transform a
# vectorized form of a n×n symmetric matrix
# into vector of its lower triangular entries,
# opposite of duplication_matrix()
function elimination_matrix(n::Integer)
ntri = div(n * (n + 1), 2)
L = zeros(ntri, n^2)
for j in 1:n
for i in j:n
tri_ix = (j - 1) * n + i - div(j * (j - 1), 2)
L[tri_ix, i+n*(j-1)] = 1
end
end
return L
end

function commutation_matrix(n; tosparse = false)
M = zeros(n^2, n^2)

for i in 1:n
for j in 1:n
M[i+n*(j-1), j+n*(i-1)] = 1.0
end
end

if tosparse
M = sparse(M)
end

return M
end

function commutation_matrix_pre_square(A)
n2 = size(A, 1)
n = Int(sqrt(n2))

ind = repeat(1:n, inner = n)
indadd = (0:(n-1)) * n
for i in 1:n
ind[((i-1)*n+1):i*n] .+= indadd
end

A_post = A[ind, :]

return A_post
end

function commutation_matrix_pre_square_add!(B, A) # comuptes B + KₙA
n2 = size(A, 1)
n = Int(sqrt(n2))

ind = repeat(1:n, inner = n)
indadd = (0:(n-1)) * n
for i in 1:n
ind[((i-1)*n+1):i*n] .+= indadd
end

@views @inbounds B .+= A[ind, :]

return B
end

function get_commutation_lookup(n2::Int64)
n = Int(sqrt(n2))
ind = repeat(1:n, inner = n)
indadd = (0:(n-1)) * n
for i in 1:n
ind[((i-1)*n+1):i*n] .+= indadd
end

lookup = Dict{Int64, Int64}()

for i in 1:n2
j = findall(x -> (x == i), ind)[1]
push!(lookup, i => j)
end

return lookup
end

function commutation_matrix_pre_square!(A::SparseMatrixCSC, lookup) # comuptes B + KₙA
for (i, rowind) in enumerate(A.rowval)
A.rowval[i] = lookup[rowind]
end
end

function commutation_matrix_pre_square!(A::SparseMatrixCSC) # computes KₙA
lookup = get_commutation_lookup(size(A, 2))
commutation_matrix_pre_square!(A, lookup)
end

function commutation_matrix_pre_square(A::SparseMatrixCSC)
B = copy(A)
commutation_matrix_pre_square!(B)
return B
end

function commutation_matrix_pre_square(A::SparseMatrixCSC, lookup)
B = copy(A)
commutation_matrix_pre_square!(B, lookup)
return B
end

function commutation_matrix_pre_square_add_mt!(B, A) # comuptes B + KₙA # 0 allocations but slower
n2 = size(A, 1)
n = Int(sqrt(n2))

indadd = (0:(n-1)) * n

Threads.@threads for i in 1:n
for j in 1:n
row = i + indadd[j]
@views @inbounds B[row, :] .+= A[row, :]
end
end

return B
end
13 changes: 5 additions & 8 deletions src/loss/ML/FIML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Analytic gradients are available.
## Implementation
Subtype of `SemLossFunction`.
"""
mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction
mutable struct SemFIML{INV, C, L, O, M, IM, I, T, W} <: SemLossFunction
inverses::INV #preallocated inverses of imp_cov
choleskys::C #preallocated choleskys
logdets::L #logdets of implied covmats
Expand All @@ -37,7 +37,7 @@ mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction

mult::T

commutation_indices::U
commutator::CommutationMatrix

interaction::W
end
Expand All @@ -64,8 +64,6 @@ function SemFIML(; observed, specification, kwargs...)
∇ind =
[findall(x -> !(x[1] ind || x[2] ind), ∇ind) for ind in patterns_not(observed)]

commutation_indices = get_commutation_lookup(get_n_nodes(specification)^2)

return SemFIML(
inverses,
choleskys,
Expand All @@ -75,7 +73,7 @@ function SemFIML(; observed, specification, kwargs...)
meandiff,
imp_inv,
mult,
commutation_indices,
CommutationMatrix(get_n_nodes(specification)),
nothing,
)
end
Expand Down Expand Up @@ -163,10 +161,9 @@ function ∇F_fiml_outer(JΣ, Jμ, imply, model, semfiml)
Iₙ = sparse(1.0I, size(A(imply))...)
P = kron(F⨉I_A⁻¹(imply), F⨉I_A⁻¹(imply))
Q = kron(S(imply) * I_A⁻¹(imply)', Iₙ)
#commutation_matrix_pre_square_add!(Q, Q)
Q2 = commutation_matrix_pre_square(Q, semfiml.commutation_indices)
Q .+= semfiml.commutator * Q

∇Σ = P * (∇S(imply) + (Q + Q2) * ∇A(imply))
∇Σ = P * (∇S(imply) + Q * ∇A(imply))

∇μ =
F⨉I_A⁻¹(imply) * ∇M(imply) +
Expand Down
49 changes: 49 additions & 0 deletions test/unit_tests/matrix_helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
using StructuralEquationModels, Test, Random, SparseArrays, LinearAlgebra
using StructuralEquationModels:
CommutationMatrix, transpose_linear_indices, duplication_matrix, elimination_matrix

Random.seed!(73721)

n = 4
m = 5

@testset "Commutation matrix" begin
# transpose linear indices
A = rand(n, m)
@test reshape(A[transpose_linear_indices(n, m)], m, n) == A'
# commutation matrix multiplication
K = CommutationMatrix(n)
# test K array interface methods
@test size(K) == (n^2, n^2)
@test size(K, 1) == n^2
@test length(K) == n^4
nn_linind = LinearIndices((n, n))
@test K[nn_linind[3, 2], nn_linind[2, 3]] == 1
@test K[nn_linind[3, 2], nn_linind[3, 2]] == 0

B = rand(n, n)
@test_throws DimensionMismatch K * rand(n, m)
@test K * vec(B) == vec(B')
C = sprand(n, n, 0.5)
@test K * vec(C) == vec(C')
# lmul!
D = sprand(n^2, n^2, 0.1)
E = copy(D)
F = Matrix(E)
lmul!(K, D)
@test D == K * E
@test Matrix(D) == K * F
end

@testset "Duplication / elimination matrix" begin
A = rand(m, m)
A = A * A'

# dupication
D = duplication_matrix(m)
@test D * A[tril(trues(size(A)))] == vec(A)

# elimination
E = elimination_matrix(m)
@test E * vec(A) == A[tril(trues(size(A)))]
end
4 changes: 4 additions & 0 deletions test/unit_tests/unit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ end
@safetestset "SemObs" begin
include("data_input_formats.jl")
end

@safetestset "Matrix algebra helper functions" begin
include("matrix_helpers.jl")
end

0 comments on commit cdc4415

Please sign in to comment.