Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1140,7 +1140,35 @@ namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))

function namespace_defaults(sys)
defs = defaults(sys)
Dict((isparameter(k) ? parameters(sys, k) : unknowns(sys, k)) => namespace_expr(v, sys)
sys_params = Set(parameters(sys))
sys_unknowns = Set(unknowns(sys))

function should_namespace(val)
# If it's a parameter from parent scope, don't namespace it
if isparameter(val) && !(unwrap(val) in sys_params || unwrap(val) in sys_unknowns)
return false
end

# Check if the expression contains any parent scope parameters
# vars() collects all variables in an expression
try
expr_vars = vars(val; op = Nothing)
for var in expr_vars
var_unwrapped = unwrap(var)
# If any variable in the expression is from parent scope, don't namespace
if isparameter(var_unwrapped) && !(var_unwrapped in sys_params || var_unwrapped in sys_unknowns)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if isparameter(var_unwrapped) && !(var_unwrapped in sys_params || var_unwrapped in sys_unknowns)
if isparameter(var_unwrapped) &&
!(var_unwrapped in sys_params || var_unwrapped in sys_unknowns)

return false
end
end
catch
# If vars() fails, fall back to default behavior
end

return true
end

Dict((isparameter(k) ? parameters(sys, k) : unknowns(sys, k)) =>
(should_namespace(v) ? namespace_expr(v, sys) : v)
Comment on lines +1170 to +1171
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
Dict((isparameter(k) ? parameters(sys, k) : unknowns(sys, k)) =>
(should_namespace(v) ? namespace_expr(v, sys) : v)
Dict((isparameter(k) ? parameters(sys, k) :
unknowns(sys, k)) => (should_namespace(v) ? namespace_expr(v, sys) : v)

for (k, v) in pairs(defs))
end

Expand Down
74 changes: 74 additions & 0 deletions test/model_parsing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -935,6 +935,80 @@ end
@test getdefault(main_sys.v1) == 13
end

@testset "Array parameters in nested @mtkmodel components" begin
# Test for issue where array parameters passed to nested models
# were not correctly resolved to their parent scope values
@mtkmodel InnerWithArrayParam begin
@parameters begin
a[1:2] = a
end
@variables begin
x(t)[1:2] = zeros(2)
end
@equations begin
D(x) ~ a
end
end

@mtkmodel OuterWithArrayParam begin
@parameters begin
a1 = 1
a2 = 2
end
@components begin
inner = InnerWithArrayParam(a = [a1+a2, a2])
end
@variables begin
y(t)[1:2] = zeros(2)
end
@equations begin
D(y) ~ [a1+a2, a2]
end
end

@named sys = OuterWithArrayParam()
sys = complete(sys)

# Check that the defaults are correctly mapped
defs = ModelingToolkit.defaults(sys)

# Find the keys for inner.a[1] and inner.a[2]
inner_a1_key = nothing
inner_a2_key = nothing
outer_a1_key = nothing
outer_a2_key = nothing

for (k, v) in defs
k_str = string(k)
if k_str == "inner₊a[1]"
inner_a1_key = k
elseif k_str == "inner₊a[2]"
inner_a2_key = k
elseif k_str == "a1"
outer_a1_key = k
elseif k_str == "a2"
outer_a2_key = k
end
end

@test inner_a1_key !== nothing
@test inner_a2_key !== nothing
@test outer_a1_key !== nothing
@test outer_a2_key !== nothing

# The inner array parameter elements should map to the outer parameters
@test isequal(defs[inner_a1_key], sys.a1 + sys.a2)
@test isequal(defs[inner_a2_key], sys.a2)
@test defs[outer_a1_key] == 1
@test defs[outer_a2_key] == 2

# Test that ODEProblem can be created successfully
prob = ODEProblem(mtkcompile(sys), [], (0.0, 1.0))
@test prob isa ODEProblem
# sol = solve(prob, Tsit5())
# @test sol[sys.y] ≈ sol[sys.inner.x]
end

@mtkmodel InnerModel begin
@parameters begin
p
Expand Down
Loading