|
1 |
| -@testset "ParameterTable - RAMMatrices conversion" begin |
2 |
| - partable = ParameterTable(ram_matrices) |
3 |
| - @test ram_matrices == RAMMatrices(partable) |
4 |
| -end |
| 1 | +using StructuralEquationModels |
5 | 2 |
|
6 |
| -@testset "params()" begin |
7 |
| - @test params(model_ml)[2, 10, 28] == [:x2, :x10, :x28] |
8 |
| - @test params(model_ml) == params(partable) |
9 |
| - @test params(model_ml) == params(RAMMatrices(partable)) |
10 |
| -end |
| 3 | +obs_vars = Symbol.("x", 1:9) |
| 4 | +lat_vars = [:visual, :textual, :speed] |
11 | 5 |
|
12 | 6 | graph = @StenoGraph begin
|
| 7 | + # measurement model |
| 8 | + visual → fixed(1.0) * x1 + fixed(0.5) * x2 + fixed(0.6) * x3 |
| 9 | + textual → fixed(1.0) * x4 + x5 + label(:a₁) * x6 |
| 10 | + speed → fixed(1.0) * x7 + fixed(1.0) * x8 + label(:λ₉) * x9 |
| 11 | + # variances and covariances |
| 12 | + _(obs_vars) ↔ _(obs_vars) |
| 13 | + _(lat_vars) ↔ _(lat_vars) |
| 14 | + visual ↔ textual + speed |
| 15 | + textual ↔ speed |
| 16 | +end |
| 17 | + |
| 18 | +ens_graph = @StenoGraph begin |
13 | 19 | # measurement model
|
14 | 20 | visual → fixed(1.0, 1.0) * x1 + fixed(0.5, 0.5) * x2 + fixed(0.6, 0.8) * x3
|
15 | 21 | textual → fixed(1.0, 1.0) * x4 + x5 + label(:a₁, :a₂) * x6
|
16 | 22 | speed → fixed(1.0, 1.0) * x7 + fixed(1.0, NaN) * x8 + label(:λ₉, :λ₉) * x9
|
17 | 23 | # variances and covariances
|
18 |
| - _(observed_vars) ↔ _(observed_vars) |
19 |
| - _(latent_vars) ↔ _(latent_vars) |
| 24 | + _(obs_vars) ↔ _(obs_vars) |
| 25 | + _(lat_vars) ↔ _(lat_vars) |
20 | 26 | visual ↔ textual + speed
|
21 | 27 | textual ↔ speed
|
22 | 28 | end
|
| 29 | + |
| 30 | +@testset "ParameterTable" begin |
| 31 | + @testset "from StenoGraph" begin |
| 32 | + @test_throws UndefKeywordError(:observed_vars) ParameterTable(graph) |
| 33 | + @test_throws UndefKeywordError(:latent_vars) ParameterTable( |
| 34 | + graph, |
| 35 | + observed_vars = obs_vars, |
| 36 | + ) |
| 37 | + partable = @inferred( |
| 38 | + ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars) |
| 39 | + ) |
| 40 | + |
| 41 | + @test partable isa ParameterTable |
| 42 | + |
| 43 | + # vars API |
| 44 | + @test observed_vars(partable) == obs_vars |
| 45 | + @test nobserved_vars(partable) == length(obs_vars) |
| 46 | + @test latent_vars(partable) == lat_vars |
| 47 | + @test nlatent_vars(partable) == length(lat_vars) |
| 48 | + @test nvars(partable) == length(obs_vars) + length(lat_vars) |
| 49 | + @test issetequal(vars(partable), [obs_vars; lat_vars]) |
| 50 | + |
| 51 | + # params API |
| 52 | + @test params(partable) == [[:θ_1, :a₁, :λ₉]; Symbol.("θ_", 2:16)] |
| 53 | + @test nparams(partable) == 18 |
| 54 | + |
| 55 | + # don't allow constructing ParameterTable from a graph for an ensemble |
| 56 | + @test_throws ArgumentError ParameterTable( |
| 57 | + ens_graph, |
| 58 | + observed_vars = obs_vars, |
| 59 | + latent_vars = lat_vars, |
| 60 | + ) |
| 61 | + end |
| 62 | + |
| 63 | + @testset "from RAMMatrices" begin |
| 64 | + partable_orig = |
| 65 | + ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars) |
| 66 | + ram_matrices = RAMMatrices(partable_orig) |
| 67 | + |
| 68 | + partable = @inferred(ParameterTable(ram_matrices)) |
| 69 | + @test partable isa ParameterTable |
| 70 | + @test issetequal(keys(partable.columns), keys(partable_orig.columns)) |
| 71 | + # FIXME nrow()? |
| 72 | + @test length(partable.columns[:from]) == length(partable_orig.columns[:from]) |
| 73 | + @test partable == partable_orig broken = true |
| 74 | + end |
| 75 | +end |
| 76 | + |
| 77 | +@testset "EnsembleParameterTable" begin |
| 78 | + groups = [:Pasteur, :Grant_White], |
| 79 | + @test_throws UndefKeywordError(:observed_vars) EnsembleParameterTable(ens_graph) |
| 80 | + @test_throws UndefKeywordError(:latent_vars) EnsembleParameterTable( |
| 81 | + ens_graph, |
| 82 | + observed_vars = obs_vars, |
| 83 | + ) |
| 84 | + @test_throws UndefKeywordError(:groups) EnsembleParameterTable( |
| 85 | + ens_graph, |
| 86 | + observed_vars = obs_vars, |
| 87 | + latent_vars = lat_vars, |
| 88 | + ) |
| 89 | + |
| 90 | + enspartable = @inferred( |
| 91 | + EnsembleParameterTable( |
| 92 | + ens_graph, |
| 93 | + observed_vars = obs_vars, |
| 94 | + latent_vars = lat_vars, |
| 95 | + groups = [:Pasteur, :Grant_White], |
| 96 | + ) |
| 97 | + ) |
| 98 | + @test enspartable isa EnsembleParameterTable |
| 99 | + |
| 100 | + @test nobserved_vars(enspartable) == length(obs_vars) broken = true |
| 101 | + @test observed_vars(enspartable) == obs_vars broken = true |
| 102 | + @test nlatent_vars(enspartable) == length(lat_vars) broken = true |
| 103 | + @test latent_vars(enspartable) == lat_vars broken = true |
| 104 | + @test nvars(enspartable) == length(obs_vars) + length(lat_vars) broken = true |
| 105 | + @test issetequal(vars(enspartable), [obs_vars; lat_vars]) broken = true |
| 106 | + |
| 107 | + @test nparams(enspartable) == 36 |
| 108 | + @test issetequal( |
| 109 | + params(enspartable), |
| 110 | + [Symbol.("gPasteur_", 1:16); Symbol.("gGrant_White_", 1:17); [:a₁, :a₂, :λ₉]], |
| 111 | + ) |
| 112 | +end |
| 113 | + |
| 114 | +@testset "RAMMatrices" begin |
| 115 | + partable = ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars) |
| 116 | + |
| 117 | + ram_matrices = @inferred(RAMMatrices(partable)) |
| 118 | + @test ram_matrices isa RAMMatrices |
| 119 | + |
| 120 | + # vars API |
| 121 | + @test nobserved_vars(ram_matrices) == length(obs_vars) |
| 122 | + @test observed_vars(ram_matrices) == obs_vars |
| 123 | + @test nlatent_vars(ram_matrices) == length(lat_vars) |
| 124 | + @test latent_vars(ram_matrices) == lat_vars |
| 125 | + @test nvars(ram_matrices) == length(obs_vars) + length(lat_vars) |
| 126 | + @test issetequal(vars(ram_matrices), [obs_vars; lat_vars]) |
| 127 | + |
| 128 | + # params API |
| 129 | + @test nparams(ram_matrices) == nparams(partable) |
| 130 | + @test params(ram_matrices) == params(partable) |
| 131 | +end |
0 commit comments