@@ -107,9 +107,163 @@ end
107107 maybe_codegen_scimlproblem (expression, SteadyStateProblem{iip}, args; kwargs... )
108108end
109109
110+ @fallback_iip_specialize function SemilinearODEFunction {iip, specialize} (
111+ sys:: System ; u0 = nothing , p = nothing , t = nothing ,
112+ semiquadratic_form = nothing ,
113+ stiff_linear = true , stiff_quadratic = false , stiff_nonlinear = false ,
114+ eval_expression = false , eval_module = @__MODULE__ ,
115+ expression = Val{false }, sparse = false , check_compatibility = true ,
116+ jac = false , checkbounds = false , cse = true , initialization_data = nothing ,
117+ analytic = nothing , kwargs... ) where {iip, specialize}
118+ check_complete (sys, SemilinearODEFunction)
119+ check_compatibility && check_compatible_system (SemilinearODEFunction, sys)
120+
121+ if semiquadratic_form === nothing
122+ semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
123+ sys = add_semiquadratic_parameters (sys, semiquadratic_form... )
124+ end
125+
126+ A, B, C = semiquadratic_form
127+ M = calculate_massmatrix (sys)
128+ _M = concrete_massmatrix (M; sparse, u0)
129+ dvs = unknowns (sys)
130+
131+ f1,
132+ f2 = generate_semiquadratic_functions (
133+ sys, A, B, C; stiff_linear, stiff_quadratic,
134+ stiff_nonlinear, expression, wrap_gfw = Val{true },
135+ eval_expression, eval_module, kwargs... )
136+
137+ if jac
138+ Cjac = (C === nothing || ! stiff_nonlinear) ? nothing : Symbolics. jacobian (C, dvs)
139+ _jac = generate_semiquadratic_jacobian (
140+ sys, A, B, C, Cjac; sparse, expression,
141+ wrap_gfw = Val{true }, eval_expression, eval_module, kwargs... )
142+ _W_sparsity = get_semiquadratic_W_sparsity (
143+ sys, A, B, C, Cjac; stiff_linear, stiff_quadratic, stiff_nonlinear, mm = M)
144+ W_prototype = calculate_W_prototype (_W_sparsity; u0, sparse)
145+ else
146+ _jac = nothing
147+ W_prototype = nothing
148+ end
149+
150+ observedfun = ObservedFunctionCache (
151+ sys; expression, steady_state = false , eval_expression, eval_module, checkbounds, cse)
152+
153+ args = (; f1)
154+ kwargs = (; jac = _jac, jac_prototype = W_prototype)
155+ f1 = maybe_codegen_scimlfn (expression, ODEFunction{iip, specialize}, args; kwargs... )
156+
157+ args = (; f1, f2)
158+ kwargs = (;
159+ sys = sys,
160+ jac = _jac,
161+ mass_matrix = _M,
162+ jac_prototype = W_prototype,
163+ observed = observedfun,
164+ analytic,
165+ initialization_data)
166+
167+ return maybe_codegen_scimlfn (
168+ expression, SplitFunction{iip, specialize}, args; kwargs... )
169+ end
170+
171+ @fallback_iip_specialize function SemilinearODEProblem {iip, spec} (
172+ sys:: System , op, tspan; check_compatibility = true , u0_eltype = nothing ,
173+ expression = Val{false }, callback = nothing , sparse = false ,
174+ stiff_linear = true , stiff_quadratic = false , stiff_nonlinear = false ,
175+ jac = false , kwargs... ) where {
176+ iip, spec}
177+ check_complete (sys, SemilinearODEProblem)
178+ check_compatibility && check_compatible_system (SemilinearODEProblem, sys)
179+
180+ A, B, C = semiquadratic_form = calculate_semiquadratic_form (sys; sparse)
181+ eqs = equations (sys)
182+ dvs = unknowns (sys)
183+
184+ sys = add_semiquadratic_parameters (sys, A, B, C)
185+ if A != = nothing
186+ linear_matrix_param = unwrap (getproperty (sys, LINEAR_MATRIX_PARAM_NAME))
187+ else
188+ linear_matrix_param = nothing
189+ end
190+ if B != = nothing
191+ quadratic_forms = [unwrap (getproperty (sys, get_quadratic_form_name (i)))
192+ for i in 1 : length (eqs)]
193+ diffcache_par = unwrap (getproperty (sys, DIFFCACHE_PARAM_NAME))
194+ else
195+ quadratic_forms = diffcache_par = nothing
196+ end
197+
198+ op = to_varmap (op, dvs)
199+ floatT = calculate_float_type (op, typeof (op))
200+ _u0_eltype = something (u0_eltype, floatT)
201+
202+ guess = copy (guesses (sys))
203+ defs = copy (defaults (sys))
204+ if A != = nothing
205+ guess[linear_matrix_param] = fill (NaN , size (A))
206+ defs[linear_matrix_param] = A
207+ end
208+ if B != = nothing
209+ for (par, mat) in zip (quadratic_forms, B)
210+ guess[par] = fill (NaN , size (mat))
211+ defs[par] = mat
212+ end
213+ cachelen = jac ? length (dvs) * length (eqs) : length (dvs)
214+ defs[diffcache_par] = DiffCache (zeros (DiffEqBase. value (_u0_eltype), cachelen))
215+ end
216+ @set! sys. guesses = guess
217+ @set! sys. defaults = defs
218+
219+ f, u0,
220+ p = process_SciMLProblem (SemilinearODEFunction{iip, spec}, sys, op;
221+ t = tspan != = nothing ? tspan[1 ] : tspan, expression, check_compatibility,
222+ semiquadratic_form, sparse, u0_eltype, stiff_linear, stiff_quadratic, stiff_nonlinear, jac, kwargs... )
223+
224+ kwargs = process_kwargs (sys; expression, callback, kwargs... )
225+
226+ args = (; f, u0, tspan, p)
227+ maybe_codegen_scimlproblem (expression, SplitODEProblem{iip}, args; kwargs... )
228+ end
229+
230+ """
231+ $(TYPEDSIGNATURES)
232+
233+ Add the necessary parameters for [`SemilinearODEProblem`](@ref) given the matrices
234+ `A`, `B`, `C` returned from [`calculate_semiquadratic_form`](@ref).
235+ """
236+ function add_semiquadratic_parameters (sys:: System , A, B, C)
237+ eqs = equations (sys)
238+ n = length (eqs)
239+ var_to_name = copy (get_var_to_name (sys))
240+ if B != = nothing
241+ for i in eachindex (B)
242+ B[i] === nothing && continue
243+ par = get_quadratic_form_param ((n, n), i)
244+ var_to_name[get_quadratic_form_name (i)] = par
245+ sys = with_additional_constant_parameter (sys, par)
246+ end
247+ par = get_diffcache_param (Float64)
248+ var_to_name[DIFFCACHE_PARAM_NAME] = par
249+ sys = with_additional_nonnumeric_parameter (sys, par)
250+ end
251+ if A != = nothing
252+ par = get_linear_matrix_param ((n, n))
253+ var_to_name[LINEAR_MATRIX_PARAM_NAME] = par
254+ sys = with_additional_constant_parameter (sys, par)
255+ end
256+ @set! sys. var_to_name = var_to_name
257+ if get_parent (sys) != = nothing
258+ @set! sys. parent = add_semiquadratic_parameters (get_parent (sys), A, B, C)
259+ end
260+ return sys
261+ end
262+
110263function check_compatible_system (
111264 T:: Union {Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
112- Type{DAEProblem}, Type{SteadyStateProblem}},
265+ Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
266+ Type{SemilinearODEProblem}},
113267 sys:: System )
114268 check_time_dependent (sys, T)
115269 check_not_dde (sys)
0 commit comments