From b8bd7ab8b395a59b5120a8d8368997ce9c31898d Mon Sep 17 00:00:00 2001 From: ChrisRackauckas Date: Sun, 3 Aug 2025 11:23:59 -0400 Subject: [PATCH] Move RecipesBase and RuntimeGeneratedFunctions to extensions MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit moves RecipesBase and RuntimeGeneratedFunctions from direct dependencies to weak dependencies with corresponding extensions, reducing the load time footprint of SciMLBase. Changes: - Moved RecipesBase from deps to weakdeps in Project.toml - Moved RuntimeGeneratedFunctions from deps to weakdeps in Project.toml - Created SciMLBaseRecipesBaseExt.jl extension containing all @recipe definitions - Created SciMLBaseRuntimeGeneratedFunctionsExt.jl extension with numargs method - Removed RecipesBase import from main SciMLBase.jl module - Removed RuntimeGeneratedFunctions import from main SciMLBase.jl module - Removed all @recipe function definitions from original source files - Removed numargs method for RuntimeGeneratedFunctions from utils.jl The plotting functionality is now only available when RecipesBase is explicitly loaded, maintaining backward compatibility while reducing the default dependency footprint. The RuntimeGeneratedFunctions support for numargs is similarly conditional. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude --- Project.toml | 6 +- ext/SciMLBaseRecipesBaseExt.jl | 500 +++++++++++++++++++ ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl | 20 + src/SciMLBase.jl | 3 +- src/ensemble/ensemble_solutions.jl | 69 --- src/integrator_interface.jl | 129 ----- src/solutions/solution_interface.jl | 168 ------- src/utils.jl | 13 - 8 files changed, 524 insertions(+), 384 deletions(-) create mode 100644 ext/SciMLBaseRecipesBaseExt.jl create mode 100644 ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl diff --git a/Project.toml b/Project.toml index 92b029a3a8..ddbef19308 100644 --- a/Project.toml +++ b/Project.toml @@ -22,10 +22,8 @@ Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SciMLStructures = "53ae85a6-f571-4167-b2af-e1d143709226" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" @@ -41,6 +39,7 @@ PartialFunctions = "570af359-4316-4cb7-8c74-252c00c2016b" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PythonCall = "6099a3de-0909-46bc-b1f4-468b9a2dfc0d" RCall = "6f49c342-dc21-5d91-9882-a32aef131414" +RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] @@ -51,6 +50,8 @@ SciMLBasePartialFunctionsExt = "PartialFunctions" SciMLBasePyCallExt = "PyCall" SciMLBasePythonCallExt = "PythonCall" SciMLBaseRCallExt = "RCall" +SciMLBaseRecipesBaseExt = "RecipesBase" +SciMLBaseRuntimeGeneratedFunctionsExt = "RuntimeGeneratedFunctions" SciMLBaseZygoteExt = ["Zygote", "ChainRulesCore"] [compat] @@ -81,7 +82,6 @@ Printf = "1.10" PyCall = "1.96" PythonCall = "0.9.15" RCall = "0.14.0" -RecipesBase = "1.3.4" RecursiveArrayTools = "3.35" Reexport = "1" RuntimeGeneratedFunctions = "0.5.12" diff --git a/ext/SciMLBaseRecipesBaseExt.jl b/ext/SciMLBaseRecipesBaseExt.jl new file mode 100644 index 0000000000..5028f9dd73 --- /dev/null +++ b/ext/SciMLBaseRecipesBaseExt.jl @@ -0,0 +1,500 @@ +module SciMLBaseRecipesBaseExt + +using SciMLBase +using RecipesBase +import RecursiveArrayTools + +# Need to import the plotting-related functions +import SciMLBase: DEFAULT_PLOT_FUNC, isdenseplot, plottable_indices, interpret_vars, + get_all_timeseries_indexes, ContinuousTimeseries, DiscreteTimeseries, + solution_slice, add_labels!, AbstractTimeseriesSolution, AbstractEnsembleSolution, + AbstractNoTimeSolution, EnsembleSummary, DEIntegrator, AbstractSDEIntegrator, + getindepsym_defaultt, getname, hasname, u_n, AbstractDEAlgorithm + +# Recipe for AbstractTimeseriesSolution +@recipe function f(sol::AbstractTimeseriesSolution; + plot_analytic = false, + denseplot = isdenseplot(sol), + plotdensity = min(Int(1e5), + sol.tslocation == 0 ? + (sol.prob isa SciMLBase.AbstractDiscreteProblem ? + max(1000, 100 * length(sol)) : + max(1000, 10 * length(sol))) : + 1000 * sol.tslocation), plotat = nothing, + tspan = nothing, + vars = nothing, idxs = nothing) + if vars !== nothing + Base.depwarn( + "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", + :f; force = true) + (idxs !== nothing) && + error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") + idxs = vars + end + + if plot_analytic && (sol.u_analytic === nothing) + throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`.")) + end + + idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs + if !(idxs isa Union{Tuple, AbstractArray}) + vars = interpret_vars([idxs], sol) + else + vars = interpret_vars(idxs, sol) + end + disc_vars = Tuple[] + cont_vars = Tuple[] + for var in vars + tsidxs = union(get_all_timeseries_indexes(sol, var[2]), + get_all_timeseries_indexes(sol, var[3])) + if ContinuousTimeseries() in tsidxs || isempty(tsidxs) + push!(cont_vars, var) + else + push!(disc_vars, var) + end + end + + plot_vecs = [] + labels = [] + + # Handle continuous variables + if !isempty(cont_vars) + int_vars = cont_vars + + if tspan === nothing + if plotat === nothing + if denseplot + # Generate the points from the plot from dense function + start_idx = sol.tslocation == 0 ? 1 : sol.tslocation + end_idx = length(sol.t) + plott = collect(range(sol.t[start_idx], sol.t[end_idx]; length = plotdensity)) + plot_timeseries = sol(plott) + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) + for t in plott] + end + else + plot_timeseries = sol.u + plott = sol.t + if plot_analytic + plot_analytic_timeseries = sol.u_analytic + end + end + else + plot_timeseries = sol(plotat) + plott = plotat + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) for t in plott] + end + end + else + _tspan = tspan isa Number ? (sol.t[1], tspan) : tspan + start_idx = findfirst(x -> x >= _tspan[1], sol.t) + end_idx = findlast(x -> x <= _tspan[2], sol.t) + if denseplot + plott = collect(range(_tspan...; length = plotdensity)) + plot_timeseries = sol(plott) + if plot_analytic + plot_analytic_timeseries = [sol.prob.f.analytic(sol.prob.u0, + sol.prob.p, t) for t in plott] + end + else + if start_idx === nothing + start_idx = 1 + end + if end_idx === nothing + end_idx = length(sol.t) + end + plott = @view sol.t[start_idx:end_idx] + plot_timeseries = @view sol.u[start_idx:end_idx] + if plot_analytic + plot_analytic_timeseries = @view sol.u_analytic[start_idx:end_idx] + end + end + end + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + # Should check that all have the same dims! + + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) + strs = String[] + varsyms = SciMLBase.variable_symbols(sol) + + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], SciMLBase.getindepsym_defaultt(sol)) + push!(plot_vecs[j - 1], plott) + else + # For the dense plotting case, we use getindex on the timeseries + if plot_timeseries isa AbstractArray + if x[j] isa Integer + # Simple integer indexing + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing + push!(plot_vecs[j - 1], Vector(sol(plott; idxs = x[j]))) + end + else + # Single value case + push!(plot_vecs[j - 1], Vector(sol(plott; idxs = x[j]))) + end + end + else # just get values + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && !(eltype(plot_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_timeseries) + else + if x[j] isa Integer + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing + push!(plot_vecs[j - 1], [sol(t, idxs = x[j]) for t in plott]) + end + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, sol, strs) + end + + if plot_analytic + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], SciMLBase.getindepsym_defaultt(sol)) + push!(plot_vecs[j - 1], plott) + else + push!(plot_vecs[j - 1], + u_n(plot_analytic_timeseries, x[j], sol, plott, + plot_analytic_timeseries)) + end + else # Just get values + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && + !(eltype(plot_analytic_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_analytic_timeseries) + else + push!(plot_vecs[j - 1], + u_n(plot_analytic_timeseries, x[j], sol, plott, + plot_analytic_timeseries)) + end + end + end + add_labels!(labels, x, dims, sol, strs) + end + end + + xflip --> sol.tdir < 0 + + if denseplot + seriestype --> :path + else + seriestype --> :scatter + end + + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] .. + if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) + xlabel --> idxs[1] + ylabel --> idxs[2] + if length(idxs) > 2 + zlabel --> idxs[3] + end + end + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + #xtickfont --> font(11) + #ytickfont --> font(11) + #legendfont --> font(11) + #guidefont --> font(11) + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) + + # Handle discrete variables + elseif !isempty(disc_vars) + int_vars = disc_vars + + if sol.tslocation != 0 + start_idx = sol.tslocation + else + start_idx = 1 + end + + if tspan === nothing + end_idx = length(sol.t) + else + _tspan = tspan isa Number ? (sol.t[1], tspan) : tspan + end_idx = findlast(x -> x <= _tspan[2], sol.t) + if end_idx === nothing + end_idx = length(sol.t) + end + end + + plott = sol.t[start_idx:end_idx] + plot_timeseries = sol.u[start_idx:end_idx] + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[] + strs = String[] + varsyms = SciMLBase.variable_symbols(sol) + + for x in int_vars + for j in 2:dims + if x[j] == 0 + push!(plot_vecs[j - 1], plott) + elseif x[j] == 1 && !(eltype(plot_timeseries) <: AbstractArray) + push!(plot_vecs[j - 1], plot_timeseries) + else + if x[j] isa Integer + push!(plot_vecs[j - 1], [u[x[j]] for u in plot_timeseries]) + else + # Symbol indexing for discrete case + push!(plot_vecs[j - 1], [sol[ti, x[j]] for ti in 1:length(plott)]) + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, sol, strs) + end + + seriestype --> :steppost + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) + end +end + +# Recipe for AbstractEnsembleSolution +@recipe function f(sim::AbstractEnsembleSolution; idxs = nothing, + summarize = true, error_style = :ribbon, ci_type = :quantile, linealpha = 0.4, zorder = 1) + + if idxs === nothing + if sim.u[1] isa SciMLBase.AbstractTimeseriesSolution + idxs = 1:length(sim.u[1].u[1]) + else + idxs = 1:length(sim.u[1]) + end + end + + if !(idxs isa Union{Tuple, AbstractArray}) + idxs = [idxs] + end + + if summarize + summary = EnsembleSummary(sim; quantiles = [0.05, 0.95]) + if error_style == :ribbon + ribbon --> (summary.qlow[:, idxs], summary.qhigh[:, idxs]) + elseif error_style == :bars + yerror --> (summary.qlow[:, idxs], summary.qhigh[:, idxs]) + end + summary.t, summary.med[:, idxs] + else + alpha --> linealpha + # Plot all trajectories + for i in eachindex(sim.u) + @series begin + if sim.u[i] isa SciMLBase.AbstractTimeseriesSolution + idxs --> idxs + sim.u[i] + else + # For non-timeseries solutions + sim.u[i][idxs] + end + end + end + end +end + +# Recipe for EnsembleSummary +@recipe function f(sim::EnsembleSummary; idxs = nothing, error_style = :ribbon) + if idxs === nothing + idxs = 1:size(sim.med, 2) + end + + if !(idxs isa Union{Tuple, AbstractArray}) + idxs = [idxs] + end + + if error_style == :ribbon + ribbon --> (sim.qlow[:, idxs], sim.qhigh[:, idxs]) + elseif error_style == :bars + yerror --> (sim.qlow[:, idxs], sim.qhigh[:, idxs]) + end + sim.t, sim.med[:, idxs] +end + +# Recipe for DEIntegrator +@recipe function f(integrator::DEIntegrator; + denseplot = (integrator.opts.calck || + integrator isa AbstractSDEIntegrator) && + integrator.iter > 0, + plotdensity = 10, + plot_analytic = false, vars = nothing, idxs = nothing) + if vars !== nothing + Base.depwarn( + "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", + :f; force = true) + (idxs !== nothing) && + error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") + idxs = vars + end + + int_vars = interpret_vars(idxs, integrator.sol) + + if denseplot + # Generate the points from the plot from dense function + plott = collect(range(integrator.tprev, integrator.t; length = plotdensity)) + if plot_analytic + plot_analytic_timeseries = [integrator.sol.prob.f.analytic( + integrator.sol.prob.u0, + integrator.sol.prob.p, + t) for t in plott] + end + else + plott = nothing + end + + dims = length(int_vars[1]) + for var in int_vars + @assert length(var) == dims + end + # Should check that all have the same dims! + + plot_vecs = [] + for i in 2:dims + push!(plot_vecs, []) + end + + labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) + strs = String[] + varsyms = SciMLBase.variable_symbols(integrator) + + for x in int_vars + for j in 2:dims + if denseplot + if (x[j] isa Integer && x[j] == 0) || + isequal(x[j], getindepsym_defaultt(integrator)) + push!(plot_vecs[j - 1], plott) + else + push!(plot_vecs[j - 1], Vector(integrator(plott; idxs = x[j]))) + end + else # just get values + if x[j] == 0 + push!(plot_vecs[j - 1], integrator.t) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) + push!(plot_vecs[j - 1], integrator.u) + else + push!(plot_vecs[j - 1], integrator.u[x[j]]) + end + end + + if !isempty(varsyms) && x[j] isa Integer + push!(strs, String(getname(varsyms[x[j]]))) + elseif hasname(x[j]) + push!(strs, String(getname(x[j]))) + else + push!(strs, "u[$(x[j])]") + end + end + add_labels!(labels, x, dims, integrator.sol, strs) + end + + if plot_analytic + for x in int_vars + for j in 1:dims + if denseplot + push!(plot_vecs[j], + u_n(plot_timeseries, x[j], sol, plott, plot_timeseries)) + else # Just get values + if x[j] == 0 + push!(plot_vecs[j], integrator.t) + elseif x[j] == 1 && !(integrator.u isa AbstractArray) + push!(plot_vecs[j], + integrator.sol.prob.f(Val{:analytic}, integrator.t, + integrator.sol[1])) + else + push!(plot_vecs[j], + integrator.sol.prob.f(Val{:analytic}, integrator.t, + integrator.sol[1])[x[j]]) + end + end + end + add_labels!(labels, x, dims, integrator.sol, strs) + end + end + + xflip --> integrator.tdir < 0 + + if denseplot + seriestype --> :path + else + seriestype --> :scatter + end + + # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] .. + if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) + xlabel --> idxs[1] + ylabel --> idxs[2] + if length(idxs) > 2 + zlabel --> idxs[3] + end + end + if getindex.(int_vars, 1) == zeros(length(int_vars)) || + getindex.(int_vars, 2) == zeros(length(int_vars)) + xlabel --> "t" + end + + linewidth --> 3 + #xtickfont --> font(11) + #ytickfont --> font(11) + #legendfont --> font(11) + #guidefont --> font(11) + label --> reshape(labels, 1, length(labels)) + (plot_vecs...,) +end + +end \ No newline at end of file diff --git a/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl new file mode 100644 index 0000000000..8e5eb1c565 --- /dev/null +++ b/ext/SciMLBaseRuntimeGeneratedFunctionsExt.jl @@ -0,0 +1,20 @@ +module SciMLBaseRuntimeGeneratedFunctionsExt + +using SciMLBase +using RuntimeGeneratedFunctions + +function SciMLBase.numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ + T, + V, + W, + I +}) where { + T, + V, + W, + I +} + (length(T),) +end + +end \ No newline at end of file diff --git a/src/SciMLBase.jl b/src/SciMLBase.jl index 49d9e246d8..06fb190691 100644 --- a/src/SciMLBase.jl +++ b/src/SciMLBase.jl @@ -4,7 +4,7 @@ if isdefined(Base, :Experimental) && @eval Base.Experimental.@max_methods 1 end using ConstructionBase -using RecipesBase, RecursiveArrayTools +using RecursiveArrayTools using SciMLStructures using SymbolicIndexingInterface using DocStringExtensions @@ -19,7 +19,6 @@ import Logging, ArrayInterface import IteratorInterfaceExtensions import CommonSolve: solve, init, step!, solve! import FunctionWrappersWrappers -import RuntimeGeneratedFunctions import EnumX import ADTypes: ADTypes, AbstractADType import Accessors: @set, @reset, @delete, @insert diff --git a/src/ensemble/ensemble_solutions.jl b/src/ensemble/ensemble_solutions.jl index ab25d11813..ffcb46ea7d 100644 --- a/src/ensemble/ensemble_solutions.jl +++ b/src/ensemble/ensemble_solutions.jl @@ -184,76 +184,7 @@ end ### Plot Recipes -@recipe function f(sim::AbstractEnsembleSolution; - zcolors = sim.u isa AbstractArray ? fill(nothing, length(sim.u)) : - nothing, - trajectories = eachindex(sim)) - for i in trajectories - size(sim.u[i].u, 1) == 0 && continue - @series begin - legend := false - xlims --> (-Inf, Inf) - ylims --> (-Inf, Inf) - zlims --> (-Inf, Inf) - marker_z --> zcolors[i] - sim.u[i] - end - end -end -@recipe function f(sim::EnsembleSummary; - idxs = sim.u.u[1] isa AbstractArray ? eachindex(sim.u.u[1]) : - 1, - error_style = :ribbon, ci_type = :quantile) - if ci_type == :SEM - if sim.u.u[1] isa AbstractArray - u = vecarr_to_vectors(sim.u) - else - u = [sim.u.u] - end - if sim.u.u[1] isa AbstractArray - ci_low = vecarr_to_vectors(VectorOfArray([sqrt.(sim.v.u[i] / sim.num_monte) .* - 1.96 for i in 1:length(sim.v)])) - ci_high = ci_low - else - ci_low = [[sqrt(sim.v.u[i] / length(sim.num_monte)) .* 1.96 - for i in 1:length(sim.v)]] - ci_high = ci_low - end - elseif ci_type == :quantile - if sim.med.u[1] isa AbstractArray - u = vecarr_to_vectors(sim.med) - else - u = [sim.med.u] - end - if sim.u.u[1] isa AbstractArray - ci_low = u - vecarr_to_vectors(sim.qlow) - ci_high = vecarr_to_vectors(sim.qhigh) - u - else - ci_low = [u[1] - sim.qlow.u] - ci_high = [sim.qhigh.u - u[1]] - end - else - error("ci_type choice not valid. Must be `:SEM` or `:quantile`") - end - for i in idxs - @series begin - legend --> false - linewidth --> 3 - fillalpha --> 0.2 - if error_style == :ribbon - ribbon --> (ci_low[i], ci_high[i]) - elseif error_style == :bars - yerror --> (ci_low[i], ci_high[i]) - elseif error_style == :none - nothing - else - error("error_style not recognized") - end - sim.t, u[i] - end - end -end function (sol::AbstractEnsembleSolution)(args...; kwargs...) [s(args...; kwargs...) for s in sol] diff --git a/src/integrator_interface.jl b/src/integrator_interface.jl index 3b13762858..becaaf7767 100644 --- a/src/integrator_interface.jl +++ b/src/integrator_interface.jl @@ -773,135 +773,6 @@ end Base.length(iter::TimeChoiceIterator) = length(iter.ts) -@recipe function f(integrator::DEIntegrator; - denseplot = (integrator.opts.calck || - integrator isa AbstractSDEIntegrator) && - integrator.iter > 0, - plotdensity = 10, - plot_analytic = false, vars = nothing, idxs = nothing) - if vars !== nothing - Base.depwarn( - "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", - :f; force = true) - (idxs !== nothing) && - error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") - idxs = vars - end - - int_vars = interpret_vars(idxs, integrator.sol) - - if denseplot - # Generate the points from the plot from dense function - plott = collect(range(integrator.tprev, integrator.t; length = plotdensity)) - if plot_analytic - plot_analytic_timeseries = [integrator.sol.prob.f.analytic( - integrator.sol.prob.u0, - integrator.sol.prob.p, - t) for t in plott] - end - else - plott = nothing - end - - dims = length(int_vars[1]) - for var in int_vars - @assert length(var) == dims - end - # Should check that all have the same dims! - - plot_vecs = [] - for i in 2:dims - push!(plot_vecs, []) - end - - labels = String[]# Array{String, 2}(1, length(int_vars)*(1+plot_analytic)) - strs = String[] - varsyms = variable_symbols(integrator) - @show plott - - for x in int_vars - for j in 2:dims - if denseplot - if (x[j] isa Integer && x[j] == 0) || - isequal(x[j], getindepsym_defaultt(integrator)) - push!(plot_vecs[j - 1], plott) - else - push!(plot_vecs[j - 1], Vector(integrator(plott; idxs = x[j]))) - end - else # just get values - if x[j] == 0 - push!(plot_vecs[j - 1], integrator.t) - elseif x[j] == 1 && !(integrator.u isa AbstractArray) - push!(plot_vecs[j - 1], integrator.u) - else - push!(plot_vecs[j - 1], integrator.u[x[j]]) - end - end - - if !isempty(varsyms) && x[j] isa Integer - push!(strs, String(getname(varsyms[x[j]]))) - elseif hasname(x[j]) - push!(strs, String(getname(x[j]))) - else - push!(strs, "u[$(x[j])]") - end - end - add_labels!(labels, x, dims, integrator.sol, strs) - end - - if plot_analytic - for x in int_vars - for j in 1:dims - if denseplot - push!(plot_vecs[j], - u_n(plot_timeseries, x[j], sol, plott, plot_timeseries)) - else # Just get values - if x[j] == 0 - push!(plot_vecs[j], integrator.t) - elseif x[j] == 1 && !(integrator.u isa AbstractArray) - push!(plot_vecs[j], - integrator.sol.prob.f(Val{:analytic}, integrator.t, - integrator.sol[1])) - else - push!(plot_vecs[j], - integrator.sol.prob.f(Val{:analytic}, integrator.t, - integrator.sol[1])[x[j]]) - end - end - end - add_labels!(labels, x, dims, integrator.sol, strs) - end - end - - xflip --> integrator.tdir < 0 - - if denseplot - seriestype --> :path - else - seriestype --> :scatter - end - - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && (typeof(idxs[1]) == Symbol && typeof(idxs[2]) == Symbol) - xlabel --> idxs[1] - ylabel --> idxs[2] - if length(idxs) > 2 - zlabel --> idxs[3] - end - end - if getindex.(int_vars, 1) == zeros(length(int_vars)) || - getindex.(int_vars, 2) == zeros(length(int_vars)) - xlabel --> "t" - end - - linewidth --> 3 - #xtickfont --> font(11) - #ytickfont --> font(11) - #legendfont --> font(11) - #guidefont --> font(11) - label --> reshape(labels, 1, length(labels)) - (plot_vecs...,) -end function step!(integ::DEIntegrator, dt, stop_at_tdt = false) (dt * integ.tdir) < 0 * oneunit(dt) && error("Cannot step backward.") diff --git a/src/solutions/solution_interface.jl b/src/solutions/solution_interface.jl index f08a73ce9e..835ab7cba8 100644 --- a/src/solutions/solution_interface.jl +++ b/src/solutions/solution_interface.jl @@ -212,174 +212,6 @@ used for plotting. plottable_indices(x::AbstractArray) = 1:length(x) plottable_indices(x::Number) = 1 -@recipe function f(sol::AbstractTimeseriesSolution; - plot_analytic = false, - denseplot = isdenseplot(sol), - plotdensity = min(Int(1e5), - sol.tslocation == 0 ? - (sol.prob isa AbstractDiscreteProblem ? - max(1000, 100 * length(sol)) : - max(1000, 10 * length(sol))) : - 1000 * sol.tslocation), plotat = nothing, - tspan = nothing, - vars = nothing, idxs = nothing) - if vars !== nothing - Base.depwarn( - "To maintain consistency with solution indexing, keyword argument vars will be removed in a future version. Please use keyword argument idxs instead.", - :f; force = true) - (idxs !== nothing) && - error("Simultaneously using keywords vars and idxs is not supported. Please only use idxs.") - idxs = vars - end - - if plot_analytic && (sol.u_analytic === nothing) - throw(ArgumentError("No analytic solution was found but `plot_analytic` was set to `true`.")) - end - - idxs = idxs === nothing ? plottable_indices(sol.u[1]) : idxs - if !(idxs isa Union{Tuple, AbstractArray}) - vars = interpret_vars([idxs], sol) - else - vars = interpret_vars(idxs, sol) - end - disc_vars = Tuple[] - cont_vars = Tuple[] - for var in vars - tsidxs = union(get_all_timeseries_indexes(sol, var[2]), - get_all_timeseries_indexes(sol, var[3])) - if ContinuousTimeseries() in tsidxs || isempty(tsidxs) - push!(cont_vars, var) - else - push!(disc_vars, (var..., only(tsidxs))) - end - end - idxs = identity.(cont_vars) - vars = identity.(cont_vars) - tdir = sign(sol.t[end] - sol.t[1]) - xflip --> tdir < 0 - seriestype --> :path - - @series begin - if idxs isa Union{AbstractArray, Tuple} && isempty(idxs) - label --> nothing - ([], []) - else - tscale = get(plotattributes, :xscale, :identity) - plot_vecs, - labels = diffeq_to_arrays(sol, plot_analytic, denseplot, - plotdensity, tspan, vars, tscale, plotat) - - # Special case labels when idxs = (:x,:y,:z) or (:x) or [:x,:y] ... - if idxs isa Tuple && vars[1][1] === DEFAULT_PLOT_FUNC - val = hasname(vars[1][2]) ? String(getname(vars[1][2])) : vars[1][2] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - xguide --> val - val = hasname(vars[1][3]) ? String(getname(vars[1][3])) : vars[1][3] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - yguide --> val - if length(idxs) > 2 - val = hasname(vars[1][4]) ? String(getname(vars[1][4])) : vars[1][4] - if val isa Integer - if val == 0 - val = "t" - else - val = "u[$val]" - end - end - zguide --> val - end - end - - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 1))) && - getindex.(vars, 1) == zeros(length(vars))) || - (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 1)) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - xguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 3 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 3))) && - getindex.(vars, 3) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 3))) - yguide --> "$(getindepsym_defaultt(sol))" - end - if length(vars[1]) >= 4 && - ((!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 4))) && - getindex.(vars, 4) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 4))) - zguide --> "$(getindepsym_defaultt(sol))" - end - - if (!any(!isequal(NotSymbolic()), symbolic_type.(getindex.(vars, 2))) && - getindex.(vars, 2) == zeros(length(vars))) || - all(t -> Symbol(t) == getindepsym_defaultt(sol), getindex.(vars, 2)) - if tspan === nothing - if tdir > 0 - xlims --> (sol.t[1], sol.t[end]) - else - xlims --> (sol.t[end], sol.t[1]) - end - else - xlims --> (tspan[1], tspan[end]) - end - end - - label --> reshape(labels, 1, length(labels)) - (plot_vecs...,) - end - end - for (func, xvar, yvar, tsidx) in disc_vars - partition = sol.discretes[tsidx] - ts = current_time(partition) - if tspan !== nothing - tstart = searchsortedfirst(ts, tspan[1]) - tend = searchsortedlast(ts, tspan[2]) - if tstart == lastindex(ts) + 1 || tend == firstindex(ts) - 1 - continue - end - else - tstart = firstindex(ts) - tend = lastindex(ts) - end - ts = ts[tstart:tend] - - if symbolic_type(xvar) == NotSymbolic() && xvar == 0 - xvar = only(independent_variable_symbols(sol)) - end - xvals = sol(ts; idxs = xvar).u - # xvals = getsym(sol, xvar)(sol, tstart:tend) - yvals = getp(sol, yvar)(sol, tstart:tend) - tmpvals = map(func, xvals, yvals) - xvals = getindex.(tmpvals, 1) - yvals = getindex.(tmpvals, 2) - # Scatterplot of points - @series begin - seriestype := :line - linestyle --> :dash - markershape --> :o - markersize --> repeat([2, 0], length(ts) - 1) - markeralpha --> repeat([1, 0], length(ts) - 1) - label --> string(hasname(yvar) ? getname(yvar) : yvar) - - x = vec([xvals[1:(end - 1)]'; xvals[2:end]']) - y = repeat(yvals, inner = 2)[1:(end - 1)] - x, y - end - end -end function diffeq_to_arrays(sol, plot_analytic, denseplot, plotdensity, tspan, vars, tscale, plotat) diff --git a/src/utils.jl b/src/utils.jl index ecded5af15..35b892706c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,19 +13,6 @@ function numargs(f) end end -function numargs(f::RuntimeGeneratedFunctions.RuntimeGeneratedFunction{ - T, - V, - W, - I -}) where { - T, - V, - W, - I -} - (length(T),) -end numargs(f::ComposedFunction) = numargs(f.inner)