Skip to content

Commit

Permalink
update MG tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Maximilian-Stefan-Ernst committed Apr 24, 2022
1 parent 031f41c commit 2325fbc
Show file tree
Hide file tree
Showing 10 changed files with 523 additions and 574 deletions.
6 changes: 3 additions & 3 deletions Artifacts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ lazy = true
url = "https://github.com/StructuralEquationModels/Data/raw/0462a36018c71b8559dc76819d56a538857593c8/datasets/holzinger_swineford/compressed/data_fiml.csv.tar.gz"

[holzinger_swineford_solution]
git-tree-sha1 = "736fd9ef52635a38a420536bd5d32848329785b1"
git-tree-sha1 = "e2ba8a13e85ea1eeaa1353bc54c9e6bf915dea3b"
lazy = true

[[holzinger_swineford_solution.download]]
sha256 = "7bd5132038f27ccb6f8fb318850613eae8d7f98d0e07738e64c0715a118745ed"
url = "https://github.com/StructuralEquationModels/Data/raw/c0e14988ebf0ccd370db6150f25126812f51329b/datasets/holzinger_swineford/compressed/holzinger_swineford_solution.tar.gz"
sha256 = "6c77146c68817302df0b28469d98ea650a158f375102bafa94dbe5945033eee2"
url = "https://github.com/StructuralEquationModels/Data/raw/a34d06ea80738f8d0f640ef421e310b03f88be23/datasets/holzinger_swineford/compressed/holzinger_swineford_solution.tar.gz"

[political_democracy]
git-tree-sha1 = "a02b07e3523570d8a27c9fe10dcac3a1e4705c33"
Expand Down
4 changes: 1 addition & 3 deletions test/examples/examples.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,4 @@ using Test, SafeTestsets

@safetestset "Political Democracy" begin include("political_democracy/political_democracy.jl") end
@safetestset "Recover Parameters" begin include("recover_parameters/recover_parameters_twofact.jl") end

#@safetestset "Multigroup" begin include("multigroup.jl") end
#@safetestset "Multigroup Parser" begin include("multigroup_parser.jl") end
@safetestset "Multigroup" begin include("multigroup/multigroup.jl") end
239 changes: 239 additions & 0 deletions test/examples/multigroup/build_models.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
####################################################################
# ML estimation
####################################################################

model_g1 = Sem(
specification = specification_g1,
data = dat_g1,
imply = RAMSymbolic
)

model_g2 = Sem(
specification = specification_g2,
data = dat_g2,
imply = RAM
)

model_ml_multigroup = SemEnsemble(model_g1, model_g2; diff = semdiff)

############################################################################
### test gradients
############################################################################

@testset "ml_gradients_multigroup" begin
@test test_gradient(model_ml_multigroup, start_test; atol = 1e-9)
end

# fit
@testset "ml_solution_multigroup" begin
solution = sem_fit(model_ml_multigroup)
update_estimate!(partable, solution)
@test compare_estimates(
partable,
solution_lav[:parameter_estimates_ml]; atol = 1e-4,
lav_groups = Dict(:Pasteur => 1, :Grant_White => 2))
end

@testset "fitmeasures/se_ml" begin
solution_ml = sem_fit(model_ml_multigroup)
@test all(test_fitmeasures(
fit_measures(solution_ml),
solution_lav[:fitmeasures_ml]; rtol = 1e-2, atol = 1e-7))

update_partable!(
partable, identifier(model_ml_multigroup), se_hessian(solution_ml), :se)
@test compare_estimates(
partable,
solution_lav[:parameter_estimates_ml]; atol = 1e-3,
col = :se, lav_col = :se,
lav_groups = Dict(:Pasteur => 1, :Grant_White => 2))
end

####################################################################
# ML estimation - user defined loss function
####################################################################

struct UserSemML <: SemLossFunction
objective
gradient
hessian
end

############################################################################
### constructor
############################################################################

UserSemML(;n_par, kwargs...) = UserSemML([1.0], zeros(n_par), zeros(n_par, n_par))

############################################################################
### functors
############################################################################

import LinearAlgebra: Symmetric, cholesky, isposdef, logdet, tr
import LinearAlgebra

function (semml::UserSemML)(par, F, G, H, model)
if G error("analytic gradient of ML is not implemented (yet)") end
if H error("analytic hessian of ML is not implemented (yet)") end

a = cholesky(Symmetric(model.imply.Σ); check = false)
if !isposdef(a)
semml.objective[1] = Inf
else
ld = logdet(a)
Σ_inv = LinearAlgebra.inv(a)
if !isnothing(F)
prod = Σ_inv*model.observed.obs_cov
semml.objective[1] = ld + tr(prod)
end
end
end

# models
model_g1 = Sem(
specification = specification_g1,
data = dat_g1,
imply = RAMSymbolic
)

model_g2 = SemFiniteDiff(
specification = specification_g2,
data = dat_g2,
imply = RAMSymbolic,
loss = UserSemML
)

model_ml_multigroup = SemEnsemble(model_g1, model_g2; diff = semdiff)

@testset "gradients_user_defined_loss" begin
@test test_gradient(model_ml_multigroup, start_test; atol = 1e-9)
end

# fit
@testset "solution_user_defined_loss" begin
solution = sem_fit(model_ml_multigroup)
update_estimate!(partable, solution)
@test compare_estimates(
partable,
solution_lav[:parameter_estimates_ml]; atol = 1e-4,
lav_groups = Dict(:Pasteur => 1, :Grant_White => 2))
end

####################################################################
# GLS estimation
####################################################################

model_ls_g1 = Sem(
specification = specification_g1,
data = dat_g1,
imply = RAMSymbolic,
loss = SemWLS
)

model_ls_g2 = Sem(
specification = specification_g2,
data = dat_g2,
imply = RAMSymbolic,
loss = SemWLS
)

model_ls_multigroup = SemEnsemble(model_ls_g1, model_ls_g2; diff = semdiff)

@testset "ls_gradients_multigroup" begin
@test test_gradient(model_ls_multigroup, start_test; atol = 1e-9)
end

@testset "ls_solution_multigroup" begin
solution = sem_fit(model_ls_multigroup)
update_estimate!(partable, solution)
@test compare_estimates(
partable,
solution_lav[:parameter_estimates_ls]; atol = 1e-4,
lav_groups = Dict(:Pasteur => 1, :Grant_White => 2))
end

@testset "fitmeasures/se_ls" begin
solution_ls = sem_fit(model_ls_multigroup)
@test all(test_fitmeasures(
fit_measures(solution_ls),
solution_lav[:fitmeasures_ls];
fitmeasure_names = fitmeasure_names_ls, rtol = 1e-2, atol = 1e-5))

update_partable!(
partable, identifier(model_ls_multigroup), se_hessian(solution_ls), :se)
@test compare_estimates(
partable,
solution_lav[:parameter_estimates_ls]; atol = 1e-2,
col = :se, lav_col = :se,
lav_groups = Dict(:Pasteur => 1, :Grant_White => 2))
end

if !isnothing(specification_miss_g1)

####################################################################
# FIML estimation
####################################################################

model_g1 = Sem(
specification = specification_miss_g1,
observed = SemObsMissing,
loss = SemFIML,
data = dat_miss_g1,
imply = RAM,
diff = SemDiffEmpty()
)

model_g2 = Sem(
specification = specification_miss_g2,
observed = SemObsMissing,
loss = SemFIML,
data = dat_miss_g2,
imply = RAM,
diff = SemDiffEmpty()
)

model_ml_multigroup = SemEnsemble(model_g1, model_g2; diff = semdiff)

############################################################################
### test gradients
############################################################################

start_test = [
fill(0.5, 6);
fill(1.0, 9);
0.05; 0.01; 0.01; 0.05; 0.01; 0.05;
fill(0.01, 9);
fill(1.0, 9);
0.05; 0.01; 0.01; 0.05; 0.01; 0.05;
fill(0.01, 9)]

@testset "fiml_gradients_multigroup" begin
@test test_gradient(model_ml_multigroup, start_test; atol = 1e-7)
end


@testset "fiml_solution_multigroup" begin
solution = sem_fit(model_ml_multigroup)
update_estimate!(partable_miss, solution)
@test compare_estimates(
partable_miss,
solution_lav[:parameter_estimates_fiml]; atol = 1e-4,
lav_groups = Dict(:Pasteur => 1, :Grant_White => 2))
end

@testset "fitmeasures/se_fiml" begin
solution = sem_fit(model_ml_multigroup)
@test all(test_fitmeasures(
fit_measures(solution),
solution_lav[:fitmeasures_fiml]; rtol = 1e-3, atol = 0))

update_partable!(
partable_miss, identifier(model_ml_multigroup), se_hessian(solution), :se)
@test compare_estimates(
partable_miss,
solution_lav[:parameter_estimates_fiml]; atol = 1e-3,
col = :se, lav_col = :se,
lav_groups = Dict(:Pasteur => 1, :Grant_White => 2))
end

end
Loading

0 comments on commit 2325fbc

Please sign in to comment.