Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion ext/SEMProximalOptExt/ProximalAlgorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ function ProximalAlgorithms.value_and_gradient(model::AbstractSem, params)
return obj, grad
end


mutable struct ProximalResult
result::Any
end
Expand Down
19 changes: 8 additions & 11 deletions src/frontend/fit/summary.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,4 @@
function details(
sem_fit::SemFit;
show_fitmeasures = false,
color = :light_cyan,
digits = 2,
)
function details(sem_fit::SemFit; show_fitmeasures = false, color = :light_cyan, digits = 2)
print("\n")
println("Fitted Structural Equation Model")
print("\n")
Expand Down Expand Up @@ -51,7 +46,7 @@ function details(
secondary_color = :light_yellow,
digits = 2,
show_variables = true,
show_columns = nothing
show_columns = nothing,
)
if show_variables
print("\n")
Expand Down Expand Up @@ -150,7 +145,8 @@ function details(
check_round(partable.columns[c][regression_indices]; digits = digits) for
c in regression_columns
)
regression_columns[2] = regression_columns[2] == :relation ? Symbol("") : regression_columns[2]
regression_columns[2] =
regression_columns[2] == :relation ? Symbol("") : regression_columns[2]

print("\n")
pretty_table(
Expand Down Expand Up @@ -222,7 +218,8 @@ function details(
printstyled("Means: \n"; color = color)

if isnothing(show_columns)
sorted_columns = [:from, :relation, :to, :estimate, :param, :value_fixed, :start]
sorted_columns =
[:from, :relation, :to, :estimate, :param, :value_fixed, :start]
mean_columns = sort_partially(sorted_columns, columns)
else
mean_columns = copy(show_columns)
Expand Down Expand Up @@ -256,7 +253,7 @@ function details(
secondary_color = :light_yellow,
digits = 2,
show_variables = true,
show_columns = nothing
show_columns = nothing,
)
if show_variables
print("\n")
Expand Down Expand Up @@ -297,7 +294,7 @@ function details(
secondary_color = secondary_color,
digits = digits,
show_variables = false,
show_columns = show_columns
show_columns = show_columns,
)
end

Expand Down
8 changes: 3 additions & 5 deletions src/frontend/specification/EnsembleParameterTable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ EnsembleParameterTable(::Nothing; params::Union{Nothing, Vector{Symbol}} = nothi
)

# convert pairs to dict
EnsembleParameterTable(ps::Pair{K, V}...; params = nothing) where {K, V} =
EnsembleParameterTable(ps::Pair{K, V}...; params = nothing) where {K, V} =
EnsembleParameterTable(Dict(ps...); params = params)

# dictionary of SEM specifications
Expand Down Expand Up @@ -148,8 +148,6 @@ end
############################################################################################

function Base.:(==)(p1::EnsembleParameterTable, p2::EnsembleParameterTable)
out =
(p1.tables == p2.tables) &&
(p1.params == p2.params)
out = (p1.tables == p2.tables) && (p1.params == p2.params)
return out
end
end
2 changes: 1 addition & 1 deletion src/frontend/specification/ParameterTable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ end

# Equality --------------------------------------------------------------------------------
function Base.:(==)(p1::ParameterTable, p2::ParameterTable)
out =
out =
(p1.columns == p2.columns) &&
(p1.observed_vars == p2.observed_vars) &&
(p1.latent_vars == p2.latent_vars) &&
Expand Down
32 changes: 19 additions & 13 deletions src/frontend/specification/RAMMatrices.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,10 @@ function RAMMatrices(
@assert length(partable.sorted_vars) == nvars(partable)
vars_sorted = copy(partable.sorted_vars)
else
vars_sorted = [partable.observed_vars
partable.latent_vars]
vars_sorted = [
partable.observed_vars
partable.latent_vars
]
end

# indices of the vars (A/S/M rows or columns)
Expand Down Expand Up @@ -216,13 +218,20 @@ function RAMMatrices(
sort!(M_consts, by = first)
end

return RAMMatrices(ParamsMatrix{T}(A_inds, A_consts, (n_vars, n_vars)),
ParamsMatrix{T}(S_inds, S_consts, (n_vars, n_vars)),
sparse(1:n_observed,
[vars_index[var] for var in partable.observed_vars],
ones(T, n_observed), n_observed, n_vars),
!isnothing(M_inds) ? ParamsVector{T}(M_inds, M_consts, (n_vars,)) : nothing,
params, vars_sorted)
return RAMMatrices(
ParamsMatrix{T}(A_inds, A_consts, (n_vars, n_vars)),
ParamsMatrix{T}(S_inds, S_consts, (n_vars, n_vars)),
sparse(
1:n_observed,
[vars_index[var] for var in partable.observed_vars],
ones(T, n_observed),
n_observed,
n_vars,
),
!isnothing(M_inds) ? ParamsVector{T}(M_inds, M_consts, (n_vars,)) : nothing,
params,
vars_sorted,
)
end

Base.convert(
Expand Down Expand Up @@ -360,10 +369,7 @@ function append_rows!(
arr_ix = arr_ixs[arr.linear_indices[j]]
skip_symmetric && (arr_ix ∈ visited_indices) && continue

push!(
partable,
partable_row(par, arr_ix, arr_name, varnames, free = true),
)
push!(partable, partable_row(par, arr_ix, arr_name, varnames, free = true))
if skip_symmetric
# mark index and its symmetric as visited
push!(visited_indices, arr_ix)
Expand Down
5 changes: 3 additions & 2 deletions src/implied/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ function check_acyclic(A::AbstractMatrix; verbose::Bool = false)
# check if non-triangular matrix is acyclic
acyclic = isone(det(I - A))
if acyclic
verbose && @info "The matrix is acyclic. Reordering variables in the model to make the A matrix either Upper or Lower Triangular can significantly improve performance.\n" maxlog =
1
verbose &&
@info "The matrix is acyclic. Reordering variables in the model to make the A matrix either Upper or Lower Triangular can significantly improve performance.\n" maxlog =
1
end
return A
end
Expand Down
4 changes: 3 additions & 1 deletion src/loss/ML/FIML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ function ∇F_fiml_outer!(G, JΣ, Jμ, implied, model, semfiml)

∇Σ = P * (implied.∇S + Q * implied.∇A)

∇μ = implied.F⨉I_A⁻¹ * implied.∇M + kron((implied.I_A⁻¹ * implied.M)', implied.F⨉I_A⁻¹) * implied.∇A
∇μ =
implied.F⨉I_A⁻¹ * implied.∇M +
kron((implied.I_A⁻¹ * implied.M)', implied.F⨉I_A⁻¹) * implied.∇A

mul!(G, ∇Σ', JΣ) # actually transposed
mul!(G, ∇μ', Jμ, -1, 1)
Expand Down
16 changes: 13 additions & 3 deletions src/objective_gradient_hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,26 @@ evaluate!(objective, gradient, hessian, loss::SemLossFunction, model::AbstractSe
evaluate!(objective, gradient, hessian, loss, implied(model), model, params)

# fallback method
function evaluate!(obj, grad, hess, loss::SemLossFunction, implied::SemImplied, model, params)
function evaluate!(
obj,
grad,
hess,
loss::SemLossFunction,
implied::SemImplied,
model,
params,
)
isnothing(obj) || (obj = objective(loss, implied, model, params))
isnothing(grad) || copyto!(grad, gradient(loss, implied, model, params))
isnothing(hess) || copyto!(hess, hessian(loss, implied, model, params))
return obj
end

# fallback methods
objective(f::SemLossFunction, implied::SemImplied, model, params) = objective(f, model, params)
gradient(f::SemLossFunction, implied::SemImplied, model, params) = gradient(f, model, params)
objective(f::SemLossFunction, implied::SemImplied, model, params) =
objective(f, model, params)
gradient(f::SemLossFunction, implied::SemImplied, model, params) =
gradient(f, model, params)
hessian(f::SemLossFunction, implied::SemImplied, model, params) = hessian(f, model, params)

# fallback method for SemImplied that calls update_xxx!() methods
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
23 changes: 15 additions & 8 deletions test/examples/multigroup/build_models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@ model_ml_multigroup2 = SemEnsemble(
data = dat,
column = :school,
groups = [:Pasteur, :Grant_White],
loss = SemML
loss = SemML,
)


# gradients
@testset "ml_gradients_multigroup" begin
test_gradient(model_ml_multigroup, start_test; atol = 1e-9)
Expand Down Expand Up @@ -206,11 +205,19 @@ end
# GLS estimation
############################################################################################

model_ls_g1 =
Sem(specification = specification_g1, data = dat_g1, implied = RAMSymbolic, loss = SemWLS)
model_ls_g1 = Sem(
specification = specification_g1,
data = dat_g1,
implied = RAMSymbolic,
loss = SemWLS,
)

model_ls_g2 =
Sem(specification = specification_g2, data = dat_g2, implied = RAMSymbolic, loss = SemWLS)
model_ls_g2 = Sem(
specification = specification_g2,
data = dat_g2,
implied = RAMSymbolic,
loss = SemWLS,
)

model_ls_multigroup = SemEnsemble(model_ls_g1, model_ls_g2; optimizer = semoptimizer)

Expand Down Expand Up @@ -239,7 +246,7 @@ end
atol = 1e-5,
)

update_se_hessian!(partable, solution_ls)
@suppress update_se_hessian!(partable, solution_ls)
test_estimates(
partable,
solution_lav[:parameter_estimates_ls];
Expand Down Expand Up @@ -283,7 +290,7 @@ if !isnothing(specification_miss_g1)
groups = [:Pasteur, :Grant_White],
loss = SemFIML,
observed = SemObservedMissing,
meanstructure = true
meanstructure = true,
)

############################################################################################
Expand Down
8 changes: 3 additions & 5 deletions test/examples/multigroup/multigroup.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using StructuralEquationModels, Test, FiniteDiff
using StructuralEquationModels, Test, FiniteDiff, Suppressor
using LinearAlgebra: diagind, LowerTriangular

const SEM = StructuralEquationModels
Expand Down Expand Up @@ -71,10 +71,8 @@ specification_g2 = RAMMatrices(;
vars = [:x1, :x2, :x3, :x4, :x5, :x6, :x7, :x8, :x9, :visual, :textual, :speed],
)

partable = EnsembleParameterTable(
:Pasteur => specification_g1,
:Grant_White => specification_g2
)
partable =
EnsembleParameterTable(:Pasteur => specification_g1, :Grant_White => specification_g2)

specification_miss_g1 = nothing
specification_miss_g2 = nothing
Expand Down
7 changes: 4 additions & 3 deletions test/examples/political_democracy/by_parts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ end
)
@test (fm[:AIC] === missing) & (fm[:BIC] === missing) & (fm[:minus2ll] === missing)

update_se_hessian!(partable, solution_ls)
@suppress update_se_hessian!(partable, solution_ls)
test_estimates(
partable,
solution_lav[:parameter_estimates_ls];
Expand All @@ -158,7 +158,8 @@ if opt_engine == :Optim
),
)

implied_sym_hessian_vech = RAMSymbolic(specification = spec, vech = true, hessian = true)
implied_sym_hessian_vech =
RAMSymbolic(specification = spec, vech = true, hessian = true)

implied_sym_hessian = RAMSymbolic(specification = spec, hessian = true)

Expand Down Expand Up @@ -294,7 +295,7 @@ end
)
@test (fm[:AIC] === missing) & (fm[:BIC] === missing) & (fm[:minus2ll] === missing)

update_se_hessian!(partable_mean, solution_ls)
@suppress update_se_hessian!(partable_mean, solution_ls)
test_estimates(
partable_mean,
solution_lav[:parameter_estimates_ls_mean];
Expand Down
4 changes: 2 additions & 2 deletions test/examples/political_democracy/constructor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ end
)
@test ismissing(fm[:AIC]) && ismissing(fm[:BIC]) && ismissing(fm[:minus2ll])

update_se_hessian!(partable, solution_ls)
@suppress update_se_hessian!(partable, solution_ls)
test_estimates(
partable,
solution_lav[:parameter_estimates_ls];
Expand Down Expand Up @@ -337,7 +337,7 @@ end
)
@test ismissing(fm[:AIC]) && ismissing(fm[:BIC]) && ismissing(fm[:minus2ll])

update_se_hessian!(partable_mean, solution_ls)
@suppress update_se_hessian!(partable_mean, solution_ls)
test_estimates(
partable_mean,
solution_lav[:parameter_estimates_ls_mean];
Expand Down
36 changes: 3 additions & 33 deletions test/examples/political_democracy/political_democracy.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using StructuralEquationModels, Test, FiniteDiff
using StructuralEquationModels, Test, Suppressor, FiniteDiff

SEM = StructuralEquationModels

Expand Down Expand Up @@ -78,22 +78,7 @@ spec = RAMMatrices(;
S = S,
F = F,
params = x,
vars = [
:x1,
:x2,
:x3,
:y1,
:y2,
:y3,
:y4,
:y5,
:y6,
:y7,
:y8,
:ind60,
:dem60,
:dem65,
],
vars = [:x1, :x2, :x3, :y1, :y2, :y3, :y4, :y5, :y6, :y7, :y8, :ind60, :dem60, :dem65],
)

partable = ParameterTable(spec)
Expand All @@ -110,22 +95,7 @@ spec_mean = RAMMatrices(;
F = F,
M = M,
params = [SEM.params(spec); Symbol.("x", string.(32:38))],
vars = [
:x1,
:x2,
:x3,
:y1,
:y2,
:y3,
:y4,
:y5,
:y6,
:y7,
:y8,
:ind60,
:dem60,
:dem65,
],
vars = [:x1, :x2, :x3, :y1, :y2, :y3, :y4, :y5, :y6, :y7, :y8, :ind60, :dem60, :dem65],
)

partable_mean = ParameterTable(spec_mean)
Expand Down
4 changes: 2 additions & 2 deletions test/examples/proximal/ridge.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using StructuralEquationModels, Test, ProximalAlgorithms, ProximalOperators
using StructuralEquationModels, Test, ProximalAlgorithms, ProximalOperators, Suppressor

# load data
dat = example_data("political_democracy")
Expand Down Expand Up @@ -54,7 +54,7 @@ solution_ridge = sem_fit(model_ridge)

model_prox = Sem(specification = partable, data = dat, loss = SemML)

solution_prox = sem_fit(model_prox, engine = :Proximal, operator_g = SqrNormL2(λ))
solution_prox = @suppress sem_fit(model_prox, engine = :Proximal, operator_g = SqrNormL2(λ))

@testset "ridge_solution" begin
@test isapprox(solution_prox.solution, solution_ridge.solution; rtol = 1e-4)
Expand Down
Loading
Loading