Skip to content

Commit 219ed56

Browse files
committed
fix for individual params needed in the model
1 parent 3d1766a commit 219ed56

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ function MTK.lowered_integral(model::CasADiModel, expr, lo, hi)
150150
model.tₛ * total
151151
end
152152

153+
MTK.needs_individual_tunables(::Opti) = true
154+
MTK.get_param_for_pmap(::Opti, P, i) = P[i]
155+
153156
function add_solve_constraints!(prob::CasADiDynamicOptProblem, tableau)
154157
@unpack A, α, c = tableau
155158
@unpack wrapped_model, f, p = prob
@@ -240,7 +243,7 @@ end
240243

241244
function MTK.get_P_values(model::CasADiModel)
242245
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
243-
value_getter(model.solver_opti, model.P)
246+
[value_getter(model.solver_opti, model.P[i]) for i in eachindex(model.P)]
244247
end
245248

246249
function MTK.get_t_values(model::CasADiModel)

src/systems/optimal_control_interface.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -267,13 +267,15 @@ function process_DynamicOptProblem(
267267
U = generate_state_variable!(model, u0, length(states), tsteps)
268268
V = generate_input_variable!(model, c0, length(ctrls), tsteps)
269269
P = generate_tunable_params!(model, p0, length(tunable_params))
270-
tₛ = generate_timescale!(model, get(pmap, tspan[2], tspan[2]), is_free_t)
271-
fullmodel = model_type(model, U, V, P, tₛ, is_free_t)
272-
273270
# Add the symbolic representation of the tunable parameters to the map
274271
# The order of the Tunable section is given by the tunable_parameters(sys)
275272
# Some backends need symbolic accessors instead of raw variables
276-
P_syms = [get_param_for_pmap(fullmodel, P, i) for i in eachindex(tunable_params)]
273+
P_syms = [get_param_for_pmap(model, P, i) for i in eachindex(tunable_params)]
274+
P_backend = needs_individual_tunables(model) ? P_syms : P
275+
276+
tₛ = generate_timescale!(model, get(pmap, tspan[2], tspan[2]), is_free_t)
277+
fullmodel = model_type(model, U, V, P_backend, tₛ, is_free_t)
278+
277279
merge!(pmap, Dict(tunable_params .=> P_syms))
278280

279281
set_variable_bounds!(fullmodel, sys, pmap, tspan[2])
@@ -294,6 +296,8 @@ function add_initial_constraints! end
294296
function add_constraint! end
295297
# Default: return P[i] directly. Symbolic backends (like Pyomo) can override this.
296298
get_param_for_pmap(model, P, i) = P isa AbstractArray ? P[i] : P
299+
# Some backends need symbolic accessors instead of raw variables (CasADi in particular)
300+
needs_individual_tunables(model) = false
297301

298302
function f_wrapper(f, Uₙ, Vₙ, p, P, t)
299303
if isempty(P)

0 commit comments

Comments
 (0)