Skip to content

Commit 2f6e8b7

Browse files
committed
add Sem unit tests
1 parent 86c172a commit 2f6e8b7

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

test/unit_tests/model.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
using StructuralEquationModels, Test, Statistics
2+
3+
dat = example_data("political_democracy")
4+
dat_missing = example_data("political_democracy_missing")[:, names(dat)]
5+
6+
obs_vars = [Symbol.("x", 1:3); Symbol.("y", 1:8)]
7+
lat_vars = [:ind60, :dem60, :dem65]
8+
9+
graph = @StenoGraph begin
10+
# loadings
11+
ind60 fixed(1) * x1 + x2 + x3
12+
dem60 fixed(1) * y1 + y2 + y3 + y4
13+
dem65 fixed(1) * y5 + y6 + y7 + y8
14+
# latent regressions
15+
label(:a) * dem60 ind60
16+
dem65 dem60
17+
dem65 ind60
18+
# variances
19+
_(obs_vars) _(obs_vars)
20+
_(lat_vars) _(lat_vars)
21+
# covariances
22+
y1 y5
23+
y2 y4 + y6
24+
y3 y7
25+
y8 y4 + y6
26+
end
27+
28+
ram_matrices =
29+
RAMMatrices(ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars))
30+
31+
obs = SemObservedData(specification = ram_matrices, data = dat)
32+
33+
function test_vars_api(semobj, spec::SemSpecification)
34+
@test @inferred(nobserved_vars(semobj)) == nobserved_vars(spec)
35+
@test observed_vars(semobj) == observed_vars(spec)
36+
37+
@test @inferred(nlatent_vars(semobj)) == nlatent_vars(spec)
38+
@test latent_vars(semobj) == latent_vars(spec)
39+
40+
@test @inferred(nvars(semobj)) == nvars(spec)
41+
@test vars(semobj) == vars(spec)
42+
end
43+
44+
function test_params_api(semobj, spec::SemSpecification)
45+
@test @inferred(nparams(semobj)) == nparams(spec)
46+
@test @inferred(params(semobj)) == params(spec)
47+
end
48+
49+
@testset "Sem(imply=$implytype, loss=$losstype)" for implytype in (RAM, RAMSymbolic),
50+
losstype in (SemML, SemWLS)
51+
52+
model = Sem(
53+
specification = ram_matrices,
54+
observed = obs,
55+
imply = implytype,
56+
loss = losstype,
57+
)
58+
59+
@test model isa Sem
60+
@test @inferred(imply(model)) isa implytype
61+
@test @inferred(observed(model)) isa SemObserved
62+
@test @inferred(optimizer(model)) isa SemOptimizer
63+
64+
test_vars_api(model, ram_matrices)
65+
test_params_api(model, ram_matrices)
66+
67+
test_vars_api(imply(model), ram_matrices)
68+
test_params_api(imply(model), ram_matrices)
69+
70+
@test @inferred(loss(model)) isa SemLoss
71+
semloss = loss(model).functions[1]
72+
@test semloss isa losstype
73+
74+
@test @inferred(nsamples(model)) == nsamples(obs)
75+
end

test/unit_tests/unit_tests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ end
1515
@safetestset "SemSpecification" begin
1616
include("specification.jl")
1717
end
18+
19+
@safetestset "Sem model" begin
20+
include("model.jl")
21+
end

0 commit comments

Comments
 (0)