Skip to content

Commit 7994476

Browse files
committed
add Sem unit tests
1 parent b99ef4e commit 7994476

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

test/unit_tests/model.jl

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
using StructuralEquationModels, Test, Statistics
2+
using StructuralEquationModels:
3+
SemSpecification,
4+
samples,
5+
nsamples,
6+
observed_vars,
7+
nobserved_vars,
8+
obs_cov,
9+
obs_mean,
10+
vars,
11+
nvars,
12+
observed_vars,
13+
latent_vars,
14+
nobserved_vars,
15+
nlatent_vars,
16+
params,
17+
nparams
18+
19+
dat = example_data("political_democracy")
20+
dat_missing = example_data("political_democracy_missing")[:, names(dat)]
21+
22+
obs_vars = [Symbol.("x", 1:3); Symbol.("y", 1:8)]
23+
lat_vars = [:ind60, :dem60, :dem65]
24+
25+
graph = @StenoGraph begin
26+
# loadings
27+
ind60 fixed(1) * x1 + x2 + x3
28+
dem60 fixed(1) * y1 + y2 + y3 + y4
29+
dem65 fixed(1) * y5 + y6 + y7 + y8
30+
# latent regressions
31+
label(:a) * dem60 ind60
32+
dem65 dem60
33+
dem65 ind60
34+
# variances
35+
_(obs_vars) _(obs_vars)
36+
_(lat_vars) _(lat_vars)
37+
# covariances
38+
y1 y5
39+
y2 y4 + y6
40+
y3 y7
41+
y8 y4 + y6
42+
end
43+
44+
ram_matrices =
45+
RAMMatrices(ParameterTable(graph, observed_vars = obs_vars, latent_vars = lat_vars))
46+
47+
obs = SemObservedData(specification = ram_matrices, data = dat)
48+
49+
function test_vars_api(semobj, spec::SemSpecification)
50+
@test @inferred(nobserved_vars(semobj)) == nobserved_vars(spec)
51+
@test observed_vars(semobj) == observed_vars(spec)
52+
53+
@test @inferred(nlatent_vars(semobj)) == nlatent_vars(spec)
54+
@test latent_vars(semobj) == latent_vars(spec)
55+
56+
@test @inferred(nvars(semobj)) == nvars(spec)
57+
@test vars(semobj) == vars(spec)
58+
end
59+
60+
function test_params_api(semobj, spec::SemSpecification)
61+
@test @inferred(nparams(semobj)) == nparams(spec)
62+
@test @inferred(params(semobj)) == params(spec)
63+
end
64+
65+
@testset "Sem(imply=$implytype, loss=$losstype)" for implytype in (RAM, RAMSymbolic),
66+
losstype in (SemML, SemWLS)
67+
68+
model = Sem(
69+
specification = ram_matrices,
70+
observed = obs,
71+
imply = implytype,
72+
loss = losstype,
73+
)
74+
75+
@test model isa Sem
76+
@test @inferred(imply(model)) isa implytype
77+
@test @inferred(observed(model)) isa SemObserved
78+
@test @inferred(optimizer(model)) isa SemOptimizer
79+
80+
test_vars_api(model, ram_matrices)
81+
test_params_api(model, ram_matrices)
82+
83+
test_vars_api(imply(model), ram_matrices)
84+
test_params_api(imply(model), ram_matrices)
85+
86+
@test @inferred(loss(model)) isa SemLoss
87+
semloss = loss(model).functions[1]
88+
@test semloss isa losstype
89+
90+
@test @inferred(nsamples(model)) == nsamples(obs)
91+
end

test/unit_tests/unit_tests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,7 @@ end
1515
@safetestset "SemSpecification" begin
1616
include("specification.jl")
1717
end
18+
19+
@safetestset "Sem model" begin
20+
include("model.jl")
21+
end

0 commit comments

Comments
 (0)