diff --git a/src/StructuralEquationModels.jl b/src/StructuralEquationModels.jl index 048b7181c..113022960 100644 --- a/src/StructuralEquationModels.jl +++ b/src/StructuralEquationModels.jl @@ -24,6 +24,10 @@ const SEM = StructuralEquationModels # type hierarchy include("types.jl") include("objective_gradient_hessian.jl") + +# helper objects and functions +include("additional_functions/commutation_matrix.jl") + # fitted objects include("frontend/fit/SemFit.jl") # specification of models diff --git a/src/additional_functions/commutation_matrix.jl b/src/additional_functions/commutation_matrix.jl new file mode 100644 index 000000000..345f809e0 --- /dev/null +++ b/src/additional_functions/commutation_matrix.jl @@ -0,0 +1,75 @@ +""" + + transpose_linear_indices(n, [m]) + +Put each linear index of the *n×m* matrix to the position of the +corresponding element in the transposed matrix. + +## Example +` +1 4 +2 5 => 1 2 3 +3 6 4 5 6 +` +""" +transpose_linear_indices(n::Integer, m::Integer = n) = + repeat(1:n, inner = m) .+ repeat((0:(m-1)) * n, outer = n) + +""" + CommutationMatrix(n::Integer) <: AbstractMatrix{Int} + +A *commutation matrix* *C* is a n²×n² matrix of 0s and 1s. +If *vec(A)* is a vectorized form of a n×n matrix *A*, +then ``C * vec(A) = vec(Aᵀ)``. +""" +struct CommutationMatrix <: AbstractMatrix{Int} + n::Int + n²::Int + transpose_inds::Vector{Int} # maps the linear indices of n×n matrix *B* to the indices of matrix *B'* + + CommutationMatrix(n::Integer) = new(n, n^2, transpose_linear_indices(n)) +end + +Base.size(A::CommutationMatrix) = (A.n², A.n²) +Base.size(A::CommutationMatrix, dim::Integer) = + 1 <= dim <= 2 ? A.n² : throw(ArgumentError("invalid matrix dimension $dim")) +Base.length(A::CommutationMatrix) = A.n²^2 +Base.getindex(A::CommutationMatrix, i::Int, j::Int) = j == A.transpose_inds[i] ? 1 : 0 + +function Base.:(*)(A::CommutationMatrix, B::AbstractVector) + size(A, 2) == size(B, 1) || throw( + DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) elements"), + ) + return B[A.transpose_inds] +end + +function Base.:(*)(A::CommutationMatrix, B::AbstractMatrix) + size(A, 2) == size(B, 1) || throw( + DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"), + ) + return B[A.transpose_inds, :] +end + +function Base.:(*)(A::CommutationMatrix, B::SparseMatrixCSC) + size(A, 2) == size(B, 1) || throw( + DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"), + ) + return SparseMatrixCSC( + size(B, 1), + size(B, 2), + copy(B.colptr), + A.transpose_inds[B.rowval], + copy(B.nzval), + ) +end + +function LinearAlgebra.lmul!(A::CommutationMatrix, B::SparseMatrixCSC) + size(A, 2) == size(B, 1) || throw( + DimensionMismatch("A has $(size(A, 2)) columns, but B has $(size(B, 1)) rows"), + ) + + @inbounds for (i, rowind) in enumerate(B.rowval) + B.rowval[i] = A.transpose_inds[rowind] + end + return B +end diff --git a/src/additional_functions/helper.jl b/src/additional_functions/helper.jl index abc37207c..b96813dc3 100644 --- a/src/additional_functions/helper.jl +++ b/src/additional_functions/helper.jl @@ -41,7 +41,7 @@ function get_observed(rowind, data, semobserved; args = (), kwargs = NamedTuple( return observed_vec end -skipmissing_mean(mat::AbstractMatrix) = +skipmissing_mean(mat::AbstractMatrix) = [mean(skipmissing(coldata)) for coldata in eachcol(mat)] function F_one_person(imp_mean, meandiff, inverse, data, logdet) @@ -111,143 +111,34 @@ function cov_and_mean(rows; corrected = false) return obs_cov, vec(obs_mean) end -function duplication_matrix(nobs) - nobs = Int(nobs) - n1 = Int(nobs * (nobs + 1) * 0.5) - n2 = Int(nobs^2) - Dt = zeros(n1, n2) - - for j in 1:nobs - for i in j:nobs - u = zeros(n1) - u[Int((j - 1) * nobs + i - 0.5 * j * (j - 1))] = 1 - T = zeros(nobs, nobs) - T[j, i] = 1 - T[i, j] = 1 - Dt += u * transpose(vec(T)) +# n²×(n(n+1)/2) matrix to transform a vector of lower +# triangular entries into a vectorized form of a n×n symmetric matrix, +# opposite of elimination_matrix() +function duplication_matrix(n::Integer) + ntri = div(n * (n + 1), 2) + D = zeros(n^2, ntri) + for j in 1:n + for i in j:n + tri_ix = (j - 1) * n + i - div(j * (j - 1), 2) + D[j+n*(i-1), tri_ix] = 1 + D[i+n*(j-1), tri_ix] = 1 end end - D = transpose(Dt) return D end -function elimination_matrix(nobs) - nobs = Int(nobs) - n1 = Int(nobs * (nobs + 1) * 0.5) - n2 = Int(nobs^2) - L = zeros(n1, n2) - - for j in 1:nobs - for i in j:nobs - u = zeros(n1) - u[Int((j - 1) * nobs + i - 0.5 * j * (j - 1))] = 1 - T = zeros(nobs, nobs) - T[i, j] = 1 - L += u * transpose(vec(T)) +# (n(n+1)/2)×n² matrix to transform a +# vectorized form of a n×n symmetric matrix +# into vector of its lower triangular entries, +# opposite of duplication_matrix() +function elimination_matrix(n::Integer) + ntri = div(n * (n + 1), 2) + L = zeros(ntri, n^2) + for j in 1:n + for i in j:n + tri_ix = (j - 1) * n + i - div(j * (j - 1), 2) + L[tri_ix, i+n*(j-1)] = 1 end end return L end - -function commutation_matrix(n; tosparse = false) - M = zeros(n^2, n^2) - - for i in 1:n - for j in 1:n - M[i+n*(j-1), j+n*(i-1)] = 1.0 - end - end - - if tosparse - M = sparse(M) - end - - return M -end - -function commutation_matrix_pre_square(A) - n2 = size(A, 1) - n = Int(sqrt(n2)) - - ind = repeat(1:n, inner = n) - indadd = (0:(n-1)) * n - for i in 1:n - ind[((i-1)*n+1):i*n] .+= indadd - end - - A_post = A[ind, :] - - return A_post -end - -function commutation_matrix_pre_square_add!(B, A) # comuptes B + KₙA - n2 = size(A, 1) - n = Int(sqrt(n2)) - - ind = repeat(1:n, inner = n) - indadd = (0:(n-1)) * n - for i in 1:n - ind[((i-1)*n+1):i*n] .+= indadd - end - - @views @inbounds B .+= A[ind, :] - - return B -end - -function get_commutation_lookup(n2::Int64) - n = Int(sqrt(n2)) - ind = repeat(1:n, inner = n) - indadd = (0:(n-1)) * n - for i in 1:n - ind[((i-1)*n+1):i*n] .+= indadd - end - - lookup = Dict{Int64, Int64}() - - for i in 1:n2 - j = findall(x -> (x == i), ind)[1] - push!(lookup, i => j) - end - - return lookup -end - -function commutation_matrix_pre_square!(A::SparseMatrixCSC, lookup) # comuptes B + KₙA - for (i, rowind) in enumerate(A.rowval) - A.rowval[i] = lookup[rowind] - end -end - -function commutation_matrix_pre_square!(A::SparseMatrixCSC) # computes KₙA - lookup = get_commutation_lookup(size(A, 2)) - commutation_matrix_pre_square!(A, lookup) -end - -function commutation_matrix_pre_square(A::SparseMatrixCSC) - B = copy(A) - commutation_matrix_pre_square!(B) - return B -end - -function commutation_matrix_pre_square(A::SparseMatrixCSC, lookup) - B = copy(A) - commutation_matrix_pre_square!(B, lookup) - return B -end - -function commutation_matrix_pre_square_add_mt!(B, A) # comuptes B + KₙA # 0 allocations but slower - n2 = size(A, 1) - n = Int(sqrt(n2)) - - indadd = (0:(n-1)) * n - - Threads.@threads for i in 1:n - for j in 1:n - row = i + indadd[j] - @views @inbounds B[row, :] .+= A[row, :] - end - end - - return B -end diff --git a/src/loss/ML/FIML.jl b/src/loss/ML/FIML.jl index 7a27e7615..1cc7c123c 100644 --- a/src/loss/ML/FIML.jl +++ b/src/loss/ML/FIML.jl @@ -24,7 +24,7 @@ Analytic gradients are available. ## Implementation Subtype of `SemLossFunction`. """ -mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction +mutable struct SemFIML{INV, C, L, O, M, IM, I, T, W} <: SemLossFunction inverses::INV #preallocated inverses of imp_cov choleskys::C #preallocated choleskys logdets::L #logdets of implied covmats @@ -37,7 +37,7 @@ mutable struct SemFIML{INV, C, L, O, M, IM, I, T, U, W} <: SemLossFunction mult::T - commutation_indices::U + commutator::CommutationMatrix interaction::W end @@ -64,8 +64,6 @@ function SemFIML(; observed, specification, kwargs...) ∇ind = [findall(x -> !(x[1] ∈ ind || x[2] ∈ ind), ∇ind) for ind in patterns_not(observed)] - commutation_indices = get_commutation_lookup(get_n_nodes(specification)^2) - return SemFIML( inverses, choleskys, @@ -75,7 +73,7 @@ function SemFIML(; observed, specification, kwargs...) meandiff, imp_inv, mult, - commutation_indices, + CommutationMatrix(get_n_nodes(specification)), nothing, ) end @@ -163,10 +161,9 @@ function ∇F_fiml_outer(JΣ, Jμ, imply, model, semfiml) Iₙ = sparse(1.0I, size(A(imply))...) P = kron(F⨉I_A⁻¹(imply), F⨉I_A⁻¹(imply)) Q = kron(S(imply) * I_A⁻¹(imply)', Iₙ) - #commutation_matrix_pre_square_add!(Q, Q) - Q2 = commutation_matrix_pre_square(Q, semfiml.commutation_indices) + Q .+= semfiml.commutator * Q - ∇Σ = P * (∇S(imply) + (Q + Q2) * ∇A(imply)) + ∇Σ = P * (∇S(imply) + Q * ∇A(imply)) ∇μ = F⨉I_A⁻¹(imply) * ∇M(imply) + diff --git a/test/unit_tests/matrix_helpers.jl b/test/unit_tests/matrix_helpers.jl new file mode 100644 index 000000000..b2f32f31a --- /dev/null +++ b/test/unit_tests/matrix_helpers.jl @@ -0,0 +1,49 @@ +using StructuralEquationModels, Test, Random, SparseArrays, LinearAlgebra +using StructuralEquationModels: + CommutationMatrix, transpose_linear_indices, duplication_matrix, elimination_matrix + +Random.seed!(73721) + +n = 4 +m = 5 + +@testset "Commutation matrix" begin + # transpose linear indices + A = rand(n, m) + @test reshape(A[transpose_linear_indices(n, m)], m, n) == A' + # commutation matrix multiplication + K = CommutationMatrix(n) + # test K array interface methods + @test size(K) == (n^2, n^2) + @test size(K, 1) == n^2 + @test length(K) == n^4 + nn_linind = LinearIndices((n, n)) + @test K[nn_linind[3, 2], nn_linind[2, 3]] == 1 + @test K[nn_linind[3, 2], nn_linind[3, 2]] == 0 + + B = rand(n, n) + @test_throws DimensionMismatch K * rand(n, m) + @test K * vec(B) == vec(B') + C = sprand(n, n, 0.5) + @test K * vec(C) == vec(C') + # lmul! + D = sprand(n^2, n^2, 0.1) + E = copy(D) + F = Matrix(E) + lmul!(K, D) + @test D == K * E + @test Matrix(D) == K * F +end + +@testset "Duplication / elimination matrix" begin + A = rand(m, m) + A = A * A' + + # dupication + D = duplication_matrix(m) + @test D * A[tril(trues(size(A)))] == vec(A) + + # elimination + E = elimination_matrix(m) + @test E * vec(A) == A[tril(trues(size(A)))] +end diff --git a/test/unit_tests/unit_tests.jl b/test/unit_tests/unit_tests.jl index 87fdde2f1..eb58650c1 100644 --- a/test/unit_tests/unit_tests.jl +++ b/test/unit_tests/unit_tests.jl @@ -7,3 +7,7 @@ end @safetestset "SemObs" begin include("data_input_formats.jl") end + +@safetestset "Matrix algebra helper functions" begin + include("matrix_helpers.jl") +end