Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/implied/RAM/symbolic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ function RAMSymbolic(;
specification::SemSpecification,
loss_types = nothing,
vech = false,
simplify_symbolics = false,
gradient = true,
hessian = false,
meanstructure = false,
Expand All @@ -116,7 +117,7 @@ function RAMSymbolic(;
I_A⁻¹ = neumann_series(A)

# Σ
Σ_symbolic = eval_Σ_symbolic(S, I_A⁻¹, F; vech = vech)
Σ_symbolic = eval_Σ_symbolic(S, I_A⁻¹, F; vech = vech, simplify = simplify_symbolics)
#print(Symbolics.build_function(Σ_symbolic)[2])
Σ_function = Symbolics.build_function(Σ_symbolic, par, expression = Val{false})[2]
Σ = zeros(size(Σ_symbolic))
Expand Down Expand Up @@ -157,7 +158,7 @@ function RAMSymbolic(;
# μ
if meanstructure
MS = HasMeanStruct
μ_symbolic = eval_μ_symbolic(M, I_A⁻¹, F)
μ_symbolic = eval_μ_symbolic(M, I_A⁻¹, F; simplify = simplify_symbolics)
μ_function = Symbolics.build_function(μ_symbolic, par, expression = Val{false})[2]
μ = zeros(size(μ_symbolic))
if gradient
Expand Down Expand Up @@ -235,23 +236,26 @@ end
############################################################################################

# expected covariations of observed vars
function eval_Σ_symbolic(S, I_A⁻¹, F; vech = false)
function eval_Σ_symbolic(S, I_A⁻¹, F; vech = false, simplify = false)
Σ = F * I_A⁻¹ * S * permutedims(I_A⁻¹) * permutedims(F)
Σ = Array(Σ)
vech && (Σ = Σ[tril(trues(size(F, 1), size(F, 1)))])
# Σ = Symbolics.simplify.(Σ)
Threads.@threads for i in eachindex(Σ)
Σ[i] = Symbolics.simplify(Σ[i])
if simplify
Threads.@threads for i in eachindex(Σ)
Σ[i] = Symbolics.simplify(Σ[i])
end
end
return Σ
end

# expected means of observed vars
function eval_μ_symbolic(M, I_A⁻¹, F)
function eval_μ_symbolic(M, I_A⁻¹, F; simplify = false)
μ = F * I_A⁻¹ * M
μ = Array(μ)
Threads.@threads for i in eachindex(μ)
μ[i] = Symbolics.simplify(μ[i])
if simplify
Threads.@threads for i in eachindex(μ)
μ[i] = Symbolics.simplify(μ[i])
end
end
return μ
end
Loading