diff --git a/Project.toml b/Project.toml index e37c6e017..d45497f2d 100644 --- a/Project.toml +++ b/Project.toml @@ -59,7 +59,7 @@ Random = "1.10" RecursiveArrayTools = "3.31.2" ReTestItems = "1.29" Reexport = "1.2" -SciMLBase = "2.108.0" +SciMLBase = "2.120.0" Sparspak = "0.3.11" StaticArrays = "1.9.8" Test = "1.10" diff --git a/docs/src/tutorials/optimal_control.md b/docs/src/tutorials/optimal_control.md new file mode 100644 index 000000000..f14a1ab14 --- /dev/null +++ b/docs/src/tutorials/optimal_control.md @@ -0,0 +1,248 @@ +# Solve Optimal Control problem + +A classical optimal control problem is the rocket launching problem(aka [Goddard Rocket problem](https://en.wikipedia.org/wiki/Goddard_problem)). Say we have a rocket with limited fuel and is launched vertically. And we want to control the final altitude of this rocket so that we can make the best of the limited fuel in rocket to get to the highest altitude. In this optimal control problem, the state variables are: + + - Velocity of the rocket: $x_v(t)$ + - Altitude of the rocket: $x_h(t)$ + - Mass of the rocket and the fuel: $x_m(t)$ + +The control variable is + + - Thrust of the rocket: $u_t(t)$ + +The dynamics of the launching can be formulated with three differential equations: + +$$ +\left\{\begin{aligned} +&\frac{dx_v}{dt}=\frac{u_t-drag(x_h,x_v)}{x_m}-g(x_h)\\ +&\frac{dx_h}{dt}=x_v\\ +&\frac{dx_m}{dt}=-\frac{u_t}{c} +\end{aligned}\right. +$$ + +where the drag $D(x_h,x_v)$ is a function of altitude and velocity: + +$$ +D(x_h,x_v)=D_c\cdot x_v^2\cdot\exp^{h_c(\frac{x_h-x_h(0)}{x_h(0)})} +$$ + +gravity $g(x_h)$ is a function of altitude: + +$$ +g(x_h)=g_0\cdot (\frac{x_h(0)}{x_h})^2 +$$ + +$c$ is a constant. Suppose the final time is $T$, we here want to maximize the final altitude $x_h(T)$: + +$$ +\max x_h(T) +$$ + +The inequality constraints for the state variables and control variables are: + +$$ +\left\{\begin{aligned} +&x_v>0\\ +&x_h>0\\ +&m_T 0.0 +@inline __default_cost(f) = f +@inline __default_cost(fun::BVPFunction) = __default_cost(fun.cost) + +@inline function __extract_lcons_ucons( + prob::AbstractBVProblem, ::Type{T}, M, N, bcresid_prototype, f_prototype) where {T} + L_f_prototype = length(f_prototype) + L_bcresid_prototype = length(bcresid_prototype) + lcons = if isnothing(prob.lcons) + zeros(T, L_bcresid_prototype + (N - 1)*L_f_prototype) + else + lcons_length = length(prob.lcons) + vcat(prob.lcons, zeros(T, N*M - lcons_length)) + end + ucons = if isnothing(prob.ucons) + zeros(T, L_bcresid_prototype + (N - 1)*L_f_prototype) + else + ucons_length = length(prob.ucons) + vcat(prob.ucons, zeros(T, N*M - ucons_length)) + end + return lcons, ucons +end + +@inline function __extract_lcons_ucons(prob::AbstractBVProblem, ::Type{T}, M, N) where {T} + lcons = if isnothing(prob.lcons) + zeros(T, N*M) + else + lcons_length = length(prob.lcons) + vcat(prob.lcons, zeros(T, N*M - lcons_length)) + end + ucons = if isnothing(prob.ucons) + zeros(T, N*M) + else + ucons_length = length(prob.ucons) + vcat(prob.ucons, zeros(T, N*M - ucons_length)) + end + return lcons, ucons +end + +@inline function __extract_lb_ub(prob::AbstractBVProblem, ::Type{T}, M, N) where {T} + lb = if isnothing(prob.lb) + nothing + else + repeat(prob.lb, N) + end + ub = if isnothing(prob.ub) + nothing + else + repeat(prob.ub, N) + end + return lb, ub +end + +""" + __construct_internal_problem + +Constructs the internal problem based on the type of the boundary value problem and the +algorithm used. It returns either a `NonlinearProblem` or an `OptimizationProblem`. +""" +function __construct_internal_problem( + prob, pt::StandardBVProblem, alg, loss, jac, jac_prototype, + resid_prototype, bcresid_prototype, f_prototype, y, p, M::Int, N::Int) + T = eltype(y) + iip = SciMLBase.isinplace(prob) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = sparse(jac_prototype)) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N, bcresid_prototype, f_prototype) + lb, ub = __extract_lb_ub(prob, T, M, N) + + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons, lb = lb, ub = ub) + end +end + +function __construct_internal_problem( + prob, pt::TwoPointBVProblem, alg, loss, jac, jac_prototype, + resid_prototype, bcresid_prototype, f_prototype, y, p, M::Int, N::Int) + T = eltype(y) + iip = SciMLBase.isinplace(prob) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = sparse(jac_prototype)) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N, bcresid_prototype, f_prototype) + lb, ub = __extract_lb_ub(prob, T, M, N) + + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons, lb = lb, ub = ub) + end +end + +# Single shooting use diffmode for StandardBVProblem and TwoPointBVProblem +function __construct_internal_problem(prob, alg, loss, jac, jac_prototype, resid_prototype, + bcresid_prototype, f_prototype, y, p, M::Int, N::Int, ::Nothing) + T = eltype(y) + iip = SciMLBase.isinplace(prob) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{iip}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = sparse(jac_prototype)) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N, bcresid_prototype, f_prototype) + lb, ub = __extract_lb_ub(prob, T, M, N) + + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons, lb = lb, ub = ub) + end +end + +# Multiple shooting always use inplace version internal problem constructor +function __construct_internal_problem( + prob, pt::StandardBVProblem, alg, loss, jac, jac_prototype, resid_prototype, + bcresid_prototype, f_prototype, y, p, M::Int, N::Int, ::Nothing) + T = eltype(y) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{true}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = sparse(jac_prototype)) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N, bcresid_prototype, f_prototype) + lb, ub = __extract_lb_ub(prob, T, M, N) + + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons, lb = lb, ub = ub) + end +end +function __construct_internal_problem( + prob, pt::TwoPointBVProblem, alg, loss, jac, jac_prototype, resid_prototype, + bcresid_prototype, f_prototype, y, p, M::Int, N::Int, ::Nothing) + T = eltype(y) + iip = SciMLBase.isinplace(prob) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = sparse(jac_prototype)) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N, bcresid_prototype, f_prototype) + lb, ub = __extract_lb_ub(prob, T, M, N) + + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons, lb = lb, ub = ub) + end +end + +# SecondOrderBVProblem +function __construct_internal_problem( + prob, pt::StandardSecondOrderBVProblem, alg, loss, jac, + jac_prototype, resid_prototype, y, p, M::Int, N::Int) + T = eltype(y) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{true}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = sparse(jac_prototype)) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N) + lb, ub = __extract_lb_ub(prob, T, M, N) + + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons, lb = lb, ub = ub) + end +end +function __construct_internal_problem( + prob, pt::TwoPointSecondOrderBVProblem, alg, loss, jac, + jac_prototype, resid_prototype, y, p, M::Int, N::Int) + T = eltype(y) + iip = SciMLBase.isinplace(prob) + if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) + nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, + jac_prototype = jac_prototype) + return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) + else + optf = OptimizationFunction{true}(__default_cost(prob.f), + AutoSparse(get_dense_ad(alg.jac_alg.diffmode), + sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), + cons = loss, + cons_j = jac, + cons_jac_prototype = sparse(jac_prototype)) + lcons, ucons = __extract_lcons_ucons(prob, T, M, N) + lb, ub = __extract_lb_ub(prob, T, M, N) + + return __internal_optimization_problem( + prob, optf, y, p; lcons = lcons, ucons = ucons, lb = lb, ub = ub) + end +end diff --git a/lib/BoundaryValueDiffEqCore/src/utils.jl b/lib/BoundaryValueDiffEqCore/src/utils.jl index 2cc396b14..e0ffdccc6 100644 --- a/lib/BoundaryValueDiffEqCore/src/utils.jl +++ b/lib/BoundaryValueDiffEqCore/src/utils.jl @@ -620,188 +620,3 @@ end @inline __concrete_kwargs(::Nothing, optimize, nlsolve_kwargs, optimize_kwargs) = (;) # Doesn't support for now @inline __concrete_kwargs(::Nothing, ::Nothing, nlsolve_kwargs, optimize_kwargs) = (; nlsolve_kwargs...) - -## Optimization solver related utils ## - -@inline __default_cost(::Nothing) = (x, p) -> 0.0 -@inline __default_cost(f) = f -@inline __default_cost(fun::BVPFunction) = __default_cost(fun.cost) - -@inline function __extract_lcons_ucons(prob::AbstractBVProblem, ::Type{T}, M, N) where {T} - lcons = if isnothing(prob.lcons) - zeros(T, N*M) - else - lcons_length = length(prob.lcons) - vcat(prob.lcons, zeros(T, N*M - lcons_length)) - end - ucons = if isnothing(prob.ucons) - zeros(T, N*M) - else - ucons_length = length(prob.ucons) - vcat(prob.ucons, zeros(T, N*M - ucons_length)) - end - return lcons, ucons -end - -""" - __construct_internal_problem - -Constructs the internal problem based on the type of the boundary value problem and the -algorithm used. It returns either a `NonlinearProblem` or an `OptimizationProblem`. -""" -function __construct_internal_problem(prob, pt::StandardBVProblem, alg, loss, jac, - jac_prototype, resid_prototype, y, p, M::Int, N::Int) - T = eltype(y) - iip = SciMLBase.isinplace(prob) - if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) - nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, - jac_prototype = jac_prototype) - return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) - else - optf = OptimizationFunction{true}(__default_cost(prob.f), - AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode), - sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), - cons = loss, - cons_j = jac, - cons_jac_prototype = jac_prototype) - lcons, ucons = __extract_lcons_ucons(prob, T, M, N) - return __internal_optimization_problem( - prob, optf, y, p; lcons = lcons, ucons = ucons) - end -end - -function __construct_internal_problem(prob, pt::TwoPointBVProblem, alg, loss, jac, - jac_prototype, resid_prototype, y, p, M::Int, N::Int) - T = eltype(y) - iip = SciMLBase.isinplace(prob) - if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) - nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, - jac_prototype = jac_prototype) - return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) - else - optf = OptimizationFunction{true}(__default_cost(prob.f), - AutoSparse(get_dense_ad(alg.jac_alg.diffmode), - sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), - cons = loss, - cons_j = jac, - cons_jac_prototype = jac_prototype) - lcons, ucons = __extract_lcons_ucons(prob, T, M, N) - - return __internal_optimization_problem( - prob, optf, y, p; lcons = lcons, ucons = ucons) - end -end - -# Single shooting use diffmode for StandardBVProblem and TwoPointBVProblem -function __construct_internal_problem(prob, alg, loss, jac, jac_prototype, - resid_prototype, y, p, M::Int, N::Int, ::Nothing) - T = eltype(y) - iip = SciMLBase.isinplace(prob) - if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) - nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, - jac_prototype = jac_prototype) - return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) - else - optf = OptimizationFunction{iip}(__default_cost(prob.f), - AutoSparse(get_dense_ad(alg.jac_alg.diffmode), - sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), - cons = loss, - cons_j = jac, - cons_jac_prototype = jac_prototype) - lcons, ucons = __extract_lcons_ucons(prob, T, M, N) - - return __internal_optimization_problem( - prob, optf, y, p; lcons = lcons, ucons = ucons) - end -end - -# Multiple shooting always use inplace version internal problem constructor -function __construct_internal_problem( - prob, pt::StandardBVProblem, alg, loss, jac, jac_prototype, - resid_prototype, y, p, M::Int, N::Int, ::Nothing) - T = eltype(y) - if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) - nlf = NonlinearFunction{true}(loss; jac = jac, resid_prototype = resid_prototype, - jac_prototype = jac_prototype) - return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) - else - optf = OptimizationFunction{true}(__default_cost(prob.f), - AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode), - sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), - cons = loss, - cons_j = jac, - cons_jac_prototype = jac_prototype) - lcons, ucons = __extract_lcons_ucons(prob, T, M, N) - - return __internal_optimization_problem( - prob, optf, y, p; lcons = lcons, ucons = ucons) - end -end -function __construct_internal_problem( - prob, pt::TwoPointBVProblem, alg, loss, jac, jac_prototype, - resid_prototype, y, p, M::Int, N::Int, ::Nothing) - T = eltype(y) - if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) - nlf = NonlinearFunction{true}(loss; jac = jac, resid_prototype = resid_prototype, - jac_prototype = jac_prototype) - return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) - else - optf = OptimizationFunction{true}(__default_cost(prob.f), - AutoSparse(get_dense_ad(alg.jac_alg.diffmode), - sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), - cons = loss, - cons_j = jac, - cons_jac_prototype = jac_prototype) - lcons, ucons = __extract_lcons_ucons(prob, T, M, N) - - return __internal_optimization_problem( - prob, optf, y, p; lcons = lcons, ucons = ucons) - end -end - -# Second order BVProblem -function __construct_internal_problem( - prob, pt::StandardSecondOrderBVProblem, alg, loss, jac, - jac_prototype, resid_prototype, y, p, M::Int, N::Int) - T = eltype(y) - iip = SciMLBase.isinplace(prob) - if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) - nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, - jac_prototype = jac_prototype) - return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) - else - optf = OptimizationFunction{iip}(__default_cost(prob.f.f), - AutoSparse(get_dense_ad(alg.jac_alg.nonbc_diffmode), - sparsity_detector = __default_sparsity_detector(alg.jac_alg.nonbc_diffmode)), - cons = loss, - cons_j = jac, - cons_jac_prototype = jac_prototype) - lcons, ucons = __extract_lcons_ucons(prob, T, M, N) - return __internal_optimization_problem( - prob, optf, y, p; lcons = lcons, ucons = ucons) - end -end - -# Two point BVProblem -function __construct_internal_problem( - prob, pt::TwoPointSecondOrderBVProblem, alg, loss, jac, - jac_prototype, resid_prototype, y, p, M::Int, N::Int) - T = eltype(y) - iip = SciMLBase.isinplace(prob) - if !isnothing(alg.nlsolve) || (isnothing(alg.nlsolve) && isnothing(alg.optimize)) - nlf = NonlinearFunction{iip}(loss; jac = jac, resid_prototype = resid_prototype, - jac_prototype = jac_prototype) - return __internal_nlsolve_problem(prob, resid_prototype, y, nlf, y, p) - else - optf = OptimizationFunction{iip}(__default_cost(prob.f.f), - AutoSparse(get_dense_ad(alg.jac_alg.diffmode), - sparsity_detector = __default_sparsity_detector(alg.jac_alg.diffmode)), - cons = loss, - cons_j = jac, - cons_jac_prototype = jac_prototype) - lcons, ucons = __extract_lcons_ucons(prob, T, M, N) - - return __internal_optimization_problem( - prob, optf, y, p; lcons = lcons, ucons = ucons) - end -end diff --git a/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl b/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl index f836f7917..2603f0c45 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/BoundaryValueDiffEqFIRK.jl @@ -24,7 +24,7 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm, __initial_guess_on_mesh, __flatten_initial_guess, __build_solution, __Fix3, __split_kwargs, _sparse_like, get_dense_ad, __internal_optimization_problem, - __internal_solve + __internal_solve, __default_sparsity_detector using ConcreteStructs: @concrete using DiffEqBase: DiffEqBase diff --git a/lib/BoundaryValueDiffEqFIRK/src/collocation.jl b/lib/BoundaryValueDiffEqFIRK/src/collocation.jl index bbdcf1350..8317e4d6c 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/collocation.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/collocation.jl @@ -1,15 +1,58 @@ -function Φ!(residual, cache::FIRKCacheExpand, y, u, trait) - return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, - y, u, cache.p, cache.mesh, cache.mesh_dt, cache.stage, trait) +function Φ!(residual, cache::FIRKCacheExpand, y, u, trait, constraint) + return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, y, u, cache.p, + cache.mesh, cache.mesh_dt, cache.stage, cache.f_prototype, trait, constraint) end -function Φ!(residual, cache::FIRKCacheNested, y, u, trait) - return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, y, - u, cache.p, cache.mesh, cache.mesh_dt, cache.stage, cache, trait) +function Φ!(residual, cache::FIRKCacheNested, y, u, trait, constraint) + return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, y, u, + cache.p, cache.mesh, cache.mesh_dt, cache.stage, cache, trait, constraint) end -@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{false}, - y, u, p, mesh, mesh_dt, stage::Int, ::DiffCacheNeeded) +@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{false}, y, u, p, + mesh, mesh_dt, stage::Int, f_prototype, ::DiffCacheNeeded, ::Val{true}) + (; c, a, b) = TU + L_f_prototype = length(f_prototype) + tmp1, + tmpu = get_tmp(fᵢ_cache, u)[1:L_f_prototype], + get_tmp(fᵢ_cache, u)[(f_prototype + 1):end] + + K = get_tmp(k_discrete[1], u) # Not optimal # TODO + T = eltype(u) + ctr = 1 + + for i in eachindex(mesh_dt) + h = mesh_dt[i] + yᵢ = get_tmp(y[ctr], u) + yᵢ₊₁ = get_tmp(y[ctr + stage + 1], u) + + yᵢ, uᵢ = yᵢ[1:L_f_prototype], yᵢ[(L_f_prototype + 1):end] + yᵢ₊₁, uᵢ₊₁ = yᵢ₊₁[1:L_f_prototype], yᵢ₊₁[(L_f_prototype + 1):end] + + # Load interpolation residual + for j in 1:stage + tmp = get_tmp(y[ctr + j], u) + K[:, j] = tmp[1:3] + end + + # Update interpolation residual + for r in 1:stage + @. tmp1 = yᵢ + @. tmpu = uᵢ + __maybe_matmul!(tmp1, K, a[:, r], h, T(1)) + f!(residual[ctr + r], vcat(tmp1, tmpu), p, mesh[i] + c[r] * h) + residual[ctr + r] .-= K[:, r] + end + + # Update mesh point residual + residᵢ = residual[ctr] + @. residᵢ = yᵢ₊₁ - yᵢ + __maybe_matmul!(residᵢ, K, b, -h, T(1)) + ctr += stage + 1 + end +end + +@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{false}, y, u, p, + mesh, mesh_dt, stage::Int, f_prototype, ::DiffCacheNeeded, ::Val{false}) (; c, a, b) = TU tmp1 = get_tmp(fᵢ_cache, u) K = get_tmp(k_discrete[1], u) # Not optimal # TODO @@ -42,8 +85,9 @@ end end end -@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{false}, - y, u, p, mesh, mesh_dt, stage::Int, ::NoDiffCacheNeeded) +@views function Φ!( + residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{false}, y, u, p, mesh, + mesh_dt, stage::Int, f_prototype, ::NoDiffCacheNeeded, ::Val{false}) (; c, a, b) = TU tmp1 = similar(fᵢ_cache) K = similar(k_discrete[1]) @@ -114,8 +158,38 @@ function FIRK_nlsolve(K, p_nlsolve, f!, TU::FIRKTableau{true}, p_f!) return res end -@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{true}, y, - u, p, mesh, mesh_dt, stage::Int, cache, ::DiffCacheNeeded) +@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{true}, y, u, p, + mesh, mesh_dt, stage::Int, cache, ::DiffCacheNeeded, ::Val{true}) + (; b) = TU + (; nest_prob, alg) = cache + + T = eltype(u) + nestprob_p = vcat(T(mesh[1]), T(mesh_dt[1]), get_tmp(y[1], u)) + nest_nlsolve_alg = __concrete_solve_algorithm(nest_prob, alg.nlsolve) + + for i in eachindex(k_discrete) + residᵢ = residual[i] + h = mesh_dt[i] + + yᵢ = get_tmp(y[i], u) + yᵢ₊₁ = get_tmp(y[i + 1], u) + + nestprob_p[1] = T(mesh[i]) + nestprob_p[2] = T(mesh_dt[i]) + nestprob_p[3:end] = yᵢ + + K = get_tmp(k_discrete[i], u) + + _nestprob = remake(nest_prob, p = nestprob_p) + nestsol = __solve(_nestprob, nest_nlsolve_alg; alg.nested_nlsolve_kwargs...) + @. K = nestsol.u + @. residᵢ = yᵢ₊₁ - yᵢ + __maybe_matmul!(residᵢ, nestsol.u, b, -h, T(1)) + end +end + +@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{true}, y, u, p, + mesh, mesh_dt, stage::Int, cache, ::DiffCacheNeeded, ::Val{false}) (; b) = TU (; nest_prob, alg) = cache @@ -144,8 +218,8 @@ end end end -@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{true}, y, - u, p, mesh, mesh_dt, stage::Int, cache, ::NoDiffCacheNeeded) +@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::FIRKTableau{true}, y, u, p, + mesh, mesh_dt, stage::Int, cache, ::NoDiffCacheNeeded, ::Val{false}) (; b) = TU (; nest_prob, alg) = cache diff --git a/lib/BoundaryValueDiffEqFIRK/src/firk.jl b/lib/BoundaryValueDiffEqFIRK/src/firk.jl index f74327e74..4993840fa 100644 --- a/lib/BoundaryValueDiffEqFIRK/src/firk.jl +++ b/lib/BoundaryValueDiffEqFIRK/src/firk.jl @@ -12,6 +12,7 @@ alg # FIRK methods TU # FIRK Tableau ITU # FIRK Interpolation Tableau + f_prototype bcresid_prototype # Everything below gets resized in adaptive methods mesh # Discrete mesh @@ -47,6 +48,7 @@ Base.eltype(::FIRKCacheNested{iip, T}) where {iip, T} = T alg # FIRK methods TU # FIRK Tableau ITU # FIRK Interpolation Tableau + f_prototype bcresid_prototype # Everything below gets resized in adaptive methods mesh # Discrete mesh @@ -113,6 +115,10 @@ function init_nested( end diffcache = __cache_trait(alg.jac_alg) fit_parameters = haskey(prob.kwargs, :fit_parameters) + constraint = (!isnothing(prob.f.inequality)) || + (!isnothing(prob.f.equality)) || + (!isnothing(prob.lb)) || + (!isnothing(prob.ub)) t₀, t₁ = prob.tspan ig, T, @@ -134,17 +140,34 @@ function init_nested( y = __alloc.(copy.(y₀.u)) TU, ITU = constructRK(alg, T) stage = alg_stage(alg) + f_prototype = isnothing(prob.f.f_prototype) ? nothing : __vec(prob.f.f_prototype) + L_f_prototype = isnothing(f_prototype) ? M : length(f_prototype) - k_discrete = [__maybe_allocate_diffcache(safe_similar(X, M, stage), chunksize, alg.jac_alg) - for _ in 1:Nig] + k_discrete = if !constraint + [__maybe_allocate_diffcache(safe_similar(X, M, stage), chunksize, alg.jac_alg) + for _ in 1:Nig] + else + [__maybe_allocate_diffcache(safe_similar(X, L_f_prototype, stage), chunksize, alg.jac_alg) + for _ in 1:Nig] + end bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X) residual = if iip - if prob.problem_type isa TwoPointBVProblem - vcat([__alloc(__vec(bcresid_prototype))], __alloc.(copy.(@view(y₀.u[2:end])))) + if !constraint + if prob.problem_type isa TwoPointBVProblem + vcat([__alloc(__vec(bcresid_prototype))], __alloc.(copy.(@view(y₀.u[2:end])))) + else + vcat([__alloc(bcresid_prototype)], __alloc.(copy.(@view(y₀.u[2:end])))) + end else - vcat([__alloc(bcresid_prototype)], __alloc.(copy.(@view(y₀.u[2:end])))) + if prob.problem_type isa TwoPointBVProblem + vcat([__alloc(__vec(bcresid_prototype))], + __alloc.(copy.([f_prototype for _ in 1:length(y₀.u[2:end])]))) + else + vcat([__alloc(bcresid_prototype)], + __alloc.(copy.([f_prototype for _ in 1:length(y₀.u[2:end])]))) + end end else nothing @@ -205,9 +228,9 @@ function init_nested( return FIRKCacheNested{iip, T, typeof(diffcache), fit_parameters}( alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type, prob.p, - alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete, y, y₀, residual, - fᵢ_cache, fᵢ₂_cache, defect, nestprob, resid₁_size, nlsolve_kwargs, - optimize_kwargs, (; abstol, dt, adaptive, controller, kwargs...)) + alg, TU, ITU, f_prototype, bcresid_prototype, mesh, mesh_dt, k_discrete, + y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, nestprob, resid₁_size, + nlsolve_kwargs, optimize_kwargs, (; abstol, dt, adaptive, controller, kwargs...)) end function init_expanded( @@ -222,6 +245,10 @@ function init_expanded( end diffcache = __cache_trait(alg.jac_alg) fit_parameters = haskey(prob.kwargs, :fit_parameters) + constraint = (!isnothing(prob.f.inequality)) || + (!isnothing(prob.f.equality)) || + (!isnothing(prob.lb)) || + (!isnothing(prob.ub)) t₀, t₁ = prob.tspan ig, T, @@ -233,6 +260,8 @@ function init_expanded( TU, ITU = constructRK(alg, T) stage = alg_stage(alg) + f_prototype = isnothing(prob.f.f_prototype) ? nothing : __vec(prob.f.f_prototype) + L_f_prototype = isnothing(f_prototype) ? M : length(f_prototype) chunksize = pickchunksize(M + M * Nig * (stage + 1)) __alloc = @closure x -> __maybe_allocate_diffcache(vec(x), chunksize, alg.jac_alg) @@ -245,16 +274,31 @@ function init_expanded( y₀ = extend_y(_y₀, Nig + 1, stage) y = __alloc.(copy.(y₀.u)) # Runtime dispatch - k_discrete = [__maybe_allocate_diffcache(safe_similar(X, M, stage), chunksize, alg.jac_alg) - for _ in 1:Nig] # Runtime dispatch + k_discrete = if !constraint + [__maybe_allocate_diffcache(safe_similar(X, M, stage), chunksize, alg.jac_alg) + for _ in 1:Nig] # Runtime dispatch + else + [__maybe_allocate_diffcache(safe_similar(X, L_f_prototype, stage), chunksize, alg.jac_alg) + for _ in 1:Nig] # Runtime dispatch + end bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X) residual = if iip - if prob.problem_type isa TwoPointBVProblem - vcat([__alloc(__vec(bcresid_prototype))], __alloc.(copy.(@view(y₀.u[2:end])))) + if !constraint + if prob.problem_type isa TwoPointBVProblem + vcat([__alloc(__vec(bcresid_prototype))], __alloc.(copy.(@view(y₀.u[2:end])))) + else + vcat([__alloc(bcresid_prototype)], __alloc.(copy.(@view(y₀.u[2:end])))) + end else - vcat([__alloc(bcresid_prototype)], __alloc.(copy.(@view(y₀.u[2:end])))) + if prob.problem_type isa TwoPointBVProblem + vcat([__alloc(__vec(bcresid_prototype))], + __alloc.(copy.([f_prototype for _ in 1:length(y₀.u[2:end])]))) + else + vcat([__alloc(bcresid_prototype)], + __alloc.(copy.([f_prototype for _ in 1:length(y₀.u[2:end])]))) + end end else nothing @@ -303,9 +347,9 @@ function init_expanded( prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob return FIRKCacheExpand{iip, T, typeof(diffcache), fit_parameters}( - alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type, - prob.p, alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete, y, - y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, resid₁_size, nlsolve_kwargs, + alg_order(alg), stage, M, size(X), f, bc, prob_, prob.problem_type, prob.p, + alg, TU, ITU, f_prototype, bcresid_prototype, mesh, mesh_dt, k_discrete, + y, y₀, residual, fᵢ_cache, fᵢ₂_cache, defect, resid₁_size, nlsolve_kwargs, optimize_kwargs, (; abstol, dt, adaptive, controller, kwargs...)) end @@ -456,6 +500,16 @@ end # Constructing the Nonlinear Problem function __construct_problem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{iip}}, y::AbstractVector, y₀::AbstractVectorOfArray) where {iip} + constraint = (!isnothing(cache.prob.f.inequality)) || + (!isnothing(cache.prob.f.equality)) || + (!isnothing(cache.prob.lb)) || + (!isnothing(cache.prob.ub)) + return __construct_problem(cache, y, y₀, Val(constraint)) +end + +# Constructing the Nonlinear Problem +function __construct_problem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{iip}}, + y::AbstractVector, y₀::AbstractVectorOfArray, constraint) where {iip} pt = cache.problem_type (; jac_alg) = cache.alg @@ -476,7 +530,7 @@ function __construct_problem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{ @closure (du, u, p) -> __firk_loss_collocation!( - du, u, p, cache.y, cache.mesh, cache.residual, cache, trait) + du, u, p, cache.y, cache.mesh, cache.residual, cache, trait, constraint) else @closure (u, p) -> __firk_loss_collocation( @@ -487,7 +541,7 @@ function __construct_problem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{ @closure (du, u, p) -> __firk_loss!(du, u, p, cache.y, pt, cache.bc, cache.residual, - cache.mesh, cache, eval_sol, trait) + cache.mesh, cache, eval_sol, trait, constraint) else @closure (u, p) -> __firk_loss( @@ -497,17 +551,81 @@ function __construct_problem(cache::Union{FIRKCacheNested{iip}, FIRKCacheExpand{ if !isnothing(cache.alg.optimize) loss = @closure (du, u, - p) -> __firk_loss!( - du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, trait) + p) -> __firk_loss!(du, u, p, cache.y, pt, cache.bc, cache.residual, + cache.bcresid_prototype, cache.mesh, cache, eval_sol, trait, constraint) + end + + return __construct_problem(cache, y, loss_bc, loss_collocation, loss, pt, constraint) +end + +function __construct_problem( + cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C, + loss::LF, ::StandardBVProblem, ::Val{true}) where {iip, BC, C, LF} + (; alg, stage, bcresid_prototype, f_prototype) = cache + (; jac_alg) = alg + (; bc_diffmode) = jac_alg + N = length(cache.mesh) + + resid_bc = cache.bcresid_prototype + L = length(resid_bc) + L_f_prototype = length(f_prototype) + resid_collocation = safe_similar(y, L_f_prototype * (N - 1) * (stage + 1)) + + cache_bc = if iip + DI.prepare_jacobian(loss_bc, resid_bc, bc_diffmode, y, Constant(cache.p)) + else + DI.prepare_jacobian(loss_bc, bc_diffmode, y, Constant(cache.p)) + end + + nonbc_diffmode = AutoSparse(get_dense_ad(jac_alg.nonbc_diffmode), + sparsity_detector = __default_sparsity_detector(jac_alg.nonbc_diffmode), + coloring_algorithm = __default_coloring_algorithm(jac_alg.nonbc_diffmode)) + + cache_collocation = if iip + DI.prepare_jacobian( + loss_collocation, resid_collocation, nonbc_diffmode, y, Constant(cache.p)) + else + DI.prepare_jacobian(loss_collocation, nonbc_diffmode, y, Constant(cache.p)) + end + + J_bc = if iip + DI.jacobian(loss_bc, resid_bc, cache_bc, bc_diffmode, y, Constant(cache.p)) + else + DI.jacobian(loss_bc, cache_bc, bc_diffmode, y, Constant(cache.p)) + end + J_c = if iip + DI.jacobian(loss_collocation, resid_collocation, cache_collocation, + nonbc_diffmode, y, Constant(cache.p)) + else + DI.jacobian( + loss_collocation, cache_collocation, nonbc_diffmode, y, Constant(cache.p)) + end + + jac_prototype = vcat(J_bc, J_c) + jac = if iip + @closure (J, + u, + p) -> __firk_mpoint_jacobian!( + J, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, + loss_bc, loss_collocation, resid_bc, resid_collocation, L, cache.p) + else + @closure (u, + p) -> __firk_mpoint_jacobian( + jac_prototype, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, + cache_collocation, loss_bc, loss_collocation, L, cache.p) end - return __construct_problem(cache, y, loss_bc, loss_collocation, loss, pt) + resid_prototype = vcat(resid_bc, resid_collocation) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, bcresid_prototype, + f_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1) end function __construct_problem( cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C, - loss::LF, ::StandardBVProblem) where {iip, BC, C, LF} - (; alg, stage) = cache + loss::LF, ::StandardBVProblem, ::Val{false}) where {iip, BC, C, LF} + (; alg, stage, bcresid_prototype, f_prototype) = cache (; jac_alg) = alg (; bc_diffmode) = jac_alg N = length(cache.mesh) @@ -586,15 +704,66 @@ function __construct_problem( resid_prototype = vcat(resid_bc, resid_collocation) return __construct_internal_problem( - cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, - resid_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1) + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, bcresid_prototype, + f_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1) +end + +function __construct_problem( + cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C, + loss::LF, ::TwoPointBVProblem, ::Val{true}) where {iip, BC, C, LF} + (; jac_alg) = cache.alg + (; stage, bcresid_prototype, f_prototype) = cache + N = length(cache.mesh) + + resid_collocation = safe_similar(y, cache.M * (N - 1) * (stage + 1)) + + resid = vcat( + @view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), resid_collocation, + @view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])) + L = length(cache.bcresid_prototype) + + diffmode = if jac_alg.diffmode isa AutoSparse + AutoSparse(get_dense_ad(jac_alg.diffmode); + sparsity_detector = __default_sparsity_detector(jac_alg.diffmode), + coloring_algorithm = __default_coloring_algorithm(jac_alg.diffmode)) + else + jac_alg.diffmode + end + + diffcache = if iip + DI.prepare_jacobian(loss, resid, diffmode, y, Constant(cache.p)) + else + DI.prepare_jacobian(loss, diffmode, y, Constant(cache.p)) + end + + jac_prototype = if iip + DI.jacobian(loss, resid, diffcache, diffmode, y, Constant(cache.p)) + else + DI.jacobian(loss, diffcache, diffmode, y, Constant(cache.p)) + end + + jac = if iip + @closure (J, u, + p) -> __firk_2point_jacobian!(J, u, diffmode, diffcache, loss, resid, cache.p) + else + @closure (u, + p) -> __firk_2point_jacobian( + u, jac_prototype, diffmode, diffcache, loss, cache.p) + end + + resid_prototype = copy(resid) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, bcresid_prototype, + f_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1) end function __construct_problem( cache::FIRKCacheExpand{iip}, y, loss_bc::BC, loss_collocation::C, - loss::LF, ::TwoPointBVProblem) where {iip, BC, C, LF} + loss::LF, ::TwoPointBVProblem, ::Val{false}) where {iip, BC, C, LF} (; jac_alg) = cache.alg - (; stage) = cache + (; stage, bcresid_prototype, f_prototype) = cache N = length(cache.mesh) resid_collocation = safe_similar(y, cache.M * (N - 1) * (stage + 1)) @@ -645,16 +814,78 @@ function __construct_problem( end resid_prototype = copy(resid) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, + jac_prototype, resid_prototype, bcresid_prototype, + f_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1) +end + +function __construct_problem( + cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C, + loss::LF, ::StandardBVProblem, ::Val{true}) where {iip, BC, C, LF} + (; jac_alg) = cache.alg + (; bc_diffmode) = jac_alg + (; bcresid_prototype, f_prototype) = cache + N = length(cache.mesh) + resid_bc = cache.bcresid_prototype + L = length(resid_bc) + resid_collocation = safe_similar(y, cache.M * (N - 1)) + cache_bc = if iip + DI.prepare_jacobian(loss_bc, resid_bc, bc_diffmode, y, Constant(cache.p)) + else + DI.prepare_jacobian(loss_bc, bc_diffmode, y, Constant(cache.p)) + end + + nonbc_diffmode = AutoSparse(get_dense_ad(jac_alg.nonbc_diffmode), + sparsity_detector = __default_sparsity_detector(jac_alg.nonbc_diffmode), + coloring_algorithm = __default_coloring_algorithm(jac_alg.nonbc_diffmode)) + + cache_collocation = if iip + DI.prepare_jacobian( + loss_collocation, resid_collocation, nonbc_diffmode, y, Constant(cache.p)) + else + DI.prepare_jacobian(loss_collocation, nonbc_diffmode, y, Constant(cache.p)) + end + + J_bc = if iip + DI.jacobian(loss_bc, resid_bc, cache_bc, bc_diffmode, y, Constant(cache.p)) + else + DI.jacobian(loss_bc, cache_bc, bc_diffmode, y, Constant(cache.p)) + end + J_c = if iip + DI.jacobian(loss_collocation, resid_collocation, cache_collocation, + nonbc_diffmode, y, Constant(cache.p)) + else + DI.jacobian( + loss_collocation, cache_collocation, nonbc_diffmode, y, Constant(cache.p)) + end + + jac_prototype = vcat(J_bc, J_c) + jac = if iip + @closure (J, + u, + p) -> __firk_mpoint_jacobian!( + J, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, + loss_bc, loss_collocation, resid_bc, resid_collocation, L, cache.p) + else + @closure (u, + p) -> __firk_mpoint_jacobian( + jac_prototype, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, + cache_collocation, loss_bc, loss_collocation, L, cache.p) + end + + resid_prototype = vcat(resid_bc, resid_collocation) return __construct_internal_problem( cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, - resid_prototype, y, cache.p, cache.M, (N - 1) * (stage + 1) + 1) + resid_prototype, bcresid_prototype, f_prototype, y, cache.p, cache.M, N) end function __construct_problem( cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C, - loss::LF, ::StandardBVProblem) where {iip, BC, C, LF} + loss::LF, ::StandardBVProblem, ::Val{false}) where {iip, BC, C, LF} (; jac_alg) = cache.alg (; bc_diffmode) = jac_alg + (; bcresid_prototype, f_prototype) = cache N = length(cache.mesh) resid_bc = cache.bcresid_prototype L = length(resid_bc) @@ -726,14 +957,61 @@ function __construct_problem( resid_prototype = vcat(resid_bc, resid_collocation) return __construct_internal_problem( - cache.prob, cache.problem_type, cache.alg, loss, jac, - jac_prototype, resid_prototype, y, cache.p, cache.M, N) + cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, + resid_prototype, bcresid_prototype, f_prototype, y, cache.p, cache.M, N) end function __construct_problem( cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C, - loss::LF, ::TwoPointBVProblem) where {iip, BC, C, LF} + loss::LF, ::TwoPointBVProblem, ::Val{true}) where {iip, BC, C, LF} (; jac_alg) = cache.alg + (; bcresid_prototype, f_prototype) = cache + N = length(cache.mesh) + + resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), + safe_similar(y, cache.M * (N - 1)), + @view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])) + + diffmode = if jac_alg.diffmode isa AutoSparse + AutoSparse(get_dense_ad(jac_alg.diffmode); + sparsity_detector = __default_sparsity_detector(jac_alg.diffmode), + coloring_algorithm = __default_coloring_algorithm(jac_alg.diffmode)) + else + jac_alg.diffmode + end + + diffcache = if iip + DI.prepare_jacobian(loss, resid, diffmode, y, Constant(cache.p)) + else + DI.prepare_jacobian(loss, diffmode, y, Constant(cache.p)) + end + + jac_prototype = if iip + DI.jacobian(loss, resid, diffcache, diffmode, y, Constant(cache.p)) + else + DI.jacobian(loss, diffcache, diffmode, y, Constant(cache.p)) + end + + jac = if iip + @closure (J, u, + p) -> __firk_2point_jacobian!(J, u, diffmode, diffcache, loss, resid, cache.p) + else + @closure (u, + p) -> __firk_2point_jacobian( + u, jac_prototype, diffmode, diffcache, loss, cache.p) + end + + resid_prototype = copy(resid) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, + resid_prototype, bcresid_prototype, f_prototype, y, cache.p, cache.M, N) +end + +function __construct_problem( + cache::FIRKCacheNested{iip}, y, loss_bc::BC, loss_collocation::C, + loss::LF, ::TwoPointBVProblem, ::Val{false}) where {iip, BC, C, LF} + (; jac_alg) = cache.alg + (; bcresid_prototype, f_prototype) = cache N = length(cache.mesh) resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), @@ -776,26 +1054,26 @@ function __construct_problem( resid_prototype = copy(resid) return __construct_internal_problem( - cache.prob, cache.problem_type, cache.alg, loss, jac, - jac_prototype, resid_prototype, y, cache.p, cache.M, N) + cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, + resid_prototype, bcresid_prototype, f_prototype, y, cache.p, cache.M, N) end -@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, - mesh, cache, eval_sol, trait::DiffCacheNeeded) where {BC} +@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, + cache, eval_sol, trait::DiffCacheNeeded, constraint) where {BC} y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] - Φ!(resids[2:end], cache, y_, u, trait) + Φ!(resids[2:end], cache, y_, u, trait, constraint) eval_sol.u[1:end] .= y_ eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh) recursive_flatten!(resid, resids) return nothing end -@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, - mesh, cache, eval_sol, trait::NoDiffCacheNeeded) where {BC} +@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, + cache, eval_sol, trait::NoDiffCacheNeeded, constraint) where {BC} y_ = recursive_unflatten!(y, u) resids = [r for r in residual] - Φ!(resids[2:end], cache, y_, u, trait) + Φ!(resids[2:end], cache, y_, u, trait, constraint) eval_sol.u[1:end] .= y_ eval_bc_residual!(resids[1], pt, bc!, eval_sol, p, mesh) recursive_flatten!(resid, resids) @@ -803,44 +1081,46 @@ end end # loss function for optimization based solvers -@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, - residual, mesh, cache, trait) where {BC} +@views function __firk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, + bcresid_prototype, mesh, cache, _, trait, constraint) where {BC} bcresid = length(cache.bcresid_prototype) __firk_loss_bc!(resid[1:bcresid], u, p, pt, bc!, y, mesh, cache, trait) __firk_loss_collocation!( - resid[(bcresid + 1):end], u, p, y, mesh, residual, cache, trait) + resid[(bcresid + 1):end], u, p, y, mesh, residual, cache, trait, constraint) return nothing end @views function __firk_loss!( resid, u, p, y::AbstractVector, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, - residual, mesh, cache, _, trait::DiffCacheNeeded) where {BC1, BC2} + residual, mesh, cache, _, trait::DiffCacheNeeded, constraint) where {BC1, BC2} y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] resida = resids[1][1:prod(cache.resid_size[1])] residb = resids[1][(prod(cache.resid_size[1]) + 1):end] eval_bc_residual!((resida, residb), pt, bc!, y_, p, mesh) - Φ!(resids[2:end], cache, y_, u, trait) + Φ!(resids[2:end], cache, y_, u, trait, constraint) recursive_flatten_twopoint!(resid, resids, cache.resid_size) return nothing end -@views function __firk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, - residual, mesh, cache, _, trait::NoDiffCacheNeeded) where {BC1, BC2} +@views function __firk_loss!( + resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual, + mesh, cache, _, trait::NoDiffCacheNeeded, constraint) where {BC1, BC2} y_ = recursive_unflatten!(y, u) soly_ = VectorOfArray(y_) resida = residual[1][1:prod(cache.resid_size[1])] residb = residual[1][(prod(cache.resid_size[1]) + 1):end] eval_bc_residual!((resida, residb), pt, bc!, soly_, p, mesh) - Φ!(residual[2:end], cache, y_, u, trait) + Φ!(residual[2:end], cache, y_, u, trait, constraint) recursive_flatten_twopoint!(resid, residual, cache.resid_size) return nothing end # loss function for optimization based solvers -@views function __firk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, - residual, mesh, cache, trait) where {BC1, BC2} - __firk_loss!(resid, u, p, y, pt, bc!, residual, mesh, cache, nothing, trait) +@views function __firk_loss!( + resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual, + bcresid_prototype, mesh, cache, _, trait, constraint) where {BC1, BC2} + __firk_loss!(resid, u, p, y, pt, bc!, residual, mesh, cache, nothing, trait, constraint) return nothing end @@ -877,19 +1157,19 @@ end end @views function __firk_loss_collocation!( - resid, u, p, y, mesh, residual, cache, trait::DiffCacheNeeded) + resid, u, p, y, mesh, residual, cache, trait::DiffCacheNeeded, constraint) y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual[2:end]] - Φ!(resids, cache, y_, u, trait) + Φ!(resids, cache, y_, u, trait, constraint) recursive_flatten!(resid, resids) return nothing end @views function __firk_loss_collocation!( - resid, u, p, y, mesh, residual, cache, trait::NoDiffCacheNeeded) + resid, u, p, y, mesh, residual, cache, trait::NoDiffCacheNeeded, constraint) y_ = recursive_unflatten!(y, u) resids = [r for r in residual[2:end]] - Φ!(resids, cache, y_, u, trait) + Φ!(resids, cache, y_, u, trait, constraint) recursive_flatten!(resid, resids) return nothing end diff --git a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl index 9801f9b31..9606e118e 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/BoundaryValueDiffEqMIRK.jl @@ -23,7 +23,8 @@ using BoundaryValueDiffEqCore: AbstractBoundaryValueDiffEqAlgorithm, __use_both_error_control, __default_coloring_algorithm, DiffCacheNeeded, NoDiffCacheNeeded, __split_kwargs, __concrete_kwargs, __FastShortcutNonlinearPolyalg, - __construct_internal_problem, __internal_solve + __construct_internal_problem, __internal_solve, + __default_sparsity_detector using ConcreteStructs: @concrete using DiffEqBase: DiffEqBase diff --git a/lib/BoundaryValueDiffEqMIRK/src/collocation.jl b/lib/BoundaryValueDiffEqMIRK/src/collocation.jl index cb971737b..327502fae 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/collocation.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/collocation.jl @@ -1,10 +1,44 @@ -function Φ!(residual, cache::MIRKCache, y, u, trait) - return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, - y, u, cache.p, cache.mesh, cache.mesh_dt, cache.stage, trait) +function Φ!(residual, cache::MIRKCache, y, u, trait, constraint) + return Φ!(residual, cache.fᵢ_cache, cache.k_discrete, cache.f, cache.TU, y, u, cache.p, + cache.mesh, cache.mesh_dt, cache.stage, cache.f_prototype, trait, constraint) end -@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, - y, u, p, mesh, mesh_dt, stage::Int, ::DiffCacheNeeded) +@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, y, u, p, mesh, + mesh_dt, stage::Int, f_prototype, ::DiffCacheNeeded, ::Val{true}) + (; c, v, x, b) = TU + L_f_prototype = length(f_prototype) + + tmpy, + tmpu = get_tmp(fᵢ_cache, u)[1:L_f_prototype], + get_tmp(fᵢ_cache, u)[(L_f_prototype + 1):end] + + T = eltype(u) + for i in eachindex(k_discrete) + K = get_tmp(k_discrete[i], u) + residᵢ = residual[i] + h = mesh_dt[i] + + yᵢ = get_tmp(y[i], u) + yᵢ₊₁ = get_tmp(y[i + 1], u) + + yᵢ, uᵢ = yᵢ[1:L_f_prototype], yᵢ[(L_f_prototype + 1):end] + yᵢ₊₁, uᵢ₊₁ = yᵢ₊₁[1:L_f_prototype], yᵢ₊₁[(L_f_prototype + 1):end] + + for r in 1:stage + @. tmpy = (1 - v[r]) * yᵢ + v[r] * yᵢ₊₁ + @. tmpu = (1 - v[r]) * uᵢ + v[r] * uᵢ₊₁ + __maybe_matmul!(tmpy, K[:, 1:(r - 1)], x[r, 1:(r - 1)], h, T(1)) + f!(K[:, r], vcat(tmpy, tmpu), p, mesh[i] + c[r] * h) + end + + # Update residual + @. residᵢ = yᵢ₊₁ - yᵢ + __maybe_matmul!(residᵢ, K[:, 1:stage], b[1:stage], -h, T(1)) + end +end + +@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, y, u, p, mesh, + mesh_dt, stage::Int, _, ::DiffCacheNeeded, constraint::Val{false}) (; c, v, x, b) = TU tmp = get_tmp(fᵢ_cache, u) @@ -29,8 +63,8 @@ end end end -@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, y, - u, p, mesh, mesh_dt, stage::Int, ::NoDiffCacheNeeded) +@views function Φ!(residual, fᵢ_cache, k_discrete, f!, TU::MIRKTableau, y, u, p, + mesh, mesh_dt, stage::Int, _, ::NoDiffCacheNeeded, ::Val{false}) (; c, v, x, b) = TU tmp = similar(fᵢ_cache) diff --git a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl index aafd5e6cc..cf896c028 100644 --- a/lib/BoundaryValueDiffEqMIRK/src/mirk.jl +++ b/lib/BoundaryValueDiffEqMIRK/src/mirk.jl @@ -12,6 +12,7 @@ alg # MIRK methods TU # MIRK Tableau ITU # MIRK Interpolation Tableau + f_prototype bcresid_prototype # Everything below gets resized in adaptive methods mesh # Discrete mesh @@ -43,6 +44,10 @@ function SciMLBase.__init( diffcache = __cache_trait(alg.jac_alg) @assert (iip || isnothing(alg.optimize)) "Out-of-place constraints don't allow optimization solvers " fit_parameters = haskey(prob.kwargs, :fit_parameters) + constraint = (!isnothing(prob.f.inequality)) || + (!isnothing(prob.f.equality)) || + (!isnothing(prob.lb)) || + (!isnothing(prob.ub)) t₀, t₁ = prob.tspan ig, T, @@ -64,44 +69,79 @@ function SciMLBase.__init( y = __alloc.(copy.(y₀.u)) TU, ITU = constructMIRK(alg, T) stage = alg_stage(alg) + f_prototype = isnothing(prob.f.f_prototype) ? nothing : __vec(prob.f.f_prototype) + L_f_prototype = isnothing(f_prototype) ? N : length(f_prototype) - k_discrete = [__maybe_allocate_diffcache(safe_similar(X, N, stage), chunksize, alg.jac_alg) - for _ in 1:Nig] - k_interp = VectorOfArray([similar(X, N, ITU.s_star - stage) for _ in 1:Nig]) + k_discrete = if !constraint + [__maybe_allocate_diffcache(safe_similar(X, N, stage), chunksize, alg.jac_alg) + for _ in 1:Nig] + else + [__maybe_allocate_diffcache(safe_similar(X, L_f_prototype, stage), chunksize, alg.jac_alg) + for _ in 1:Nig] + end + k_interp = if !constraint + VectorOfArray([similar(X, N, ITU.s_star - stage) for _ in 1:Nig]) + else + VectorOfArray([similar(X, L_f_prototype, ITU.s_star - stage) for _ in 1:Nig]) + end bcresid_prototype, resid₁_size = __get_bcresid_prototype(prob.problem_type, prob, X) residual = if iip - if prob.problem_type isa TwoPointBVProblem - vcat([__alloc(__vec(bcresid_prototype))], __alloc.(copy.(@view(y₀.u[2:end])))) + if !constraint + if prob.problem_type isa TwoPointBVProblem + vcat([__alloc(__vec(bcresid_prototype))], __alloc.(copy.(@view(y₀.u[2:end])))) + else + vcat([__alloc(bcresid_prototype)], __alloc.(copy.(@view(y₀.u[2:end])))) + end else - vcat([__alloc(bcresid_prototype)], __alloc.(copy.(@view(y₀.u[2:end])))) + if prob.problem_type isa TwoPointBVProblem + vcat([__alloc(__vec(bcresid_prototype))], __alloc.(copy.([f_prototype + for _ in 1:Nig]))) + else + vcat([__alloc(bcresid_prototype)], __alloc.(copy.([f_prototype + for _ in 1:Nig]))) + end end else nothing end use_both = __use_both_error_control(controller) - errors = VectorOfArray([similar(X, ifelse(adaptive, N, 0)) - for _ in 1:ifelse(use_both, 2Nig, Nig)]) - new_stages = VectorOfArray([similar(X, N) for _ in 1:Nig]) + errors = if !constraint + VectorOfArray([similar(X, ifelse(adaptive, N, 0)) + for _ in 1:ifelse(use_both, 2Nig, Nig)]) + else + VectorOfArray([similar(X, ifelse(adaptive, L_f_prototype, 0)) + for _ in 1:ifelse(use_both, 2Nig, Nig)]) + end + new_stages = if !constraint + VectorOfArray([similar(X, N) for _ in 1:Nig]) + else + VectorOfArray([similar(X, L_f_prototype) for _ in 1:Nig]) + end # Transform the functions to handle non-vector inputs bcresid_prototype = __vec(bcresid_prototype) f, bc = if X isa AbstractVector - #TODO: Simplify the logic by wrapping the functions - if fit_parameters == true + f_wrapped = prob.f + bc_wrapped = prob.f.bc + if fit_parameters l_parameters = length(prob.p) - vecf! = function (du, u, p, t) - prob.f(du, u, @view(u[(end - l_parameters + 1):end]), t) - du[(end - l_parameters + 1):end] .= 0 + base_f = f_wrapped + f_wrapped = @closure (du, + u, + p, + t) -> begin + @inbounds @views begin + base_f(du, u, u[(end - l_parameters + 1):end], t) + fill!(du[(end - l_parameters + 1):end], zero(eltype(du))) + end + return nothing end - vecbc! = prob.f.bc - vecf!, vecbc! - else - prob.f, prob.f.bc end + f_wrapped, bc_wrapped elseif iip vecf! = @closure (du, u, p, t) -> __vec_f!(du, u, p, t, prob.f, size(X)) vecbc! = if !(prob.problem_type isa TwoPointBVProblem) @@ -128,9 +168,9 @@ function SciMLBase.__init( prob_ = !(prob.u0 isa AbstractArray) ? remake(prob; u0 = X) : prob return MIRKCache{iip, T, use_both, typeof(diffcache), fit_parameters}( - alg_order(alg), stage, N, size(X), f, bc, prob_, prob.problem_type, prob.p, - alg, TU, ITU, bcresid_prototype, mesh, mesh_dt, k_discrete, k_interp, y, y₀, - residual, fᵢ_cache, fᵢ₂_cache, errors, new_stages, resid₁_size, nlsolve_kwargs, + alg_order(alg), stage, N, size(X), f, bc, prob_, prob.problem_type, prob.p, alg, + TU, ITU, f_prototype, bcresid_prototype, mesh, mesh_dt, k_discrete, k_interp, y, + y₀, residual, fᵢ_cache, fᵢ₂_cache, errors, new_stages, resid₁_size, nlsolve_kwargs, optimize_kwargs, (; abstol, dt, adaptive, controller, fit_parameters, kwargs...)) end @@ -239,6 +279,15 @@ end # Constructing the Nonlinear Problem function __construct_problem(cache::MIRKCache{iip}, y::AbstractVector, y₀::AbstractVectorOfArray) where {iip} + constraint = (!isnothing(cache.prob.f.inequality)) || + (!isnothing(cache.prob.f.equality)) || + (!isnothing(cache.prob.lb)) || + (!isnothing(cache.prob.ub)) + return __construct_problem(cache, y, y₀, Val(constraint)) +end + +function __construct_problem(cache::MIRKCache{iip}, y::AbstractVector, + y₀::AbstractVectorOfArray, constraint) where {iip} pt = cache.problem_type (; jac_alg) = cache.alg @@ -259,7 +308,7 @@ function __construct_problem(cache::MIRKCache{iip}, y::AbstractVector, y₀::Abs @closure (du, u, p) -> __mirk_loss_collocation!( - du, u, p, cache.y, cache.mesh, cache.residual, cache, trait) + du, u, p, cache.y, cache.mesh, cache.residual, cache, trait, constraint) else @closure (u, p) -> __mirk_loss_collocation( @@ -270,7 +319,7 @@ function __construct_problem(cache::MIRKCache{iip}, y::AbstractVector, y₀::Abs @closure (du, u, p) -> __mirk_loss!(du, u, p, cache.y, pt, cache.bc, cache.residual, - cache.mesh, cache, eval_sol, trait) + cache.mesh, cache, eval_sol, trait, constraint) else @closure (u, p) -> __mirk_loss( @@ -280,18 +329,18 @@ function __construct_problem(cache::MIRKCache{iip}, y::AbstractVector, y₀::Abs if !isnothing(cache.alg.optimize) loss = @closure (du, u, - p) -> __mirk_loss!( - du, u, p, cache.y, pt, cache.bc, cache.residual, cache.mesh, cache, trait) + p) -> __mirk_loss!(du, u, p, cache.y, pt, cache.bc, cache.residual, + cache.bcresid_prototype, cache.mesh, cache, eval_sol, trait, constraint) end - return __construct_problem(cache, y, loss_bc, loss_collocation, loss, pt) + return __construct_problem(cache, y, loss_bc, loss_collocation, loss, pt, constraint) end -@views function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, - mesh, cache, EvalSol, trait::DiffCacheNeeded) where {BC} +@views function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, + cache, EvalSol, trait::DiffCacheNeeded, constraint) where {BC} y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] - Φ!(resids[2:end], cache, y_, u, trait) + Φ!(resids[2:end], cache, y_, u, trait, constraint) EvalSol.u[1:end] .= __restructure_sol(y_, cache.in_size) EvalSol.cache.k_discrete[1:end] .= cache.k_discrete eval_bc_residual!(resids[1], pt, bc!, EvalSol, p, mesh) @@ -299,10 +348,10 @@ end return nothing end -@views function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, - mesh, cache, EvalSol, trait::NoDiffCacheNeeded) where {BC} +@views function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, mesh, + cache, EvalSol, trait::NoDiffCacheNeeded, constraint) where {BC} y_ = recursive_unflatten!(y, u) - Φ!(residual[2:end], cache, y_, u, trait) + Φ!(residual[2:end], cache, y_, u, trait, constraint) EvalSol.u[1:end] .= __restructure_sol(y_, cache.in_size) EvalSol.cache.k_discrete[1:end] .= cache.k_discrete eval_bc_residual!(residual[1], pt, bc!, EvalSol, p, mesh) @@ -311,20 +360,21 @@ end end # loss function for optimization based solvers -@views function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, - residual, mesh, cache, trait) where {BC} - bcresid = length(cache.bcresid_prototype) +@views function __mirk_loss!(resid, u, p, y, pt::StandardBVProblem, bc!::BC, residual, + bcresid_prototype, mesh, cache, _, trait, constraint) where {BC} + bcresid = length(bcresid_prototype) __mirk_loss_bc!(resid[1:bcresid], u, p, pt, bc!, y, mesh, cache, trait) __mirk_loss_collocation!( - resid[(bcresid + 1):end], u, p, y, mesh, residual, cache, trait) + resid[(bcresid + 1):end], u, p, y, mesh, residual, cache, trait, constraint) return nothing end -@views function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, - residual, mesh, cache, _, trait::DiffCacheNeeded) where {BC1, BC2} +@views function __mirk_loss!( + resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual, + mesh, cache, _, trait::DiffCacheNeeded, constraint) where {BC1, BC2} y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual] - Φ!(resids[2:end], cache, y_, u, trait) + Φ!(resids[2:end], cache, y_, u, trait, constraint) resida = resids[1][1:prod(cache.resid_size[1])] residb = resids[1][(prod(cache.resid_size[1]) + 1):end] eval_bc_residual!((resida, residb), pt, bc!, y_, p, mesh) @@ -332,10 +382,11 @@ end return nothing end -@views function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, - residual, mesh, cache, _, trait::NoDiffCacheNeeded) where {BC1, BC2} +@views function __mirk_loss!( + resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual, + mesh, cache, _, trait::NoDiffCacheNeeded, constraint) where {BC1, BC2} y_ = recursive_unflatten!(y, u) - Φ!(residual[2:end], cache, y_, u, trait) + Φ!(residual[2:end], cache, y_, u, trait, constraint) resida = residual[1][1:prod(cache.resid_size[1])] residb = residual[1][(prod(cache.resid_size[1]) + 1):end] eval_bc_residual!((resida, residb), pt, bc!, y_, p, mesh) @@ -344,9 +395,10 @@ end end # loss function for optimization based solvers -@views function __mirk_loss!(resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, - residual, mesh, cache, trait) where {BC1, BC2} - __mirk_loss!(resid, u, p, y, pt, bc!, residual, mesh, cache, nothing, trait) +@views function __mirk_loss!( + resid, u, p, y, pt::TwoPointBVProblem, bc!::Tuple{BC1, BC2}, residual, + bcresid_prototype, mesh, cache, _, trait, constraint) where {BC1, BC2} + __mirk_loss!(resid, u, p, y, pt, bc!, residual, mesh, cache, nothing, trait, constraint) return nothing end @@ -384,19 +436,19 @@ end end @views function __mirk_loss_collocation!( - resid, u, p, y, mesh, residual, cache, trait::DiffCacheNeeded) + resid, u, p, y, mesh, residual, cache, trait::DiffCacheNeeded, constraint) y_ = recursive_unflatten!(y, u) resids = [get_tmp(r, u) for r in residual[2:end]] - Φ!(resids, cache, y_, u, trait) + Φ!(resids, cache, y_, u, trait, constraint) recursive_flatten!(resid, resids) return nothing end @views function __mirk_loss_collocation!( - resid, u, p, y, mesh, residual, cache, trait::NoDiffCacheNeeded) + resid, u, p, y, mesh, residual, cache, trait::NoDiffCacheNeeded, constraint) y_ = recursive_unflatten!(y, u) resids = [r for r in residual[2:end]] - Φ!(resids, cache, y_, u, trait) + Φ!(resids, cache, y_, u, trait, constraint) recursive_flatten!(resid, resids) return nothing end @@ -407,15 +459,81 @@ end return mapreduce(vec, vcat, resids) end -function __construct_problem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, - loss::LF, ::StandardBVProblem) where {iip, BC, C, LF} +function __construct_problem( + cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF, + ::StandardBVProblem, constraint::Val{true}) where {iip, BC, C, LF} + (; jac_alg) = cache.alg + (; f_prototype, bcresid_prototype) = cache + (; bc_diffmode) = jac_alg + N = length(cache.mesh) + + resid_bc = bcresid_prototype + L = length(resid_bc) + L_f_prototype = length(f_prototype) + resid_collocation = safe_similar(y, L_f_prototype * (N - 1)) + + cache_bc = if iip + DI.prepare_jacobian(loss_bc, resid_bc, bc_diffmode, y, Constant(cache.p)) + else + DI.prepare_jacobian(loss_bc, bc_diffmode, y, Constant(cache.p)) + end + + nonbc_diffmode = AutoSparse(get_dense_ad(jac_alg.nonbc_diffmode), + sparsity_detector = __default_sparsity_detector(jac_alg.nonbc_diffmode), + coloring_algorithm = __default_coloring_algorithm(jac_alg.nonbc_diffmode)) + cache_collocation = if iip + DI.prepare_jacobian( + loss_collocation, resid_collocation, nonbc_diffmode, y, Constant(cache.p)) + else + DI.prepare_jacobian(loss_collocation, nonbc_diffmode, y, Constant(cache.p)) + end + + J_bc = if iip + DI.jacobian(loss_bc, resid_bc, cache_bc, bc_diffmode, y, Constant(cache.p)) + else + DI.jacobian(loss_bc, cache_bc, bc_diffmode, y, Constant(cache.p)) + end + J_c = if iip + DI.jacobian(loss_collocation, resid_collocation, cache_collocation, + nonbc_diffmode, y, Constant(cache.p)) + else + DI.jacobian( + loss_collocation, cache_collocation, nonbc_diffmode, y, Constant(cache.p)) + end + jac_prototype = vcat(J_bc, J_c) + + jac = if iip + @closure (J, + u, + p) -> __mirk_mpoint_jacobian!( + J, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, cache_collocation, + loss_bc, loss_collocation, resid_bc, resid_collocation, L, cache.p) + else + @closure (u, + p) -> __mirk_mpoint_jacobian( + jac_prototype, J_c, u, bc_diffmode, nonbc_diffmode, cache_bc, + cache_collocation, loss_bc, loss_collocation, L, cache.p) + end + + resid_prototype = vcat(resid_bc, resid_collocation) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, + resid_prototype, bcresid_prototype, f_prototype, y, cache.p, cache.M, N) +end + +# Dispatch for problems with constraints +function __construct_problem( + cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF, + ::StandardBVProblem, constraint::Val{false}) where {iip, BC, C, LF} (; jac_alg) = cache.alg + (; f_prototype, bcresid_prototype) = cache (; bc_diffmode) = jac_alg N = length(cache.mesh) - resid_bc = cache.bcresid_prototype + resid_bc = bcresid_prototype L = length(resid_bc) resid_collocation = safe_similar(y, cache.M * (N - 1)) + resid_prototype = vcat(resid_bc, resid_collocation) cache_bc = if iip DI.prepare_jacobian(loss_bc, resid_bc, bc_diffmode, y, Constant(cache.p)) @@ -482,10 +600,9 @@ function __construct_problem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_colloca cache_collocation, loss_bc, loss_collocation, L, cache.p) end - resid_prototype = vcat(resid_bc, resid_collocation) return __construct_internal_problem( - cache.prob, cache.problem_type, cache.alg, loss, jac, - jac_prototype, resid_prototype, y, cache.p, cache.M, N) + cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, + resid_prototype, bcresid_prototype, f_prototype, y, cache.p, cache.M, N) end function __mirk_mpoint_jacobian!( @@ -530,21 +647,68 @@ function __mirk_mpoint_jacobian( return J end -function __construct_problem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, - loss::LF, ::TwoPointBVProblem) where {iip, BC, C, LF} +function __construct_problem( + cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF, + ::TwoPointBVProblem, constraint::Val{true}) where {iip, BC, C, LF} + (; jac_alg) = cache.alg + (; f_prototype, bcresid_prototype) = cache + N = length(cache.mesh) + L_f_prototype = length(f_prototype) + + resid = vcat(@view(bcresid_prototype[1:prod(cache.resid_size[1])]), + safe_similar(y, L_f_prototype * (N - 1)), + @view(bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])) + + diffmode = if jac_alg.diffmode isa AutoSparse + AutoSparse(get_dense_ad(jac_alg.diffmode); + sparsity_detector = __default_sparsity_detector(jac_alg.diffmode), + coloring_algorithm = __default_coloring_algorithm(jac_alg.diffmode)) + else + jac_alg.diffmode + end + + diffcache = if iip + DI.prepare_jacobian(loss, resid, diffmode, y, Constant(cache.p)) + else + DI.prepare_jacobian(loss, diffmode, y, Constant(cache.p)) + end + + jac_prototype = if iip + DI.jacobian(loss, resid, diffcache, diffmode, y, Constant(cache.p)) + else + DI.jacobian(loss, diffcache, diffmode, y, Constant(cache.p)) + end + + jac = if iip + @closure ( + J, u, p) -> __mirk_2point_jacobian!(J, u, diffmode, diffcache, loss, resid, p) + else + @closure ( + u, p) -> __mirk_2point_jacobian(u, jac_prototype, diffmode, diffcache, loss, p) + end + + resid_prototype = copy(resid) + return __construct_internal_problem( + cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, + resid_prototype, bcresid_prototype, f_prototype, y, cache.p, cache.M, N) +end + +function __construct_problem( + cache::MIRKCache{iip}, y, loss_bc::BC, loss_collocation::C, loss::LF, + ::TwoPointBVProblem, constraint::Val{false}) where {iip, BC, C, LF} (; jac_alg) = cache.alg + (; f_prototype, bcresid_prototype) = cache N = length(cache.mesh) - resid = vcat(@view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), + resid = vcat(@view(bcresid_prototype[1:prod(cache.resid_size[1])]), safe_similar(y, cache.M * (N - 1)), - @view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])) + @view(bcresid_prototype[(prod(cache.resid_size[1]) + 1):end])) diffmode = if jac_alg.diffmode isa AutoSparse sparse_jacobian_prototype = __generate_sparse_jacobian_prototype( cache, cache.problem_type, - @view(cache.bcresid_prototype[1:prod(cache.resid_size[1])]), - @view(cache.bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]), - cache.M, N) + @view(bcresid_prototype[1:prod(cache.resid_size[1])]), + @view(bcresid_prototype[(prod(cache.resid_size[1]) + 1):end]), cache.M, N) AutoSparse(get_dense_ad(jac_alg.diffmode); sparsity_detector = ADTypes.KnownJacobianSparsityDetector(sparse_jacobian_prototype), coloring_algorithm = __default_coloring_algorithm(jac_alg.diffmode)) @@ -574,8 +738,8 @@ function __construct_problem(cache::MIRKCache{iip}, y, loss_bc::BC, loss_colloca resid_prototype = copy(resid) return __construct_internal_problem( - cache.prob, cache.problem_type, cache.alg, loss, jac, - jac_prototype, resid_prototype, y, cache.p, cache.M, N) + cache.prob, cache.problem_type, cache.alg, loss, jac, jac_prototype, + resid_prototype, bcresid_prototype, f_prototype, y, cache.p, cache.M, N) end function __mirk_2point_jacobian!(J, x, diffmode, diffcache, loss_fn::L, resid, p) where {L} diff --git a/lib/BoundaryValueDiffEqMIRK/test/dynamic_optimization_tests.jl b/lib/BoundaryValueDiffEqMIRK/test/dynamic_optimization_tests.jl new file mode 100644 index 000000000..963f7ad03 --- /dev/null +++ b/lib/BoundaryValueDiffEqMIRK/test/dynamic_optimization_tests.jl @@ -0,0 +1,57 @@ +@testitem "Rocket launching problem" begin + using BoundaryValueDiffEqMIRK, OptimizationIpopt + h_0 = 1 # Initial height + v_0 = 0 # Initial velocity + m_0 = 1.0 # Initial mass + m_T = 0.6 # Final mass + g_0 = 1 # Gravity at the surface + h_c = 500 # Used for drag + c = 0.5 * sqrt(g_0 * h_0) # Thrust-to-fuel mass + D_c = 0.5 * 620 * m_0 / g_0 # Drag scaling + u_t_max = 3.5 * g_0 * m_0 # Maximum thrust + T_max = 0.2 # Number of seconds + T = 1_000 # Number of time steps + Δt = 0.2 / T; # Time per discretized step + + tspan = (0.0, 0.2) + drag(x_h, x_v) = D_c * x_v^2 * exp(-h_c * (x_h - h_0) / h_0) + g(x_h) = g_0 * (h_0 / x_h)^2 + function rocket_launch!(du, u, p, t) + # u_t is the control variable (thrust) + x_v, x_h, x_m, u_t = u[1], u[2], u[3], u[4] + du[1] = (u_t-drag(x_h, x_v))/x_m - g(x_h) + du[2] = x_v + du[3] = -u_t/c + end + function rocket_launch_bc!(res, u, p, t) + res[1] = u(0.0)[1] - v_0 + res[2] = u(0.0)[2] - h_0 + res[3] = u(0.0)[3] - m_0 + res[4] = u(0.2)[4] - 0.0 + end + function rocket_launch_bc_a!(res, ua, p) + res[1] = ua[1] - v_0 + res[2] = ua[2] - h_0 + res[3] = ua[3] - m_0 + end + function rocket_launch_bc_b!(res, ub, p) + res[1] = ub[4] - 0.0 + end + cost_fun(u, p) = -u[end - 2] #Final altitude x_h. To minimize, only temporary, need to use temporary solution interpolation here similar to what we do in boundary condition evaluations. + u0 = [v_0, h_0, m_T, 3.0] + rocket_launch_fun_mp = BVPFunction( + rocket_launch!, rocket_launch_bc!; cost = cost_fun, f_prototype = zeros(3)) + rocket_launch_prob_mp = BVProblem(rocket_launch_fun_mp, u0, tspan; + lb = [0.0, h_0, m_T, 0.0], ub = [Inf, Inf, m_0, u_t_max]) + sol = solve(rocket_launch_prob_mp, MIRK4(; optimize = IpoptOptimizer()); dt = Δt, adaptive = false) + @test SciMLBase.successful_retcode(sol) + + rocket_launch_fun_tp = BVPFunction( + rocket_launch!, (rocket_launch_bc_a!, rocket_launch_bc_b!); + cost = cost_fun, f_prototype = zeros(3), + bcresid_prototype = (zeros(3), zeros(1)), twopoint = Val(true)) + rocket_launch_prob_tp = TwoPointBVProblem(rocket_launch_fun_tp, u0, tspan; + lb = [0.0, h_0, m_T, 0.0], ub = [Inf, Inf, m_0, u_t_max]) + sol = solve(rocket_launch_prob_tp, MIRK4(; optimize = IpoptOptimizer()); dt = Δt, adaptive = false) + @test SciMLBase.successful_retcode(sol) +end