Skip to content

Commit 6d14af0

Browse files
SebastianM-Cclaude
andcommitted
add tune_parameters for the rest of the backends
Co-authored-by: Claude <[email protected]>
1 parent 0655fd2 commit 6d14af0

File tree

3 files changed

+56
-27
lines changed

3 files changed

+56
-27
lines changed

ext/MTKCasADiDynamicOptExt.jl

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,17 @@ function Base.getindex(m::MXLinearInterpolation, i...)
2121
length(i) == length(size(m.u)) ? m.u[i...] : m.u[i..., :]
2222
end
2323

24-
mutable struct CasADiModel
24+
mutable struct CasADiModel{T}
2525
model::Opti
2626
U::MXLinearInterpolation
2727
V::MXLinearInterpolation
28+
P::T
2829
tₛ::MX
2930
is_free_final::Bool
3031
solver_opti::Union{Nothing, Opti}
3132

32-
function CasADiModel(opti, U, V, tₛ, is_free_final, solver_opti = nothing)
33-
new(opti, U, V, tₛ, is_free_final, solver_opti)
33+
function CasADiModel(opti, U, V, P, tₛ, is_free_final, solver_opti = nothing)
34+
new{typeof(P)}(opti, U, V, P, tₛ, is_free_final, solver_opti)
3435
end
3536
end
3637

@@ -66,10 +67,11 @@ end
6667
function MTK.CasADiDynamicOptProblem(sys::System, op, tspan;
6768
dt = nothing,
6869
steps = nothing,
70+
tune_parameters = false,
6971
guesses = Dict(), kwargs...)
7072
prob,
7173
_ = MTK.process_DynamicOptProblem(
72-
CasADiDynamicOptProblem, CasADiModel, sys, op, tspan; dt, steps, guesses, kwargs...)
74+
CasADiDynamicOptProblem, CasADiModel, sys, op, tspan; dt, steps, tune_parameters, guesses, kwargs...)
7375
prob
7476
end
7577

@@ -90,6 +92,14 @@ function MTK.generate_input_variable!(model::Opti, c0, nc, tsteps)
9092
MXLinearInterpolation(V, tsteps, tsteps[2] - tsteps[1])
9193
end
9294

95+
function MTK.generate_tunable_params!(model::Opti, p0, np)
96+
P = CasADi.variable!(model, np)
97+
for i in 1:np
98+
set_initial!(model, P[i], p0[i])
99+
end
100+
P
101+
end
102+
93103
function MTK.generate_timescale!(model::Opti, guess, is_free_t)
94104
if is_free_t
95105
tₛ = variable!(model)
@@ -143,7 +153,7 @@ end
143153
function add_solve_constraints!(prob::CasADiDynamicOptProblem, tableau)
144154
@unpack A, α, c = tableau
145155
@unpack wrapped_model, f, p = prob
146-
@unpack model, U, V, tₛ = wrapped_model
156+
@unpack model, U, V, P, tₛ = wrapped_model
147157
solver_opti = copy(model)
148158

149159
tsteps = U.t
@@ -160,7 +170,7 @@ function add_solve_constraints!(prob::CasADiDynamicOptProblem, tableau)
160170
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = MX(zeros(nᵤ)))
161171
Uₙ = U.u[:, k] + ΔU * dt
162172
Vₙ = V.u[:, k]
163-
Kₙ = tₛ * f(Uₙ, Vₙ, p, τ + h * dt) # scale the time
173+
Kₙ = tₛ * MTK.f_wrapper(f, Uₙ, Vₙ, p, P, τ + h * dt) # scale the time
164174
push!(K, Kₙ)
165175
end
166176
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
@@ -176,7 +186,7 @@ function add_solve_constraints!(prob::CasADiDynamicOptProblem, tableau)
176186
ΔU = ΔUs[i, :]'
177187
Uₙ = U.u[:, k] + ΔU * dt
178188
Vₙ = V.u[:, k]
179-
subject_to!(solver_opti, Kᵢ[:, i] == tₛ * f(Uₙ, Vₙ, p, τ + h * dt))
189+
subject_to!(solver_opti, Kᵢ[:, i] == tₛ * MTK.f_wrapper(f, Uₙ, Vₙ, p, P, τ + h * dt))
180190
end
181191
ΔU_tot = dt * (Kᵢ * α)
182192
subject_to!(solver_opti, U.u[:, k] + ΔU_tot == U.u[:, k + 1])
@@ -228,6 +238,11 @@ function MTK.get_V_values(model::CasADiModel)
228238
end
229239
end
230240

241+
function MTK.get_P_values(model::CasADiModel)
242+
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
243+
value_getter(model.solver_opti, model.P)
244+
end
245+
231246
function MTK.get_t_values(model::CasADiModel)
232247
value_getter = MTK.successful_solve(model) ? CasADi.debug_value : CasADi.value
233248
ts = value_getter(model.solver_opti, model.tₛ) .* model.U.t

ext/MTKInfiniteOptExt.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,11 @@ end
9797
function MTK.InfiniteOptDynamicOptProblem(sys::System, op, tspan;
9898
dt = nothing,
9999
steps = nothing,
100+
tune_parameters = false,
100101
guesses = Dict(), kwargs...)
101102
prob,
102103
pmap = MTK.process_DynamicOptProblem(InfiniteOptDynamicOptProblem, InfiniteOptModel,
103-
sys, op, tspan; dt, steps, guesses, kwargs...)
104+
sys, op, tspan; dt, steps, tune_parameters, guesses, kwargs...)
104105
MTK.add_equational_constraints!(prob.wrapped_model, sys, pmap, tspan)
105106
prob
106107
end
@@ -131,16 +132,6 @@ function MTK.lowered_var(m::InfiniteOptModel, uv, i, t)
131132
t isa Union{Num, Symbolics.Symbolic} ? X[i] : X[i](t)
132133
end
133134

134-
function f_wrapper(f, Uₙ, Vₙ, p, P, t)
135-
if SciMLStructures.isscimlstructure(p)
136-
_, repack, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), p)
137-
p′ = repack(P)
138-
f(Uₙ, Vₙ, p′, t)
139-
else
140-
f(Uₙ, Vₙ, P, t)
141-
end
142-
end
143-
144135
function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
145136
@unpack A, α, c = tableau
146137
@unpack wrapped_model, f, p = prob
@@ -159,7 +150,7 @@ function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
159150
ΔU = sum([A[i, j] * K[j] for j in 1:(i - 1)], init = zeros(nᵤ))
160151
Uₙ = [U[i](τ) + ΔU[i] * dt for i in 1:nᵤ]
161152
Vₙ = [V[i](τ) for i in 1:nᵥ]
162-
Kₙ = tₛ * f_wrapper(f, Uₙ, Vₙ, p, P, τ + h * dt)
153+
Kₙ = tₛ * MTK.f_wrapper(f, Uₙ, Vₙ, p, P, τ + h * dt)
163154
push!(K, Kₙ)
164155
end
165156
ΔU = dt * sum([α[i] * K[i] for i in 1:length(α)])
@@ -175,12 +166,12 @@ function add_solve_constraints!(prob::JuMPDynamicOptProblem, tableau)
175166
for (i, h) in enumerate(c)
176167
ΔU = @view ΔUs[i, :]
177168
Uₙ = U + ΔU * dt
178-
@constraint(model, [j = 1:nᵤ], K[i, j]==(tₛ * f_wrapper(f, Uₙ, V, p, P, τ + h * dt)[j]),
179-
DomainRestrictions(t => τ), base_name="solve_K$i()")
169+
@constraint(model, [j = 1:nᵤ], K[i, j]==(tₛ * MTK.f_wrapper(f, Uₙ, V, p, P, τ + h * dt)[j]),
170+
DomainRestriction(==(τ), t), base_name="solve_K$i()")
180171
end
181172
@constraint(model,
182173
[n = 1:nᵤ], U[n](τ) + ΔU_tot[n]==U[n](min+ dt, tsteps[end])),
183-
DomainRestrictions(t => τ), base_name="solve_U()")
174+
DomainRestriction(==(τ), t), base_name="solve_U()")
184175
end
185176
end
186177
end

ext/MTKPyomoDynamicOptExt.jl

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@ const SPECIAL_FUNCTIONS_DICT = Dict([acos => Pyomo.py_acos,
1616
log => Pyomo.py_log,
1717
sin => Pyomo.py_sin,
1818
sqrt => Pyomo.py_sqrt,
19-
exp => Pyomo.py_exp])
19+
exp => Pyomo.py_exp,
20+
abs2 => (x -> x^2)])
2021

2122
struct PyomoDynamicOptModel
2223
model::ConcreteModel
2324
U::PyomoVar
2425
V::PyomoVar
26+
P::PyomoVar
2527
tₛ::PyomoVar
2628
is_free_final::Bool
2729
solver_model::Union{Nothing, ConcreteModel}
@@ -30,10 +32,10 @@ struct PyomoDynamicOptModel
3032
t_sym::Union{Num, Symbolics.BasicSymbolic}
3133
dummy_sym::Union{Num, Symbolics.BasicSymbolic}
3234

33-
function PyomoDynamicOptModel(model, U, V, tₛ, is_free_final)
35+
function PyomoDynamicOptModel(model, U, V, P, tₛ, is_free_final)
3436
@variables MODEL_SYM::Symbolics.symstruct(ConcreteModel) T_SYM DUMMY_SYM
3537
model.dU = dae.DerivativeVar(U, wrt = model.t, initialize = 0)
36-
new(model, U, V, tₛ, is_free_final, nothing,
38+
new(model, U, V, P, tₛ, is_free_final, nothing,
3739
PyomoVar(model.dU), MODEL_SYM, T_SYM, DUMMY_SYM)
3840
end
3941
end
@@ -60,11 +62,11 @@ end
6062
_getproperty(s, name::Val{fieldname}) where {fieldname} = getproperty(s, fieldname)
6163

6264
function MTK.PyomoDynamicOptProblem(sys::System, op, tspan;
63-
dt = nothing, steps = nothing,
65+
dt = nothing, steps = nothing, tune_parameters = false,
6466
guesses = Dict(), kwargs...)
6567
prob,
6668
pmap = MTK.process_DynamicOptProblem(PyomoDynamicOptProblem, PyomoDynamicOptModel,
67-
sys, op, tspan; dt, steps, guesses, kwargs...)
69+
sys, op, tspan; dt, steps, tune_parameters, guesses, kwargs...)
6870
conc_model = prob.wrapped_model.model
6971
MTK.add_equational_constraints!(prob.wrapped_model, sys, pmap, tspan)
7072
prob
@@ -94,6 +96,13 @@ function MTK.generate_input_variable!(m::ConcreteModel, c0, nc, ts)
9496
PyomoVar(m.V)
9597
end
9698

99+
function MTK.generate_tunable_params!(m::ConcreteModel, p0, np)
100+
m.p_idxs = pyomo.RangeSet(1, np)
101+
init_f = Pyomo.pyfunc((m, i) -> (p0[Pyomo.pyconvert(Int, i)]))
102+
m.P = pyomo.Var(m.p_idxs, initialize = init_f)
103+
PyomoVar(m.P)
104+
end
105+
97106
function MTK.generate_timescale!(m::ConcreteModel, guess, is_free_t)
98107
m.tₛ = is_free_t ? pyomo.Var(initialize = guess, bounds = (0, Inf)) : Pyomo.Py(1)
99108
PyomoVar(m.tₛ)
@@ -169,6 +178,16 @@ function MTK.lowered_var(m::PyomoDynamicOptModel, uv, i, t)
169178
Symbolics.unwrap(var)
170179
end
171180

181+
function MTK.lowered_param(m::PyomoDynamicOptModel, i)
182+
P = Symbolics.value(pysym_getproperty(m.model_sym, :P))
183+
Symbolics.unwrap(P[i])
184+
end
185+
186+
# For Pyomo, return symbolic accessors instead of raw PyomoVar
187+
function MTK.get_param_for_pmap(m::PyomoDynamicOptModel, P, i)
188+
MTK.lowered_param(m, i)
189+
end
190+
172191
struct PyomoCollocation <: AbstractCollocation
173192
solver::Union{String, Symbol}
174193
derivative_method::Pyomo.DiscretizationMethod
@@ -208,6 +227,10 @@ function MTK.get_V_values(output::PyomoOutput)
208227
m = output.model
209228
[[Pyomo.pyconvert(Float64, pyomo.value(m.V[i, t])) for i in m.v_idxs] for t in m.t]
210229
end
230+
function MTK.get_P_values(output::PyomoOutput)
231+
m = output.model
232+
[Pyomo.pyconvert(Float64, pyomo.value(m.P[i])) for i in m.p_idxs]
233+
end
211234
function MTK.get_t_values(output::PyomoOutput)
212235
m = output.model
213236
Pyomo.pyconvert(Float64, pyomo.value(m.tₛ)) * [Pyomo.pyconvert(Float64, t) for t in m.t]

0 commit comments

Comments
 (0)