Skip to content

Commit 86c172a

Browse files
committed
refactor SemSpec tests
1 parent c0e2c9e commit 86c172a

File tree

2 files changed

+124
-11
lines changed

2 files changed

+124
-11
lines changed

test/unit_tests/specification.jl

Lines changed: 120 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,131 @@
1-
@testset "ParameterTable - RAMMatrices conversion" begin
2-
partable = ParameterTable(ram_matrices)
3-
@test ram_matrices == RAMMatrices(partable)
4-
end
1+
using StructuralEquationModels
52

6-
@testset "params()" begin
7-
@test params(model_ml)[2, 10, 28] == [:x2, :x10, :x28]
8-
@test params(model_ml) == params(partable)
9-
@test params(model_ml) == params(RAMMatrices(partable))
10-
end
3+
obs_vars = Symbol.("x", 1:9)
4+
lat_vars = [:visual, :textual, :speed]
115

126
graph = @StenoGraph begin
7+
# measurement model
8+
visual fixed(1.0) * x1 + fixed(0.5) * x2 + fixed(0.6) * x3
9+
textual fixed(1.0) * x4 + x5 + label(:a₁) * x6
10+
speed fixed(1.0) * x7 + fixed(1.0) * x8 + label(:λ₉) * x9
11+
# variances and covariances
12+
_(obs_vars) _(obs_vars)
13+
_(lat_vars) _(lat_vars)
14+
visual textual + speed
15+
textual speed
16+
end
17+
18+
ens_graph = @StenoGraph begin
1319
# measurement model
1420
visual fixed(1.0, 1.0) * x1 + fixed(0.5, 0.5) * x2 + fixed(0.6, 0.8) * x3
1521
textual fixed(1.0, 1.0) * x4 + x5 + label(:a₁, :a₂) * x6
1622
speed fixed(1.0, 1.0) * x7 + fixed(1.0, NaN) * x8 + label(:λ₉, :λ₉) * x9
1723
# variances and covariances
18-
_(observed_vars) _(observed_vars)
19-
_(latent_vars) _(latent_vars)
24+
_(obs_vars) _(obs_vars)
25+
_(lat_vars) _(lat_vars)
2026
visual textual + speed
2127
textual speed
2228
end
29+
30+
@testset "ParameterTable" begin
31+
@testset "from StenoGraph" begin
32+
@test_throws UndefKeywordError(:observed_vars) ParameterTable(graph)
33+
@test_throws UndefKeywordError(:latent_vars) ParameterTable(
34+
graph,
35+
observed_vars = obs_vars,
36+
)
37+
partable = @inferred(
38+
ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars)
39+
)
40+
41+
@test partable isa ParameterTable
42+
43+
# vars API
44+
@test observed_vars(partable) == obs_vars
45+
@test nobserved_vars(partable) == length(obs_vars)
46+
@test latent_vars(partable) == lat_vars
47+
@test nlatent_vars(partable) == length(lat_vars)
48+
@test nvars(partable) == length(obs_vars) + length(lat_vars)
49+
@test issetequal(vars(partable), [obs_vars; lat_vars])
50+
51+
# params API
52+
@test params(partable) == [[:θ_1, :a₁, :λ₉]; Symbol.("θ_", 2:16)]
53+
@test nparams(partable) == 18
54+
55+
# don't allow constructing ParameterTable from a graph for an ensemble
56+
@test_throws ArgumentError ParameterTable(
57+
ens_graph,
58+
observed_vars = obs_vars,
59+
latent_vars = lat_vars,
60+
)
61+
end
62+
63+
@testset "from RAMMatrices" begin
64+
partable_orig =
65+
ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars)
66+
ram_matrices = RAMMatrices(partable_orig)
67+
68+
partable = @inferred(ParameterTable(ram_matrices))
69+
@test partable isa ParameterTable
70+
@test issetequal(keys(partable.columns), keys(partable_orig.columns))
71+
# FIXME nrow()?
72+
@test length(partable.columns[:from]) == length(partable_orig.columns[:from])
73+
@test partable == partable_orig broken = true
74+
end
75+
end
76+
77+
@testset "EnsembleParameterTable" begin
78+
groups = [:Pasteur, :Grant_White],
79+
@test_throws UndefKeywordError(:observed_vars) EnsembleParameterTable(ens_graph)
80+
@test_throws UndefKeywordError(:latent_vars) EnsembleParameterTable(
81+
ens_graph,
82+
observed_vars = obs_vars,
83+
)
84+
@test_throws UndefKeywordError(:groups) EnsembleParameterTable(
85+
ens_graph,
86+
observed_vars = obs_vars,
87+
latent_vars = lat_vars,
88+
)
89+
90+
enspartable = @inferred(
91+
EnsembleParameterTable(
92+
ens_graph,
93+
observed_vars = obs_vars,
94+
latent_vars = lat_vars,
95+
groups = [:Pasteur, :Grant_White],
96+
)
97+
)
98+
@test enspartable isa EnsembleParameterTable
99+
100+
@test nobserved_vars(enspartable) == length(obs_vars) broken = true
101+
@test observed_vars(enspartable) == obs_vars broken = true
102+
@test nlatent_vars(enspartable) == length(lat_vars) broken = true
103+
@test latent_vars(enspartable) == lat_vars broken = true
104+
@test nvars(enspartable) == length(obs_vars) + length(lat_vars) broken = true
105+
@test issetequal(vars(enspartable), [obs_vars; lat_vars]) broken = true
106+
107+
@test nparams(enspartable) == 36
108+
@test issetequal(
109+
params(enspartable),
110+
[Symbol.("gPasteur_", 1:16); Symbol.("gGrant_White_", 1:17); [:a₁, :a₂, :λ₉]],
111+
)
112+
end
113+
114+
@testset "RAMMatrices" begin
115+
partable = ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars)
116+
117+
ram_matrices = @inferred(RAMMatrices(partable))
118+
@test ram_matrices isa RAMMatrices
119+
120+
# vars API
121+
@test nobserved_vars(ram_matrices) == length(obs_vars)
122+
@test observed_vars(ram_matrices) == obs_vars
123+
@test nlatent_vars(ram_matrices) == length(lat_vars)
124+
@test latent_vars(ram_matrices) == lat_vars
125+
@test nvars(ram_matrices) == length(obs_vars) + length(lat_vars)
126+
@test issetequal(vars(ram_matrices), [obs_vars; lat_vars])
127+
128+
# params API
129+
@test nparams(ram_matrices) == nparams(partable)
130+
@test params(ram_matrices) == params(partable)
131+
end

test/unit_tests/unit_tests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ end
1111
@safetestset "SemObserved" begin
1212
include("data_input_formats.jl")
1313
end
14+
15+
@safetestset "SemSpecification" begin
16+
include("specification.jl")
17+
end

0 commit comments

Comments
 (0)