diff --git a/Project.toml b/Project.toml index 7d1f81085..2d903f380 100644 --- a/Project.toml +++ b/Project.toml @@ -11,7 +11,6 @@ CommonWorldInvalidations = "f70d9fcc-98c5-4d4a-abd7-e4cdeebd8ca8" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" -Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" @@ -22,9 +21,13 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" +MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" +MutableArithmetics = "d8a4904e-b15c-11e9-3269-09a3773c0cb0" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -42,6 +45,7 @@ TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" [weakdeps] D3Trees = "e3df1716-f71e-5df9-9e2d-98e193103c45" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Groebner = "0b43b601-686d-58a3-8a1c-6623616c7cd4" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -52,6 +56,7 @@ SymPyPythonCall = "bc8888f7-b21e-4b7c-a06a-5d9c9496438c" [extensions] SymbolicsD3TreesExt = "D3Trees" +SymbolicsDistributionsExt = "Distributions" SymbolicsForwardDiffExt = "ForwardDiff" SymbolicsGroebnerExt = "Groebner" SymbolicsLuxExt = "Lux" @@ -84,12 +89,16 @@ Latexify = "0.16" LogExpFunctions = "0.3" Lux = "1" MacroTools = "0.5" +Moshi = "0.3.7" +MultivariatePolynomials = "0.5.12" +MutableArithmetics = "1.6.5" NaNMath = "1" Nemo = "0.46, 0.47, 0.48, 0.49, 0.52" OffsetArrays = "1.15.0" PkgBenchmark = "0.2" PreallocationTools = "0.4" PrecompileTools = "1" +Preferences = "1.5.0" Primes = "0.5" RecipesBase = "1.1" Reexport = "1" diff --git a/ext/SymbolicsDistributionsExt.jl b/ext/SymbolicsDistributionsExt.jl new file mode 100644 index 000000000..9334bc95f --- /dev/null +++ b/ext/SymbolicsDistributionsExt.jl @@ -0,0 +1,37 @@ +module SymbolicsDistributionsExt + +using Symbolics +using Symbolics: Num, Arr, VartypeT, unwrap, BasicSymbolic +using Distributions + + +for f in [pdf, logpdf, cdf, logcdf, quantile] + @eval function (::$(typeof(f)))(dist::Distributions.Distribution, x::Num) + $f(dist, unwrap(x)) + end + @eval function (::$(typeof(f)))(dist::Distributions.Distribution, x::Arr) + $f(dist, unwrap(x)) + end + @eval function (::$(typeof(f)))(dist::BasicSymbolic{VartypeT}, x::Num) + $f(dist, unwrap(x)) + end + @eval function (::$(typeof(f)))(dist::BasicSymbolic{VartypeT}, x::Arr) + $f(dist, unwrap(x)) + end + @eval function (::$(typeof(f)))(dist::BasicSymbolic{VartypeT}, x) where {T} + $f(dist, unwrap(x)) + end +end + +for f in [Distributions.Uniform, Distributions.Normal] + for (T1, T2) in Iterators.product(Iterators.repeated([Real, BasicSymbolic{VartypeT}, Num], 2)...) + if T1 != Num && T2 != Num + continue + end + @eval function (::Type{$f})(a::$T1, b::$T2) + $f(unwrap(a), unwrap(b)) + end + end +end + +end diff --git a/ext/SymbolicsGroebnerExt.jl b/ext/SymbolicsGroebnerExt.jl index 8f142558f..1be9e2790 100644 --- a/ext/SymbolicsGroebnerExt.jl +++ b/ext/SymbolicsGroebnerExt.jl @@ -5,19 +5,20 @@ const Nemo = Groebner.Nemo using Symbolics using Symbolics: Num, symtype, BasicSymbolic import Symbolics.PrecompileTools +import Symbolics.Bijections +import Symbolics: SymbolicUtils +import Symbolics: DP, MP function Symbolics.groebner_basis(polynomials::Vector{Num}; ordering=InputOrdering(), kwargs...) - polynoms, pvar2sym, sym2term = Symbolics.symbol_to_poly(polynomials) - sym2term_for_groebner = Dict{Any,Any}(v1 => k for (k, (v1, v2)) in sym2term) - all_sym_vars = Groebner.ordering_variables(ordering) - missed = setdiff(all_sym_vars, Set(collect(keys(sym2term_for_groebner)))) - for var in missed - sym2term_for_groebner[var] = var - end - ordering = Groebner.ordering_transform(ordering, sym2term_for_groebner ) - basis = Groebner.groebner(polynoms; ordering=ordering, kwargs...) - PolyType = symtype(first(polynomials)) - Symbolics.poly_to_symbol(basis, pvar2sym, sym2term, PolyType) + polynoms, poly_to_bs = Symbolics.symbol_to_poly(polynomials) + basis = groebner_basis_poly(polynoms, poly_to_bs; ordering, kwargs...) + Symbolics.poly_to_symbol(basis, poly_to_bs) +end + +function groebner_basis_poly(polynoms::Vector{<:DP.Polynomial}, poly_to_bs::Bijections.Bijection; ordering=InputOrdering(), kwargs...) + bs_to_poly = Bijections.active_inv(poly_to_bs) + ordering = Groebner.ordering_transform(ordering, bs_to_poly) + return Groebner.groebner(polynoms; ordering=ordering, kwargs...) end """ @@ -41,8 +42,8 @@ julia> @variables x y; julia> is_groebner_basis([x^2 - y^2, x*y^2 + x, y^3 + y]) ``` """ -function Symbolics.is_groebner_basis(polynomials::Vector{<:Union{Num, BasicSymbolic{<:Number}}}; kwargs...) - polynoms, _, _ = Symbolics.symbol_to_poly(polynomials) +function Symbolics.is_groebner_basis(polynomials::Vector{<:Union{Num, BasicSymbolic}}; kwargs...) + polynoms, _ = Symbolics.symbol_to_poly(polynomials) Groebner.isgroebner(polynoms; kwargs...) end @@ -66,7 +67,13 @@ function nemo_crude_evaluate(poly::Nemo.MPolyRingElem, varmap) end function nemo_crude_evaluate(poly::Nemo.FracElem, varmap) - nemo_crude_evaluate(numerator(poly), varmap) // nemo_crude_evaluate(denominator(poly), varmap) + num = nemo_crude_evaluate(numerator(poly), varmap) + den = nemo_crude_evaluate(denominator(poly), varmap) + if num isa Num || den isa Num + num / den + else + num // den + end end function nemo_crude_evaluate(poly::Nemo.ZZRingElem, varmap) @@ -86,10 +93,10 @@ function gen_separating_var(vars) end # Given a GB in k[params][vars] produces a GB in k(params)[vars] -function demote(gb, vars::Vector{Num}, params::Vector{Num}) - isequal(gb, [1]) && return gb - - gb = Symbolics.wrap.(SymbolicUtils.toterm.(gb)) +function demote(gb, gb_as_poly, bs_to_poly, vars::Vector{Num}, params::Vector{Num}) + length(gb) == 1 && SymbolicUtils._isone(gb[1]) && return gb + isequal(gb, [1]) && return gb + # gb = Symbolics.wrap.(SymbolicUtils.toterm.(gb)) Symbolics.check_polynomial.(gb) all_vars = [vars..., params...] @@ -97,16 +104,16 @@ function demote(gb, vars::Vector{Num}, params::Vector{Num}) sym_to_nemo = Dict(all_vars .=> nemo_all_vars) nemo_to_sym = Dict(v => k for (k, v) in sym_to_nemo) - nemo_gb = Symbolics.substitute(gb, sym_to_nemo) - nemo_gb = Symbolics.substitute(nemo_gb, sym_to_nemo) + pvar_to_nemo = [bs_to_poly[v] for v in all_vars] => nemo_all_vars + nemo_gb = [poly(pvar_to_nemo) for poly in gb_as_poly] - nemo_vars = filter(v -> string(v) in string.(vars), nemo_all_vars) - nemo_params = filter(v -> string(v) in string.(params), nemo_all_vars) + nemo_vars = view(nemo_all_vars, 1:length(vars)) + nemo_params = view(nemo_all_vars, length(vars)+1:length(all_vars)) ring_flat = parent(nemo_vars[1]) ring_param, params_demoted = Nemo.polynomial_ring(Nemo.base_ring(ring_flat), map(string, nemo_params)) ring_demoted, vars_demoted = Nemo.polynomial_ring(Nemo.fraction_field(ring_param), map(string, nemo_vars), internal_ordering=:lex) - varmap = Dict((nemo_vars .=> vars_demoted)..., (nemo_params .=> params_demoted)...) + varmap = Dict(vcat(nemo_vars .=> vars_demoted, nemo_params .=> params_demoted)) gb_demoted = map(f -> ring_demoted(nemo_crude_evaluate(f, varmap)), nemo_gb) result = empty(gb_demoted) while true @@ -124,7 +131,7 @@ function demote(gb, vars::Vector{Num}, params::Vector{Num}) end @assert all(f -> isone(Nemo.leading_coefficient(f)), result) - sym_to_nemo = Dict(sym => nem for sym in all_vars for nem in [vars_demoted..., params_demoted...] if isequal(string(sym),string(nem))) + sym_to_nemo = Dict(all_vars .=> [vars_demoted; params_demoted]) nemo_to_sym = Dict(v => k for (k, v) in sym_to_nemo) final_result = Num[] @@ -150,8 +157,7 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa # AAECC 9, 433–461 (1999). https://doi.org/10.1007/s002000050114 rng = Groebner.Random.Xoshiro(42) - - all_indeterminates = reduce(union, map(Symbolics.get_variables, eqs)) + all_indeterminates = collect(reduce(union!, map(Symbolics.get_variables, eqs))) params = map(Symbolics.Num ∘ Symbolics.wrap, setdiff(all_indeterminates, vars)) # Use a new variable to separate the input polynomials (Reference above) @@ -179,10 +185,12 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa push!(new_eqs, separating_form) - new_eqs = Symbolics.groebner_basis(new_eqs, ordering=Lex(vcat(vars, params))) + polynoms, poly_to_bs = Symbolics.symbol_to_poly(new_eqs) + basis = groebner_basis_poly(polynoms, poly_to_bs; ordering = Lex(vcat(vars, params))) + new_eqs = Symbolics.poly_to_symbol(basis, poly_to_bs) # handle "unsolvable" case - if isequal(1, new_eqs[1]) + if SymbolicUtils._iszero(new_eqs[1]) return [] end @@ -190,12 +198,11 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa all_present = Symbolics.get_variables(new_eqs[i]) if length(intersect(all_present, vars)) < 1 deleteat!(new_eqs, i) + deleteat!(basis, i) end end - - new_eqs = demote(new_eqs, vars, params) + new_eqs = demote(new_eqs, basis, Bijections.active_inv(poly_to_bs), vars, params) new_eqs = map(Symbolics.unwrap, new_eqs) - # condition for positive dimensionality, i.e. infinite solutions if length(new_eqs) < length(vars) warns && @warn("Infinite number of solutions") @@ -209,13 +216,15 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa # xn - fn(T, params) = 0 generating = !(length(new_eqs) == length(vars)) if length(new_eqs) == length(vars) - generating |= !(isequal(setdiff(Symbolics.get_variables(new_eqs[1]), params), [new_var])) + vars_in_1 = Symbolics.get_variables(new_eqs[1]) + setdiff!(vars_in_1, params) + generating |= !(length(vars_in_1) == 1 && isequal(first(vars_in_1), new_var)) for i in eachindex(new_eqs)[2:end] present_vars = setdiff(Symbolics.get_variables(new_eqs[i]), new_var) present_vars = setdiff(present_vars, params) isempty(present_vars) && (generating = false; break;) - var_i = present_vars[1] - condition1 = isequal(present_vars, [var_i]) + var_i = first(present_vars) + condition1 = length(present_vars) == 1 condition2 = Symbolics.degree(new_eqs[i], var_i) == 1 generating |= !(condition1 && condition2) end @@ -234,7 +243,9 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa # first, solve the first minimal polynomial @assert length(new_eqs) == length(vars) - @assert isequal(setdiff(Symbolics.get_variables(new_eqs[1]), params), [new_var]) + vars_in_1 = Symbolics.get_variables(new_eqs[1]) + setdiff!(vars_in_1, params) + @assert length(vars_in_1) == 1 && isequal(first(vars_in_1), new_var) minpoly_sols = Symbolics.symbolic_solve(Symbolics.wrap(new_eqs[1]), new_var, dropmultiplicity=dropmultiplicity) solutions = [Dict{Num, Any}(new_var => sol) for sol in minpoly_sols] @@ -243,10 +254,10 @@ function solve_zerodim(eqs::Vector, vars::Vector{Num}; dropmultiplicity=true, wa # second, iterate over eqs and sub each found solution # then add the roots of the remaining unknown variables for (i, eq) in enumerate(new_eqs) - present_vars = setdiff(Symbolics.get_variables(eq), params) - present_vars = setdiff(present_vars, new_var) + present_vars = setdiff!(Symbolics.get_variables(eq), params) + present_vars = setdiff!(present_vars, new_var) @assert length(present_vars) == 1 - var_tosolve = present_vars[1] + var_tosolve = first(present_vars) @assert Symbolics.degree(eq, var_tosolve) == 1 @assert !isempty(solutions) for roots in solutions @@ -270,7 +281,7 @@ end function transendence_basis(sys, vars) J = Symbolics.jacobian(sys, vars) x0 = Dict(v => rand(-10:10) for v in vars) - J_x0 = substitute(J, x0) + J_x0 = map(Symbolics.value, substitute(J, x0)) rk, rref = Nemo.rref(Nemo.matrix(Nemo.QQ, J_x0)) pivots = Int[] for i in 1:length(sys) @@ -287,7 +298,7 @@ function Symbolics.solve_multivar(eqs::Vector, vars::Vector{Num}; dropmultiplici isempty(tr_basis) && return nothing vars_gen = setdiff(vars, tr_basis) sol = solve_zerodim(eqs, vars_gen; dropmultiplicity=dropmultiplicity, warns=warns) - + sol === nothing && return nothing for roots in sol for x in tr_basis roots[x] = x diff --git a/ext/SymbolicsLuxExt.jl b/ext/SymbolicsLuxExt.jl index 26eedaead..e114c4ed0 100644 --- a/ext/SymbolicsLuxExt.jl +++ b/ext/SymbolicsLuxExt.jl @@ -7,13 +7,20 @@ using Lux.Random: AbstractRNG, default_rng using Symbolics.SymbolicUtils @static if isdefined(Lux.NilSizePropagation, :recursively_nillify) - function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic{<:Vector{<:Real}}) + function Lux.NilSizePropagation.recursively_nillify(x::SymbolicUtils.BasicSymbolic) + @assert SymbolicUtils.symtype(x) <: Vector{<:Real} Lux.NilSizePropagation.recursively_nillify(Symbolics.wrap(x)) end end -function LuxCore.outputsize(model::SymbolicUtils.BasicSymbolic{<:LuxCore.AbstractLuxLayer}, x::Symbolics.Arr, rng::AbstractRNG) - LuxCore.outputsize(Symbolics.getdefaultval(model), x, rng) +function LuxCore.outputsize(model::SymbolicUtils.BasicSymbolic, x::Symbolics.Arr, rng::AbstractRNG) + @assert SymbolicUtils.symtype(model) <: LuxCore.AbstractLuxLayer + concrete_model = if SymbolicUtils.isconst(model) + SymbolicUtils.unwrap_const(model) + else + Symbolics.getdefaultval(model) + end + LuxCore.outputsize(concrete_model, x, rng) end @register_array_symbolic LuxCore.stateless_apply( diff --git a/ext/SymbolicsNemoExt.jl b/ext/SymbolicsNemoExt.jl index 9c9f31db1..886f5abc7 100644 --- a/ext/SymbolicsNemoExt.jl +++ b/ext/SymbolicsNemoExt.jl @@ -1,6 +1,7 @@ module SymbolicsNemoExt using Nemo import Symbolics.PrecompileTools +import Symbolics.Bijections if isdefined(Base, :get_extension) using Symbolics @@ -39,17 +40,19 @@ end function Symbolics.factor_use_nemo(poly::Num) Symbolics.check_polynomial(poly) Symbolics.degree(poly) == 0 && return poly, Num[] - vars = Symbolics.get_variables(poly) + mp_polys, poly_to_bs = Symbolics.symbol_to_poly([poly]) + mp_poly = only(mp_polys) + vars = collect(Symbolics.get_variables(poly)) + bs_to_poly = Bijections.active_inv(poly_to_bs) + poly_vars = map(Base.Fix1(getindex, bs_to_poly), vars) nemo_ring, nemo_vars = Nemo.polynomial_ring(Nemo.QQ, map(string, vars)) - sym_to_nemo = Dict(vars .=> nemo_vars) - nemo_to_sym = Dict(v => k for (k, v) in sym_to_nemo) - nemo_poly = Symbolics.substitute(poly, sym_to_nemo) + nemo_poly = mp_poly(poly_vars => nemo_vars) nemo_fac = Nemo.factor(nemo_poly) nemo_unit = Nemo.unit(nemo_fac) nemo_factors = collect(keys(nemo_fac.fac)) sym_unit = Rational(Nemo.coeff(nemo_unit, 1)) + nemo_to_sym = Dict(nemo_vars .=> vars) sym_factors = map(f -> Symbolics.wrap(nemo_crude_evaluate(f, nemo_to_sym)), nemo_factors) - for (i, fac) in enumerate(sym_factors) sym_factors[i] = fac^(collect(values(nemo_fac.fac))[i]) end @@ -58,19 +61,19 @@ function Symbolics.factor_use_nemo(poly::Num) end # Helps with precompilation time -PrecompileTools.@setup_workload begin - @variables a b c x y z - expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a)) - equation1 = a*log(x)^b + c ~ 0 - equation_polynomial = 9^x + 3^x + 2 - exp_eq = 5*2^(x+1) + 7^(x+3) - PrecompileTools.@compile_workload begin - symbolic_solve(equation1, x) - symbolic_solve(equation_polynomial, x) - symbolic_solve(exp_eq) - symbolic_solve(expr_with_params, x, dropmultiplicity=false) - symbolic_solve(x^10 - a^10, x, dropmultiplicity=false) - end -end +# PrecompileTools.@setup_workload begin +# @variables a b c x y z +# expr_with_params = expand((x + b)*(x^2 + 2x + 1)*(x^2 - a)) +# equation1 = a*log(x)^b + c ~ 0 +# equation_polynomial = 9^x + 3^x + 2 +# exp_eq = 5*2^(x+1) + 7^(x+3) +# PrecompileTools.@compile_workload begin +# symbolic_solve(equation1, x) +# symbolic_solve(equation_polynomial, x) +# symbolic_solve(exp_eq) +# symbolic_solve(expr_with_params, x, dropmultiplicity=false) +# symbolic_solve(x^10 - a^10, x, dropmultiplicity=false) +# end +# end end # module diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 17ecb3d80..df1a712ea 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -26,9 +26,10 @@ import DomainSets: Domain, DomainSets using TermInterface import TermInterface: maketerm, iscall, operation, arguments, metadata -import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic, -FnType, @rule, Rewriters, substitute, symtype, -promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv +import SymbolicUtils: Term, Add, Mul, Sym, Div, BasicSymbolic, Const, + FnType, @rule, Rewriters, substitute, symtype, shape, unwrap, unwrap_const, + promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv, BSImpl, scalarize, + Operator, _iszero, _isone, search_variables, search_variables! using SymbolicUtils.Code @@ -54,6 +55,23 @@ RuntimeGeneratedFunctions.init(@__MODULE__) import SciMLPublic: @public +using Moshi.Match: @match + +import Preferences: @load_preference + +const DEFAULT_VARTYPE_PREF = @load_preference("vartype", "SymReal") +const VartypeT = @static if DEFAULT_VARTYPE_PREF == "SymReal" + SymReal +elseif DEFAULT_VARTYPE_PREF == "SafeReal" + SafeReal +elseif DEFAULT_VARTYPE_PREF == "TreeReal" + TreeReal +else + error(""" + Invalid vartype preference: $DEFAULT_VARTYPE_PREF. Must be one of "SymReal", \ + "SafeReal" or "TreeReal". + """) +end # re-export export simplify, substitute @@ -90,13 +108,17 @@ sqrt(2) """ substitute -export Equation, ConstrainedEquation +export Equation include("equations.jl") export Inequality, ≲, ≳ include("inequality.jl") import Bijections, DynamicPolynomials +import DynamicPolynomials as DP +import MultivariatePolynomials as MP +import MutableArithmetics as MA + export tosymbol, terms, factors include("utils.jl") @@ -124,10 +146,6 @@ export SymbolicsSparsityDetector include("adtypes.jl") -export Difference, DiscreteUpdate - -include("difference.jl") - export infimum, supremum include("domains.jl") @@ -152,7 +170,6 @@ import Libdl include("build_function.jl") export build_function -import Distributions include("extra_functions.jl") using Latexify @@ -169,7 +186,6 @@ include("parsing.jl") export parse_expr_to_symbolic include("error_hints.jl") -include("struct.jl") include("operators.jl") include("limits.jl") @@ -177,7 +193,7 @@ export limit # Hacks to make wrappers "nicer" const NumberTypes = Union{AbstractFloat,Integer,Complex{<:AbstractFloat},Complex{<:Integer}} -(::Type{T})(x::SymbolicUtils.Symbolic) where {T<:NumberTypes} = throw(ArgumentError("Cannot convert Sym to $T since Sym is symbolic and $T is concrete. Use `substitute` to replace the symbolic unwraps.")) +(::Type{T})(x::SymbolicUtils.BasicSymbolic) where {T<:NumberTypes} = throw(ArgumentError("Cannot convert Sym to $T since Sym is symbolic and $T is concrete. Use `substitute` to replace the symbolic unwraps.")) for T in [Num, Complex{Num}] @eval begin #(::Type{S})(x::$T) where {S<:Union{NumberTypes,AbstractArray}} = S(Symbolics.unwrap(x))::S @@ -196,9 +212,7 @@ for T in [Num, Complex{Num}] SymbolicUtils.hasmetadata(x::$T, t) = SymbolicUtils.hasmetadata(unwrap(x), t) Broadcast.broadcastable(x::$T) = x - end - for S in [:(Symbolic{<:FnType}), :CallWithMetadata] - @eval (f::$S)(x::$T, y...) = wrap(f(unwrap(x), unwrap.(y)...)) + SymbolicUtils.scalarize(x::$T) = scalarize(unwrap(x)) end end @@ -545,7 +559,7 @@ include("inverse.jl") export rootfunction, left_continuous_function, right_continuous_function, @register_discontinuity include("discontinuities.jl") -@public Arr, CallWithMetadata, NAMESPACE_SEPARATOR, Unknown, VariableDefaultValue, VariableSource +@public Arr, NAMESPACE_SEPARATOR, Unknown, VariableDefaultValue, VariableSource @public _parse_vars, derivative, gradient, jacobian, sparsejacobian, hessian, sparsehessian @public get_variables, get_variables!, get_differential_vars, getparent, option_to_metadata_type, scalarize, shape @public unwrap, variable, wrap diff --git a/src/array-lib.jl b/src/array-lib.jl index 7748347e8..147e9c810 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -14,201 +14,39 @@ struct GetindexPosthookCtx end setmetadata(x, GetindexPosthookCtx, f) end end -function Base.getindex(x::SymArray, idx::CartesianIndex) - return x[Tuple(idx)...] -end - -function Base.getindex(x::SymArray, idx...) - idx = unwrap.(idx) - meta = metadata(unwrap(x)) - if iscall(x) && (op = operation(x)) isa Operator - args = arguments(x) - return op(only(args)[idx...]) - elseif shape(x) !== Unknown() && all(i -> i isa Integer, idx) - II = CartesianIndices(axes(x)) - ii = CartesianIndex(idx) - @boundscheck begin - if !in(ii, II) - throw(BoundsError(x, idx)) - end - end - res = Term{eltype(symtype(x))}(getindex, [x, Tuple(ii)...]; metadata = meta) - elseif all(i -> symtype(i) <: Integer, idx) - shape(x) !== Unknown() && @boundscheck begin - if length(idx) > 1 - for (a, i) in zip(axes(x), idx) - if i isa Integer && !(i in a) - throw(BoundsError(x, idx)) - end - end - end - end - res = Term{eltype(symtype(x))}(getindex, [x, idx...]; metadata = meta) - elseif length(idx) == 1 && symtype(first(idx)) <: CartesianIndex - i = first(idx) - ii = i isa CartesianIndex ? Tuple(i) : arguments(i) - - return getindex(x, ii...) - else - input_idx = [] - output_idx = [] - ranges = Dict{BasicSymbolic,AbstractRange}() - subscripts = makesubscripts(length(idx)) - for (j, i) in enumerate(idx) - if symtype(i) <: Integer - push!(input_idx, i) - elseif i isa Colon - push!(output_idx, subscripts[j]) - push!(input_idx, subscripts[j]) - elseif i isa AbstractVector - isym = subscripts[j] - push!(output_idx, isym) - push!(input_idx, isym) - ranges[isym] = i - else - error("Don't know how to index by $i") - end - end - - term = Term{Any}(getindex, [x, idx...]; metadata = meta) - T = eltype(symtype(x)) - N = ndims(x) - count(i -> symtype(i) <: Integer, idx) - res = ArrayOp(atype(symtype(x)){T,N}, - (output_idx...,), - x[input_idx...], - +, - term, - ranges) - end - - if hasmetadata(x, GetindexPosthookCtx) - f = getmetadata(x, GetindexPosthookCtx, identity) - f(res, x, idx...) - else - res - end -end # Wrapped array should wrap the elements too function Base.getindex(x::Arr, idx...) wrap(unwrap(x)[idx...]) end -function Base.getindex(x::Arr, idx::Symbolic{<:Integer}...) - wrap(unwrap(x)[idx...]) -end -function Base.getindex(x::Arr, I::Symbolic{CartesianIndex}) - wrap(unwrap(x)[tup(I)...]) -end -Base.getindex(I::Symbolic{CartesianIndex}, i::Integer) = tup(I)[i] - -function Base.getindex(A::AbstractArray{T}, I::Symbolic{CartesianIndex}) where {T} - term(getindex, A, tup(I)..., type=T) -end - -function Base.CartesianIndex(x::Symbolic{<:Integer}, xs::Symbolic{<:Integer}...) - term(CartesianIndex, x, xs..., type=CartesianIndex) -end - - -import Base: +, -, * -tup(c::CartesianIndex) = Tuple(c) -tup(c::Symbolic{CartesianIndex}) = iscall(c) ? arguments(c) : error("Cartesian index not found") - -@wrapped function -(x::CartesianIndex, y::CartesianIndex) - CartesianIndex((tup(x) .- tup(y))...) -end -@wrapped function +(x::CartesianIndex, y::CartesianIndex) - CartesianIndex((tup(x) .+ tup(y))...) +const SymIdxT = Union{Num, BasicSymbolic{VartypeT}} +function Base.getindex(x::Arr, idx::SymIdxT, idxs...) + wrap(unwrap(x)[idx, idxs...]) end - -@wrapped function *(x::CartesianIndex, y::CartesianIndex) - CartesianIndex((tup(x) .* tup(y))...) +function Base.getindex(x::Arr, i1, idx::SymIdxT, idxs...) + wrap(unwrap(x)[i1, idx, idxs...]) end - -@wrapped function *(a::Integer, x::CartesianIndex) - CartesianIndex((a * tup(x))...) +function Base.getindex(x::Arr, i1::SymIdxT, idx::SymIdxT, idxs...) + wrap(unwrap(x)[i1, idx, idxs...]) end - -@wrapped function *(x::CartesianIndex, b::Integer) - CartesianIndex((tup(x) * b)...) +function Base.getindex(x::Arr, i1, i2, idx::SymIdxT, idxs...) + wrap(unwrap(x)[i1, i2, idx, idxs...]) end - - -function propagate_ndims(::typeof(getindex), x, idx...) - ndims(x) - count(x -> symtype(x) <: Integer, idx) +function Base.getindex(x::Arr, i1, i2::SymIdxT, idx::SymIdxT, idxs...) + wrap(unwrap(x)[i1, i2, idx, idxs...]) end - -function propagate_shape(::typeof(getindex), x, idx...) - @oops axes = shape(x) - - idx1 = to_indices(CartesianIndices(axes), axes, idx) - ([1:length(x) for x in idx1 if !(symtype(x) <: Number)]...,) -end - -propagate_eltype(::typeof(getindex), x, idx...) = geteltype(x) - -function SymbolicUtils.promote_symtype(::typeof(getindex), X, ii...) - @assert all(i -> i <: Integer, ii) - eltype(X) +function Base.getindex(x::Arr, i1::SymIdxT, i2::SymIdxT, idx::SymIdxT, idxs...) + wrap(unwrap(x)[i1, i2, idx, idxs...]) end +import Base: +, -, * #### Broadcast #### -# - -using Base.Broadcast - -Base.broadcastable(s::SymArray) = s -struct SymBroadcast <: Broadcast.BroadcastStyle end -Broadcast.BroadcastStyle(::Type{<:SymArray}) = SymBroadcast() -Broadcast.result_style(::SymBroadcast) = SymBroadcast() -Broadcast.BroadcastStyle(::SymBroadcast, ::Broadcast.BroadcastStyle) = SymBroadcast() - -isonedim(x, i) = shape(x) == Unknown() ? false : isone(size(x, i)) - -function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast}) - # Do the thing here - args = inner_unwrap.(bc.args) - ndim = mapfoldl(ndims, max, args, init=0) - subscripts = makesubscripts(ndim) - - onedim_count = mapreduce(+, args) do x - if ndims(x) != 0 - map(i -> isonedim(x, i) ? 1 : 0, 1:ndim) - else - map(i -> 1, 1:ndim) - end - end - - extruded = map(x -> x < length(args), onedim_count) - - expr_args′ = map(args) do x - if ndims(x) != 0 - subs = map(i -> extruded[i] && isonedim(x, i) ? - 1 : subscripts[i], 1:ndims(x)) - x[subs...] - elseif x isa Base.RefValue - x[] - else - x - end - end - expr = term(bc.f, expr_args′...) # Imagine x .=> y -- if you don't have a term - # then you get pairs, and index matcher cannot - # recurse into pairs - Atype = propagate_atype(broadcast, bc.f, args...) - args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, args) - ArrayOp(Atype{symtype(expr),ndim}, - (subscripts...,), - expr, - +, - Term{Any}(broadcast, [bc.f, args...])) -end # On wrapper: struct SymWrapBroadcast <: Broadcast.BroadcastStyle end -Base.broadcastable(s::Arr) = s +Broadcast.broadcastable(s::Arr) = s Broadcast.BroadcastStyle(::Type{<:Arr}) = SymWrapBroadcast() @@ -216,185 +54,75 @@ Broadcast.result_style(::SymWrapBroadcast) = SymWrapBroadcast() Broadcast.BroadcastStyle(::SymWrapBroadcast, ::Broadcast.BroadcastStyle) = SymWrapBroadcast() -Broadcast.BroadcastStyle(::SymBroadcast, +Broadcast.BroadcastStyle(::SymbolicUtils.SymBroadcast, ::SymWrapBroadcast) = Broadcast.Unknown() -function Broadcast.copy(bc::Broadcast.Broadcasted{SymWrapBroadcast}) - args = map(bc.args) do arg - if arg isa Broadcast.Broadcasted - return Broadcast.copy(arg) - else - return arg - end - end - wrap(broadcast(bc.f, map(unwrap, args)...)) +unwrap_broadcasts(head, args...) = (unwrap_broadcast(head), unwrap_broadcasts(args...)...) +unwrap_broadcasts() = () +unwrap_broadcast(x) = unwrap(x) +function unwrap_broadcast(bc::Broadcast.Broadcasted{SymWrapBroadcast}) + Broadcast.Broadcasted{SymbolicUtils.SymBroadcast{VartypeT}}(bc.f, unwrap_broadcasts(bc.args...), bc.axes) end +function Broadcast.copy(bc::Broadcast.Broadcasted{SymWrapBroadcast}) + return wrap(copy(unwrap_broadcast(bc))) +end -#################### TRANSPOSE ################ -# -@wrapped function Base.adjoint(A::AbstractMatrix) - @syms i::Int j::Int - @arrayop (i, j) A[j, i] term = A' -end false - -@wrapped function Base.adjoint(b::AbstractVector) - @syms i::Int - @arrayop (1, i) b[i] term = b' -end false - -import Base: *, \ - -using LinearAlgebra - -isdot(A, b) = isadjointvec(A) && ndims(b) == 1 - -isadjointvec(A::Adjoint) = ndims(parent(A)) == 1 -isadjointvec(A::Transpose) = ndims(parent(A)) == 1 +#################### POLYADIC ################ -function isadjointvec(A) - if iscall(A) - (operation(A) === (adjoint) || - operation(A) == (transpose)) && ndims(arguments(A)[1]) == 1 - else - false - end +function *(x::Arr, args...) + return wrap(*(unwrap(x), args...)) end -isadjointvec(A::ArrayOp) = isadjointvec(A.term) - -__symtype(x::Type{<:Symbolic{T}}) where T = T -function symeltype(A) - T = eltype(A) - T <: Symbolic ? __symtype(T) : T +function *(a::SymbolicUtils.PolyadicNumericOpFirstArgT, b::Arr, bs...) + return wrap(*(a, unwrap(b), bs...)) end -# TODO: add more such methods -function getindex(A::AbstractArray, i::Symbolic{<:Integer}, ii::Symbolic{<:Integer}...) - Term{symeltype(A)}(getindex, [A, i, ii...]) +function *(a::LinearAlgebra.Adjoint{T, <: AbstractVector}, b::Arr, bs...) where {T} + return wrap(*(a, unwrap(b), bs...)) end - -function getindex(A::AbstractArray, i::Int, j::Symbolic{<:Integer}) - Term{symeltype(A)}(getindex, [A, i, j]) +function *(a::LinearAlgebra.Adjoint{T, <: AbstractVector}, b::Arr, c::AbstractVector, bs...) where {T} + return wrap(*(a, unwrap(b), unwrap(c), bs...)) end - -function getindex(A::AbstractArray, j::Symbolic{<:Integer}, i::Int) - Term{symeltype(A)}(getindex, [A, j, i]) +function *(a::Number, b::Arr, bs...) + return wrap(*(unwrap(a), unwrap(b), bs...)) end - -function getindex(A::Arr, i::Int, j::Symbolic{<:Integer}) - wrap(unwrap(A)[i, j]) +function *(x1::Arr, x2::BasicSymbolic{VartypeT}, args...) + return wrap(*(unwrap(x1), x2, args...)) end - -function getindex(A::Arr, j::Symbolic{<:Integer}, i::Int) - wrap(unwrap(A)[j, i]) +function *(x1::Arr, x2::Arr, args...) + return wrap(*(unwrap(x1), unwrap(x2), args...)) end - -function _matmul(A, B) - A = inner_unwrap(A) - B = inner_unwrap(B) - @syms i::Int j::Int k::Int - if isadjointvec(A) - op = operation(A.term) - return op(op(B) * first(arguments(A.term))) - end - return @arrayop (i, j) A[i, k] * B[k, j] term = (A * B) +function *(x1::Arr, x2::AbstractMatrix, args...) + return wrap(*(unwrap(x1), x2, args...)) end - -@wrapped (*)(A::AbstractMatrix, B::AbstractMatrix) = _matmul(A, B) false -@wrapped (*)(A::AbstractVector, B::AbstractMatrix) = _matmul(A, B) false -# Resolve ambiguity with Base.*(::Adjoint{T, <:AbstractVector} where T, ::AbstractMatrix) -(*)(x::Adjoint{T, <:AbstractVector} where {T}, A::Symbolics.Arr{<:Any, 2}) = wrap(_matmul(unwrap(x), unwrap(A))) - -function _matvec(A, b) - A = inner_unwrap(A) - b = inner_unwrap(b) - @syms i::Int k::Int - sym_res = @arrayop (i,) A[i, k] * b[k] term=(A*b) - if isdot(A, b) - return sym_res[1] - else - return sym_res - end +function *(x1::Arr, x2::AbstractVector, args...) + return wrap(*(unwrap(x1), x2, args...)) end -@wrapped (*)(A::AbstractMatrix, b::AbstractVector) = _matvec(A, b) false - -# specialize `dot` to dispatch on `Symbolic{<:Number}` to eventually work for -# arrays of (possibly unwrapped) Symbolic types, see issue #831 -@wrapped LinearAlgebra.dot(x::Number, y::Number) = conj(x) * y false - -#################### MAP-REDUCE ################ -# - -@wrapped Base.map(f, x::AbstractArray) = _map(f, x) false -@wrapped Base.map(f, x::AbstractArray, xs...) = _map(f, x, xs...) false -@wrapped Base.map(f, x, y::AbstractArray, z...) = _map(f, x, y, z...) false -@wrapped Base.map(f, x, y, z::AbstractArray, w...) = _map(f, x, y, z, w...) false - -function _map(f, x, xs...) - N = ndims(x) - idx = makesubscripts(N) - x = inner_unwrap(x) - xs = inner_unwrap.(xs) - - expr = f(map(a -> a[idx...], [x, xs...])...) - - Atype = propagate_atype(map, f, x, xs...) - ArrayOp(Atype{symtype(expr),N}, - (idx...,), - expr, - +, - Term{Any}(map, [f, x, xs...])) +function *(x1::AbstractMatrix, x2::Arr, args...) + return wrap(*(x1, unwrap(x2), args...)) end - -@inline _mapreduce(f, g, x, dims, kw) = mapreduce(f, g, x; dims=dims, kw...) - -function scalarize_op(::typeof(_mapreduce), t) - f, g, x, dims, kw = arguments(t) - # we wrap and unwrap to make things work smoothly. - # we need the result unwrapped to allow recursive scalarize to work. - unwrap(_mapreduce(f, g, collect(wrap(x)), dims, kw)) +function *(x1::Arr, x2::Arr, x3::Arr, args...) + return wrap(*(unwrap(x1), unwrap(x2), unwrap(x3), args...)) +end +function *(x1::Arr, x2::Arr, x3::Arr, x4::Arr, args...) + return wrap(*(unwrap(x1), unwrap(x2), unwrap(x3), unwrap(x4), args...)) end -@wrapped function Base.mapreduce(f, g, x::AbstractArray; dims=:, kw...) - idx = makesubscripts(ndims(x)) - out_idx = [dims == (:) || i in dims ? 1 : idx[i] for i = 1:ndims(x)] - expr = f(x[idx...]) - T = symtype(g(expr, expr)) - if dims === (:) - return Term{T}(_mapreduce, [f, g, x, dims, (kw...,)]) - end - - Atype = propagate_atype(_mapreduce, f, g, x, dims, (kw...,)) - ArrayOp(Atype{T,ndims(x)}, - (out_idx...,), - expr, - g, - Term{Any}(_mapreduce, [f, g, x, dims, (kw...,)])) -end false - -for (ff, opts) in [ - any => (identity, (|), false), - all => (identity, (&), true)] +function +(x::Arr, args...) + return +(unwrap(x), args...) +end +function +(x::Arr, y::AbstractArray, args...) + return +(unwrap(x), y, args...) +end +function +(x1::Arr, x2::Arr, args...) + return +(unwrap(x1), unwrap(x2), args...) +end - f, g, init = opts - @eval @wrapped function (::$(typeof(ff)))(x::AbstractArray; - dims=:, init=$init) - mapreduce($f, $g, x, dims=dims, init=init) - end false - @eval @wrapped function (::$(typeof(ff)))(f::Function, x::AbstractArray; - dims=:, init=$init) - mapreduce(f, $g, x, dims=dims, init=init) - end false +function +(a::SymbolicUtils.PolyadicNumericOpFirstArgT, b::Arr, bs...) + return +(a, unwrap(b), bs...) end -for (ff, opts) in [sum => (identity, +), - prod => (identity, *)] +#################### MAP-REDUCE ################ - f, g = opts - @eval @wrapped function (::$(typeof(ff)))(x::AbstractArray; kw...) - mapreduce($f, $g, x; kw...) - end false - @eval @wrapped function (::$(typeof(ff)))(f::Function, x::AbstractArray; kw...) - mapreduce(f, $g, x; kw...) - end false -end +SymbolicUtils.@map_methods Arr unwrap wrap +SymbolicUtils.@mapreduce_methods Arr unwrap wrap diff --git a/src/arrays.jl b/src/arrays.jl index 4d60edd3e..ecb9ac34a 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -1,503 +1,26 @@ using SymbolicUtils -using SymbolicUtils: @capture using StaticArraysCore import Base: eltype, length, ndims, size, axes, eachindex -export @arrayop, ArrayMaker, @makearray, @setview, @setview! - -### Store Shape as a metadata in Term{<:AbstractArray} objects -struct ArrayShapeCtx end - - -#= - There are 2 types of array terms: - `ArrayOp{T<:AbstractArray}` and `Term{<:AbstractArray}` - - - ArrayOp represents a Einstein-notation-inspired array operation. - it contains a field `term` which is a `Term` that represents the - operation that resulted in the `ArrayOp`. - I.e. will be `A*b` for the operation `(i,) => A[i,j] * b[j]` for example. - It can be `nothing` if not known. - - calling `shape` on an `ArrayOp` will return the shape of the array or `Unknown()` - - do not rely on the `symtype` or `shape` information of the `.term` when looking at an `ArrayOp`. - call `shape`, `symtype` and `ndims` directly on the `ArrayOp`. - - `Term{<:AbstractArray}` - - calling `shape` on it will return the shape of the array or `Unknown()`, and uses - the `ArrayShapeCtx` metadata context to store this. - - The Array type parameter must contain the dimension. -=# - -#### ArrayOp #### - -""" - ArrayOp(output_idx, expr, reduce) - -A tensor expression where `output_idx` are the output indices -`expr`, is the tensor expression and `reduce` is the function -used to reduce over contracted dimensions. -""" -struct ArrayOp{T<:AbstractArray} <: Symbolic{T} - output_idx # output indices - expr # Used in pattern matching - # Useful to infer eltype - reduce - term - shape - ranges::Dict{BasicSymbolic, AbstractRange} # index range each index symbol can take, - # optional for each symbol - metadata -end - -function ArrayOp(T::Type, output_idx, expr, reduce, term, ranges=Dict(); metadata=nothing) - sh = make_shape(output_idx, unwrap(expr), ranges) - ArrayOp{T}(output_idx, unwrap(expr), reduce, term, sh, ranges, metadata) -end - -function ArrayOp(a::AbstractArray) - i = makesubscripts(ndims(a)) - # TODO: formalize symtype(::Type) then! - ArrayOp(symtype(a), (i...,), a[i...], +, a) -end - -ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T} - -function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, m) - args = map(args) do arg - if iscall(arg) && operation(arg) == Ref - inner = only(arguments(arg)) - if symbolic_type(inner) == NotSymbolic() - return Ref(inner) - else - return inner - end - else - return arg - end - end - - t = f(args...) - t isa Symbolic && !isnothing(m) ? - metadata(t, m) : t -end - -SymbolicUtils.sorted_arguments(s::ArrayOp) = sorted_arguments(s.term) - -shape(aop::ArrayOp) = aop.shape - -SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.ArrayOp}) = ArraySymbolic() - -const show_arrayop = Ref{Bool}(false) -function Base.show(io::IO, aop::ArrayOp) - if iscall(aop.term) && !show_arrayop[] - show(io, aop.term) - else - print(io, "@arrayop") - print(io, "(_[$(join(string.(aop.output_idx), ","))] := $(aop.expr))") - if aop.reduce != + - print(io, " ($(aop.reduce))") - end - - if !isempty(aop.ranges) - print(io, " ", join(["$k in $v" for (k, v) in aop.ranges], ", ")) - end - end -end - -Base.summary(io::IO, aop::ArrayOp) = Base.array_summary(io, aop, shape(aop)) -function Base.showarg(io::IO, aop::ArrayOp, toplevel) - show(io, aop) - toplevel && print(io, "::", typeof(aop)) - return nothing -end - -symtype(a::ArrayOp{T}) where {T} = T -iscall(a::ArrayOp) = true -function operation(a::ArrayOp) - isnothing(a.term) ? typeof(a) : operation(a.term) -end -function arguments(a::ArrayOp) - isnothing(a.term) ? [a.output_idx, a.expr, a.reduce, - a.term, a.shape, a.ranges, metadata(a)] : - arguments(a.term) -end - -function Base.isequal(a::ArrayOp, b::ArrayOp) - a === b && return true - isequal(a.shape, b.shape) && - isequal(a.ranges, b.ranges) && - isequal(a.output_idx, b.output_idx) && - isequal(a.reduce, b.reduce) && - isequal(operation(a), operation(b)) && - isequal(a.expr, b.expr) -end - -function Base.hash(a::ArrayOp, u::UInt) - hash(a.shape, hash(a.ranges, hash(a.expr, hash(a.output_idx, hash(operation(a), u))))) -end - -macro arrayop(output_idx, expr, options...) - rs = [] - reduce = + - call = nothing - - extra = [] - isexpr = MacroTools.isexpr - for o in options - if isexpr(o, :call) && o.args[1] == :in - push!(rs, :($(o.args[2]) => $(o.args[3]))) - elseif isexpr(o, :(=)) && o.args[1] == :reduce - reduce = o.args[2] - elseif isexpr(o, :(=)) && o.args[1] == :term - call = o.args[2] - else - push!(extra, o) - end - end - if length(extra) == 1 - @warn("@arrayop is deprecated, use @arrayop term= instead") - call = output_idx - output_idx = expr - expr = extra[1] - end - @assert output_idx.head == :tuple - - oidxs = filter(x->x isa Symbol, output_idx.args) - iidxs = find_indices(expr) - idxs = union(oidxs, iidxs) - fbody = call2term(deepcopy(expr)) - oftype(x,T) = :($x::$T) - aop = gensym("aop") - quote - let - @syms $(map(x->oftype(x, Int), idxs)...) - - expr = $fbody - #TODO: proper Atype - $ArrayOp(Array{$symtype(expr), - $(length(output_idx.args))}, - $output_idx, - expr, - $reduce, - $(call2term(call)), - Dict($(rs...))) - - end - end |> esc -end - -const SymArray = Union{ArrayOp, Symbolic{<:AbstractArray}} -const SymMat = Union{ArrayOp{<:AbstractMatrix}, Symbolic{<:AbstractMatrix}} -const SymVec = Union{ArrayOp{<:AbstractVector}, Symbolic{<:AbstractVector}} - -### Propagate ### -# -## Shape ## - - -function axis_in(a, b) - first(a) >= first(b) && last(a) <= last(b) -end - -function make_shape(output_idx, expr, ranges=Dict()) - matches = idx_to_axes(expr) - for (sym, ms) in matches - to_check = filter(m->!(shape(m.A) isa Unknown), ms) - # Only check known dimensions. It may be "known symbolically" - isempty(to_check) && continue - restricted = false - if haskey(ranges, sym) - ref_axis = ranges[sym] - restricted = true - else - ref_axis = axes(first(to_check).A, first(to_check).dim) - end - reference = ref_axis - for i in (restricted ? 1 : 2):length(ms) - m = ms[i] - s=shape(m.A) - if s !== Unknown() - if restricted - if !axis_in(ref_axis, axes(m.A, m.dim)) - throw(DimensionMismatch("expected $(ref_axis) to be within axes($(m.A), $(m.dim))")) - end - elseif !isequal(axes(m.A, m.dim), reference) - throw(DimensionMismatch("expected axes($(m.A), $(m.dim)) = $(reference)")) - end - end - end - end - - sz = map(output_idx) do i - if issym(i) - if haskey(ranges, i) - return axes(ranges[i], 1) - end - if !haskey(matches, i) - error("There was an error processing arrayop expression $expr.\n" * - "Dimension of output index $i in $output_idx could not be inferred") - end - mi = matches[i] - @assert !isempty(mi) - ext = get_extents(mi) - ext isa Unknown && return Unknown() - return 1:(length(ext)) - elseif i isa Integer - return 1:(1) - end - end - # TODO: maybe we can remove this restriction? - if any(x->x isa Unknown, sz) - Unknown() - else - sz - end -end - - -function ranges(a::ArrayOp) - rs = Dict{BasicSymbolic, Any}() - ax = idx_to_axes(a.expr) - for i in keys(ax) - if haskey(a.ranges, i) - rs[i] = a.ranges[i] - else - rs[i] = ax[i] #get_extents(ax[i]) - end - end - return rs -end - -## Eltype ## - -eltype(aop::ArrayOp) = symtype(aop.expr) - -## Ndims ## -function ndims(aop::ArrayOp) - length(aop.output_idx) -end - - -### Utils ### - - -# turn `f(x...)` into `term(f, x...)` -# -function call2term(expr, arrs=[]) - !(expr isa Expr) && return :($unwrap($expr)) - if expr.head == :call - if expr.args[1] == :(:) - return expr - end - return Expr(:call, term, map(call2term, expr.args)...) - elseif expr.head == :ref - return Expr(:ref, call2term(expr.args[1]), expr.args[2:end]...) - elseif expr.head == Symbol("'") - return Expr(:call, term, adjoint, map(call2term, expr.args)...) - end - - return Expr(expr.head, map(call2term, expr.args)...) -end - -# Find all symbolic indices in expr -function find_indices(expr, idxs=[]) - !(expr isa Expr) && return idxs - if expr.head == :ref - return append!(idxs, filter(x->x isa Symbol, expr.args[2:end])) - elseif expr.head == :call && expr.args[1] == :getindex || expr.args[1] == getindex - return append!(idxs, filter(x->x isa Symbol, expr.args[3:end])) - else - foreach(x->find_indices(x, idxs), expr.args) - return idxs - end -end - -struct AxisOf - A - dim - boundary -end - -function Base.get(a::AxisOf) - @oops shape(a.A) - axes(a.A, a.dim) -end - -function get_extents(xs) - boundaries = map(x->x.boundary, xs) - if all(iszero∘wrap, boundaries) - get(first(xs)) - else - ii = findfirst(x->issym(x) || iscall(x), boundaries) - if !isnothing(ii) - error("Could not find the boundary from symbolic index $(xs[ii]). Please manually specify the range of indices.") - end - extent = get(first(xs)) - start_offset = -reduce(min, filter(x->x<0, boundaries), init=0) - end_offset = reduce(max, filter(x->x>0, boundaries), init=0) - - (first(extent) + start_offset):(last(extent) - end_offset) - end -end - -get_extents(x::AbstractRange) = x - -## Walk expr looking for symbols used in getindex expressions -# Returns a dictionary of Sym to a vector of AxisOf objects. -# The vector has as many elements as the number of times the symbol -# appears in the expr. AxisOf has three fields: -# A: the array whose indexing it appears in -# dim: The dimension of the array indexed -# boundary: how much padding is this indexing requiring, for example -# boundary is 2 for x[i + 2], and boundary = -2 for x[i - 2] -function idx_to_axes(expr, dict=Dict{Any, Vector}(), ranges=Dict()) - if iscall(expr) - if operation(expr) === (getindex) - args = arguments(expr) - for (axis, idx_expr) in enumerate(@views args[2:end]) - if issym(idx_expr) || iscall(idx_expr) - vs = get_variables(idx_expr) - isempty(vs) && continue - sym = only(get_variables(idx_expr)) - axesvec = Base.get!(() -> [], dict, sym) - push!(axesvec, AxisOf(first(args), axis, idx_expr - sym)) - end - end - else - idx_to_axes(operation(expr), dict) - foreach(ex->idx_to_axes(ex, dict), arguments(expr)) - end - end - dict -end - - -#### Term{<:AbstractArray} -# - -""" - array_term(f, args...; - container_type = propagate_atype(f, args...), - eltype = propagate_eltype(f, args...), - size = map(length, propagate_shape(f, args...)), - ndims = propagate_ndims(f, args...)) - -Create a term of `Term{<: AbstractArray}` which -is the representation of `f(args...)`. - -Default arguments: -- `container_type=propagate_atype(f, args...)` - the container type, - i.e. `Array` or `StaticArray` etc. -- `eltype=propagate_eltype(f, args...)` - the output element type. -- `size=map(length, propagate_shape(f, args...))` - the - output array size. `propagate_shape` returns a tuple of index ranges. -- `ndims=propagate_ndims(f, args...)` the output dimension. - -`propagate_shape`, `propagate_atype`, `propagate_eltype` may -return `Unknown()` to say that the output cannot be determined -""" -function array_term(f, args...; - container_type = propagate_atype(f, args...), - eltype = propagate_eltype(f, args...), - size = Unknown(), - ndims = size !== Unknown() ? length(size) : propagate_ndims(f, args...), - shape = size !== Unknown() ? Tuple(map(x->1:x, size)) : propagate_shape(f, args...)) - - if container_type == Unknown() - # There's always a fallback for this - container_type = propagate_atype(f, args...) - end +### Wrapper type for dispatch - if eltype == Unknown() - eltype = Base.propagate_eltype(container_type) - end +@symbolic_wrap struct Arr{T,N} <: AbstractArray{T, N} + value::BasicSymbolic{VartypeT} - if ndims == Unknown() - ndims = if shape == Unknown() - Any + function Arr{T, N}(ex) where {T, N} + if is_wrapper_type(T) + @assert symtype(ex) <: AbstractArray{<:wraps_type(T), N} else - length(shape) + @assert symtype(ex) <: AbstractArray{T, N} end + new{T, N}(Const{VartypeT}(ex)) end - S = container_type{eltype, ndims} - setmetadata(Term{S}(f, Any[args...]), ArrayShapeCtx, shape) -end - -""" - shape(s::Any) - -Returns `axes(s)` or `Unknown()`. -""" -shape(s) = axes(s) - -""" - shape(s::SymArray) - -Extract the shape metadata from a SymArray. -If not known, returns `Unknown()` -""" -function shape(s::Symbolic{<:AbstractArray}) - if hasmetadata(s, ArrayShapeCtx) - getmetadata(s, ArrayShapeCtx) - else - Unknown() - end -end - -## `propagate_` interface: -# used in the `array_term` construction. - -atype(::Type{<:Array}) = Array -atype(::Type{<:SArray}) = SArray -atype(::Type) = AbstractArray - -_propagate_atype(::Type{T}, ::Type{T}) where {T} = T -_propagate_atype(::Type{<:Array}, ::Type{<:SArray}) = Array -_propagate_atype(::Type{<:SArray}, ::Type{<:Array}) = Array -_propagate_atype(::Any, ::Any) = AbstractArray -_propagate_atype(T) = T -_propagate_atype() = AbstractArray - -function propagate_atype(f, args...) - As = [atype(symtype(T)) - for T in Iterators.filter(x->x <: Symbolic{<:AbstractArray}, typeof.(args))] - if length(As) <= 1 - _propagate_atype(As...) - else - foldl(_propagate_atype, As) - end -end - -function propagate_eltype(f, args...) - As = [eltype(symtype(T)) - for T in Iterators.filter(x->symtype(x) <: AbstractArray, args)] - promote_type(As...) -end - -function propagate_ndims(f, args...) - if propagate_shape(f, args...) == Unknown() - error("Could not determine the output dimension of $f$args") - else - length(propagate_shape(f, args...)) - end -end - -function propagate_shape(f, args...) - error("Don't know how to propagate shape for $f$args") -end - -### Wrapper type for dispatch - -@symbolic_wrap struct Arr{T,N} <: AbstractArray{T, N} - value end Base.hash(x::Arr, u::UInt) = hash(unwrap(x), u) Base.isequal(a::Arr, b::Arr) = isequal(unwrap(a), unwrap(b)) -Base.isequal(a::Arr, b::Symbolic) = isequal(unwrap(a), b) -Base.isequal(a::Symbolic, b::Arr) = isequal(b, a) - -ArrayOp(x::Arr) = unwrap(x) +Base.isequal(a::Arr, b::BasicSymbolic) = isequal(unwrap(a), b) +Base.isequal(a::BasicSymbolic, b::Arr) = isequal(b, a) function Arr(x) A = symtype(x) @@ -505,14 +28,8 @@ function Arr(x) Arr{maybewrap(eltype(A)), ndims(A)}(x) end -const ArrayLike{T,N} = Union{ - ArrayOp{AbstractArray{T,N}}, - Symbolic{AbstractArray{T,N}}, - Arr{T,N}, - SymbolicUtils.Term{AbstractArray{T, N}} -} # Like SymArray but includes Arr and Term{Arr} - -unwrap(x::Arr) = x.value +SymbolicUtils.unwrap(x::Arr) = x.value +SymbolicUtils.symtype(x::Arr) = symtype(unwrap(x)) maybewrap(T) = has_symwrapper(T) ? wrapper_type(T) : T # These methods allow @wrapped methods to be more specific and not overwrite @@ -528,494 +45,71 @@ function Base.show(io::IO, arr::Arr) iscall(x) && print(io, "(") print(io, unwrap(arr)) iscall(x) && print(io, ")") - if !(shape(x) isa Unknown) - print(io, "[", join(string.(axes(arr)), ","), "]") + print(io, "[") + if shape(x) isa SymbolicUtils.Unknown + print(io, shape(x)) + else + print(io, join(string.(axes(arr)), ",")) end + print(io, "]") end Base.show(io::IO, ::MIME"text/plain", arr::Arr) = show(io, arr) ################# Base array functions -# - -# basic -# these methods are not symbolic but work if we know this info. - -geteltype(s::SymArray) = geteltype(symtype(s)) -geteltype(::Type{<:AbstractArray{T}}) where {T} = T -geteltype(::Type{<:AbstractArray}) = Unknown() - -ndims(s::SymArray) = ndims(symtype(s)) -ndims(::Type{<:Arr{<:Any, N}}) where N = N - -function eltype(A::Union{Arr, SymArray}) - T = geteltype(unwrap(A)) - T === Unknown() && error("eltype of $A not known") - return T -end - -function length(A::Union{Arr, SymArray}) - s = shape(unwrap(A)) - s === Unknown() && error("length of $A not known") - return prod(length, s) -end - -function size(A::Union{Arr, SymArray}) - s = shape(unwrap(A)) - s === Unknown() && error("size of $A not known") - return length.(s) -end - -function size(A::SymArray, i::Integer) - @assert(i > 0) - i > ndims(A) ? 1 : size(A)[i] -end - -function axes(A::Union{Arr, SymArray}) - s = shape(unwrap(A)) - s === Unknown() && error("axes of $A not known") - return s -end - - -function axes(A::SymArray, i) - s = shape(A) - s === Unknown() && error("axes of $A not known") - return i <= length(s) ? s[i] : 1:(1) -end - -function eachindex(A::Union{Arr, SymArray}) - s = shape(unwrap(A)) - s === Unknown() && error("eachindex of $A not known") - return CartesianIndices(s) -end - -function get_variables!(vars, e::Arr, varlist=nothing) - foreach(x -> get_variables!(vars, x, varlist), collect(e)) - vars -end - - -### Scalarize - -scalarize(a::Array) = map(scalarize, a) -scalarize(term::Symbolic{<:AbstractArray}, idx) = term[idx...] -val2num(::Val{n}) where n = n - -function replace_by_scalarizing(ex, dict) - rule = @rule(getindex(~x, ~~i) => - scalarize(~x, (map(j->substitute(j, dict), ~~i)...,))) - - function rewrite_operation(x) - if iscall(x) && iscall(operation(x)) - f = operation(x) - ff = replace_by_scalarizing(f, dict) - if metadata(x) !== nothing - maketerm(typeof(x), ff, arguments(x), metadata(x)) - else - ff(arguments(x)...) - end - else - nothing - end - end - - prewalk_if(x->!(x isa ArrayOp || x isa ArrayMaker), - Rewriters.PassThrough(Chain([rewrite_operation, rule])), - ex) -end - -function prewalk_if(cond, f, t) - t′ = cond(t) ? f(t) : return t - if iscall(t′) - if metadata(t′) !== nothing - return maketerm(typeof(t′), TermInterface.head(t′), - map(x->prewalk_if(cond, f, x), children(t′)), metadata(t′)) - else - TermInterface.head(t′)(map(x->prewalk_if(cond, f, x), children(t′))...) - end - else - return t′ - end -end - -function scalarize(arr::AbstractArray, idx) - arr[idx...] -end - -function scalarize(arr, idx) - if iscall(arr) - scalarize_op(operation(arr), arr, idx) - else - error("scalarize is not defined for $arr at idx=$idx") - end -end - -scalarize_op(f, arr) = arr - -struct ScalarizeCache end - -function scalarize_op(f, arr, idx) - if hasmetadata(arr, ScalarizeCache) && getmetadata(arr, ScalarizeCache)[] !== nothing - getmetadata(arr, ScalarizeCache)[][idx...] - else - # wrap and unwrap to call generic methods - thing = unwrap(f(scalarize.(map(wrap, arguments(arr)))...)) - if metadata(arr) != nothing - # forward any metadata - try - thing = metadata(thing, metadata(arr)) - catch err - @warn "could not attach metadata of subexpression $arr to the scalarized form at idx" - end - end - if !hasmetadata(arr, ScalarizeCache) - arr = setmetadata(arr, ScalarizeCache, Ref{Any}(nothing)) - end - getmetadata(arr, ScalarizeCache)[] = thing - thing[idx...] - end -end - -@wrapped function Base.:(\)(A::AbstractMatrix, b::AbstractVecOrMat) - t = array_term(\, A, b) - setmetadata(t, ScalarizeCache, Ref{Any}(nothing)) -end false - -@wrapped function Base.inv(A::AbstractMatrix) - t = array_term(inv, A) - setmetadata(t, ScalarizeCache, Ref{Any}(nothing)) -end false - -_det(x, lp) = det(x, laplace=lp) - -function scalarize_op(f::typeof(_det), arr) - unwrap(det(map(wrap, collect(arguments(arr)[1])), laplace=arguments(arr)[2])) -end - -@wrapped function LinearAlgebra.det(x::AbstractMatrix; laplace=true) - Term{eltype(x)}(_det, [x, laplace]) -end false - - -# A * x = b -# A ∈ R^(m x n) x ∈ R^(n, k) = b ∈ R^(m, k) -propagate_ndims(::typeof(\), A, b) = ndims(b) -propagate_ndims(::typeof(inv), A) = ndims(A) - -# A(m,k) * B(k,n) = C(m,n) -# A(m,k) \ C(m,n) = B(k,n) -function propagate_shape(::typeof(\), A, b) - if ndims(b) == 1 - (axes(A,2),) - else - (axes(A,2), axes(b, 2)) - end -end - -function propagate_shape(::typeof(inv), A) - @oops shp = shape(A) - @assert ndims(A) == 2 && reverse(shp) == shp "Inv called on a non-square matrix" - shp -end - -function scalarize(arr::ArrayOp, idx) - @assert length(arr.output_idx) == length(idx) - - axs = ranges(arr) - - iidx = collect(keys(axs)) - contracted = setdiff(iidx, arr.output_idx) - - axes = [get_extents(axs[c]) for c in contracted] - summed = if isempty(contracted) - arr.expr - else - mapreduce(arr.reduce, Iterators.product(axes...)) do idx - replace_by_scalarizing(arr.expr, Dict(contracted .=> idx)) - end - end - - dict = Dict(oi => (unwrap(i) isa Symbolic ? unwrap(i) : get_extents(axs[oi])[i]) - for (oi, i) in zip(arr.output_idx, idx) if unwrap(oi) isa Symbolic) - - replace_by_scalarizing(summed, dict) -end - -scalarize(arr::Arr, idx) = wrap(scalarize(unwrap(arr), - unwrap.(idx))) - - -eval_array_term(op, arr) = arr -eval_array_term(op::typeof(inv), arr) = inv(scalarize(wrap(arguments(arr)[1]))) #issue 653 -eval_array_term(op::Arr) = wrap(eval_array_term(unwrap(op))) -eval_array_term(op) = eval_array_term(operation(op), op) - -""" - $(TYPEDSIGNATURES) - -Replace all occurrences of array symbolics (variables and subexpressions) in expression -`arr` with scalarized variants. -""" -function scalarize(arr) - if arr isa Arr || arr isa Symbolic{<:AbstractArray} - if iscall(arr) - arr = eval_array_term(arr) - end - map(Iterators.product(axes(arr)...)) do i - scalarize(arr[i...]) # Use arr[i...] here to trigger any getindex hooks - end - elseif iscall(arr) && operation(arr) == getindex - args = arguments(arr) - scalarize(args[1], (args[2:end]...,)) - elseif arr isa Num - wrap(scalarize(unwrap(arr))) - elseif iscall(arr) && symtype(arr) <: Number - t = maketerm(typeof(arr), operation(arr), map(scalarize, arguments(arr)), metadata(arr)) - iscall(t) ? scalarize_op(operation(t), t) : t - else - arr - end -end - -@wrapped Base.isempty(x::AbstractArray) = shape(unwrap(x)) !== Unknown() && _iszero(length(x)) false -Base.collect(x::Arr) = scalarize(x) -Base.collect(x::SymArray) = scalarize(x) -isarraysymbolic(x) = unwrap(x) isa Symbolic && SymbolicUtils.symtype(unwrap(x)) <: AbstractArray +Base.IndexStyle(::Type{<:Arr}) = Base.IndexStyle(BasicSymbolic{VartypeT}) +Base.length(A::Arr) = length(unwrap(A)) +Base.size(A::Arr) = size(unwrap(A)) +Base.axes(A::Arr) = axes(unwrap(A)) +Base.eachindex(A::Arr) = eachindex(unwrap(A)) + +function SymbolicUtils.search_variables!(buffer, expr::Arr; kw...) + SymbolicUtils.search_variables!(buffer, unwrap(expr); kw...) +end + +# cannot use `@wrapped` since it will define `\(::BasicSymbolic, ::BasicSymbolic)` +# and because `\(::Arr, ::BasicSymbolic)` will be ambiguous. +for (T1, T2) in [ + (Arr{<:Any, 2}, Arr{<:Any, 1}), + (Arr{<:Any, 2}, Arr{<:Any, 2}), + (AbstractArray{<:Any, 2}, Arr{<:Any, 1}), + (AbstractArray{<:Any, 2}, Arr{<:Any, 2}), + (Arr{<:Any, 2}, AbstractArray{<:Any, 1}), + (Arr{<:Any, 2}, AbstractArray{<:Any, 2}), + (Arr{<:Any, 2}, BasicSymbolic{SymReal}), + (Arr{<:Any, 2}, BasicSymbolic{SafeReal}), + (Arr{<:Any, 2}, BasicSymbolic{TreeReal}), + (BasicSymbolic{SymReal}, Arr{<:Any, 1}), + (BasicSymbolic{SafeReal}, Arr{<:Any, 1}), + (BasicSymbolic{TreeReal}, Arr{<:Any, 1}), + (BasicSymbolic{SymReal}, Arr{<:Any, 2}), + (BasicSymbolic{SafeReal}, Arr{<:Any, 2}), + (BasicSymbolic{TreeReal}, Arr{<:Any, 2}), +] + @eval function Base.:(\)(A::$T1, b::$T2) + unwrap(A) \ unwrap(b) + end +end + +Base.inv(A::Arr{<:Any, 2}) = wrap(inv(unwrap(A))) +LinearAlgebra.det(A::Arr{<:Any, 2}) = wrap(det(unwrap(A))) +LinearAlgebra.adjoint(A::Arr{<:Any, 2}) = wrap(adjoint(unwrap(A))) +LinearAlgebra.adjoint(A::Arr{<:Any, 1}) = wrap(adjoint(unwrap(A))) + +SymbolicUtils.scalarize(x::Arr) = SymbolicUtils.scalarize(unwrap(x)) + +Base.isempty(x::Arr) = isempty(unwrap(x)) +Base.collect(x::Arr) = wrap.(collect(unwrap(x))) +isarraysymbolic(x) = false +# this should be validated in the constructor +isarraysymbolic(x::Arr) = true +isarraysymbolic(x::BasicSymbolic) = symtype(x) <: AbstractArray Base.convert(::Type{<:Array{<:Any, N}}, arr::Arr{<:Any, N}) where {N} = scalarize(arr) - -### Stencils - -struct ArrayMaker{T, AT<:AbstractArray} <: Symbolic{AT} - shape - sequence - metadata -end - -function ArrayMaker(a::ArrayLike; eltype=eltype(a)) - ArrayMaker{eltype}(size(a), Any[axes(a) => a]) -end - -function arraymaker(T, shape, views, seq...) - ArrayMaker{T}(shape, [(views .=> seq)...]) -end - -iscall(x::ArrayMaker) = true -operation(x::ArrayMaker) = arraymaker -arguments(x::ArrayMaker) = [eltype(x), shape(x), map(first, x.sequence), map(last, x.sequence)...] - -shape(am::ArrayMaker) = am.shape - -function ArrayMaker{T}(sz::NTuple{N, Integer}, seq::Array=[]; atype=Array, metadata=nothing) where {N,T} - ArrayMaker{T, atype{T, N}}(map(x->1:x, sz), seq, metadata) -end - -(::Type{ArrayMaker{T}})(i::Int...; atype=Array) where {T} = ArrayMaker{T}(i, atype=atype) - -function Base.show(io::IO, ac::ArrayMaker) - print(io, Expr(:call, :ArrayMaker, ac.shape, - Expr(:block, ac.sequence...))) -end - -function get_indexers(ex) - @assert ex.head == :ref - arr = ex.args[1] - args = map(((i,x),)->x == Symbol(":") ? :(1:lastindex($arr, $i)) : x, enumerate(ex.args[2:end])) - replace_ends(arr, args) -end - -function search_and_replace(expr, key, val) - isequal(expr, key) && return val - - expr isa Expr ? - Expr(expr.head, map(x->search_and_replace(x, key,val), expr.args)...) : - expr -end - -function replace_ends(arr, idx) - [search_and_replace(ix, :end, :(lastindex($arr, $i))) - for (i, ix) in enumerate(idx)] -end - -macro setview!(definition, arrayop) - setview(definition, arrayop, true) -end - -macro setview(definition, arrayop) - setview(definition, arrayop, false) -end - -output_index_ranges(c::CartesianIndices) = c.indices -output_index_ranges(ix...) = ix - -function setview(definition, arrayop, inplace) - output_view = get_indexers(definition) - output_ref = definition.args[1] - - function check_assignment(vw, op) - try Base.Broadcast.broadcast_shape(map(length, vw), size(op)) - catch err - if err isa DimensionMismatch - throw(DimensionMismatch("setview did not work while assigning indices " * - "$vw to $op. LHS has size $(map(length, vw)) "* - "and RHS has size $(size(op)) " * - "-- they need to be broadcastable.")) - else - rethrow(err) - end - end - end - - function push(inplace) - if inplace - function (am, vw, op) - check_assignment(vw, op) - # assert proper size match - push!(am.sequence, vw => op) - am - end - else - function (am, vw, op) - check_assignment(vw, op) - if am isa ArrayMaker - typeof(am)(am.shape, vcat(am.sequence, vw => op)) - else - am = ArrayMaker(am) - push!(am.sequence, vw => op) - am - end - end - end - end - quote - $(push(inplace))($output_ref, - $output_index_ranges($(output_view...)), $unwrap($arrayop)) - end |> esc -end - -macro makearray(definition, sequence) - output_shape = get_indexers(definition) - output_name = definition.args[1] - - seq = map(filter(x->!(x isa LineNumberNode), sequence.args)) do pair - @assert pair.head == :call && pair.args[1] == :(=>) - # TODO: make sure the same symbol is used for the lhs array - :(@setview! $(pair.args[2]) $(pair.args[3])) - end - - quote - $output_name = $ArrayMaker{Real}(map(length, ($(output_shape...),))) - $(seq...) - $output_name = $wrap($output_name) - end |> esc -end - -function best_order(output_idx, ks, rs) - unique!(filter(issym, vcat(reverse(output_idx)..., collect(ks)))) -end - -function _cat(x, xs...; dims) - arrays = (x, xs...) - if dims isa Integer - sz = Base.cat_size_shape(Base.dims2cat(dims), arrays...) - T = reduce(promote_type, eltype.(xs), init=eltype(x)) - newdim = cumsum(map(a->size(a, dims), arrays)) - start = 1 - A = ArrayMaker{T}(sz...) - for (dim, array) in zip(newdim, arrays) - idx = CartesianIndices(ntuple(n -> n==dims ? - (start:dim) : (1:sz[n]), length(sz))) - start = dim + 1 - - @setview! A[idx] array - end - return A - else - error("Block diagonal concatenation not supported") - end -end - -# Base.cat(x::Arr, xs...; dims) = _cat(x, xs...; dims) -# Base.cat(x::AbstractArray, y::Arr, xs...; dims) = _cat(x, y, xs...; dims) - -# vv uncomment these for a major release -# Base.vcat(x::Arr, xs::AbstractVecOrMat...) = cat(x, xs..., dims=1) -# Base.hcat(x::Arr, xs::AbstractVecOrMat...) = cat(x, xs..., dims=2) -# Base.vcat(x::AbstractVecOrMat, y::Arr, xs::AbstractVecOrMat...) = _cat(x, y, xs..., dims=1) -# Base.hcat(x::AbstractVecOrMat, y::Arr, xs::AbstractVecOrMat...) = _cat(x, y, xs..., dims=2) -# Base.vcat(x::Arr, y::Arr) = _cat(x, y, dims=1) # plug ambiguity -# Base.hcat(x::Arr, y::Arr) = _cat(x, y, dims=2) - -function scalarize(x::ArrayMaker) - T = eltype(x) - A = Array{wrapper_type(T)}(undef, size(x)) - for (vw, arr) in x.sequence - if any(x->x isa AbstractArray, vw) - A[vw...] .= scalarize(arr) - else - A[vw...] = scalarize(arr) - end - end - A -end - -function scalarize(x::ArrayMaker, idx) - for (vw, arr) in reverse(x.sequence) # last one wins - if any(x->issym(x) || iscall(x), idx) - return term(getindex, x, idx...) - end - if all(in.(idx, vw)) - if symtype(arr) <: AbstractArray - # Filter out non-array indices because the RHS will be one dim less - el = [searchsortedfirst(v, i) - for (v, i) in zip(vw, idx) if v isa AbstractArray] - return scalarize(arr[el...]) - else - return arr - end - end - end - if !any(x->issym(x) || iscall(x), idx) && all(in.(idx, axes(x))) - throw(UndefRefError()) - end - - throw(BoundsError(x, idx)) -end - - -### Codegen - -function SymbolicUtils.Code.toexpr(x::ArrayOp, st) - haskey(st.rewrites, x) && return st.rewrites[x] - - if iscall(x.term) - toexpr(x.term, st) - else - _array_toexpr(x, st) - end -end - function SymbolicUtils.Code.toexpr(x::Arr, st) toexpr(unwrap(x), st) end -function SymbolicUtils.Code.toexpr(x::ArrayMaker, st) - _array_toexpr(x, st) -end - -function _array_toexpr(x, st) - outsym = Symbol("_out") - N = length(shape(x)) - ex = Let( - [ - Assignment(outsym, term(zeros, Float64, term(map, length, shape(x)))), - Assignment(Symbol("%$outsym"), inplace_expr(x, outsym)) - ], outsym, false) - - toexpr(ex, st) -end - """ $(TYPEDSIGNATURES) @@ -1027,31 +121,15 @@ end function inplace_expr(x, out_array, intermediates = nothing) x = unwrap(x) - if symtype(x) <: Number + if SymbolicUtils.isarrayop(x) && x.term === nothing + return x + elseif symtype(x) <: Number return term(broadcast_assign!, out_array, x) else return term(copy!, out_array, x) end end -function inplace_expr(x::ArrayMaker, out, intermediates = nothing) - steps = Assignment[] - - _intermediates = something(intermediates, Dict()) - for (i, (vw, op)) in enumerate(x.sequence) - out′ = Symbol(out, "_", i) - push!(steps, Assignment(out′, term(view, out, vw...))) - push!(steps, Assignment(Symbol("%$out′"), inplace_expr(unwrap(op), out′, _intermediates))) - end - - expr = Let(steps, nothing, false) - if intermediates === nothing && !isempty(_intermediates) - steps = [map(k -> Assignment(_intermediates[k], k), collect(keys(_intermediates))); steps] - expr = Let(steps, nothing, false) - end - return expr -end - function inplace_expr(x::AbstractArray, out, intermediates = nothing) expr = SetArray(false, out, x, true) if intermediates !== nothing @@ -1069,85 +147,9 @@ function inplace_builtin(term, outsym) return nothing end -function find_inter(acc, expr) - if !issym(expr) && symtype(expr) <: AbstractArray - push!(acc, expr) - elseif iscall(expr) - foreach(x -> find_inter(acc, x), arguments(expr)) - end - acc -end - -function get_inputs(x::ArrayOp) - unique(find_inter([], x.expr)) -end - -function similar_arrayvar(ex, name) - Sym{symtype(ex)}(name) #TODO: shape? -end - -function reset_to_one(range) - @assert step(range) == 1 - 1:(length(range)) -end - -function reset_sym(i) - Sym{Int}(Symbol(nameof(i), "′")) -end - -function inplace_expr(x::ArrayOp, outsym = :_out, intermediates = nothing) - if x.term !== nothing - ex = inplace_builtin(x.term, outsym) - if ex !== nothing - return ex - end - end - - inters = filter(!issym, get_inputs(x)) - intermediate_exprs = map(enumerate(inters)) do (i, ex) - if intermediates !== nothing - if haskey(intermediates, ex) - return ex => intermediates[ex] - else - sym = similar_arrayvar(ex, Symbol(outsym, :_input_, i)) - intermediates[ex] = sym - return ex => sym - end - else - return ex => similar_arrayvar(ex, Symbol(outsym, :_input_, i)) - end - end - - - rs = copy(ranges(x)) - loops = best_order(x.output_idx, keys(rs), rs) - expr = substitute(unwrap(x.expr), Dict(intermediate_exprs)) - - out_idxs = map(reset_sym, x.output_idx) - inner_expr = SetArray(false, outsym, [AtIndex(term(CartesianIndex, out_idxs...), term(x.reduce, term(getindex, outsym, out_idxs...), expr))]) - - loops = foldl(reverse(loops), init=inner_expr) do acc, k - if any(isequal(k), x.output_idx) - k′ = reset_sym(k) - loopvar = Symbol("%$k$k′") - ext = get_extents(rs[k]) - if isone(first(ext)) && isone(step(ext)) - ext = Base.OneTo(last(ext)) - end - ForLoop(loopvar, term(zip, ext, term(reset_to_one, ext)), Let([DestructuredArgs([k, reset_sym(k)], loopvar)], acc, false)) - else - ForLoop(k, get_extents(rs[k]), acc) - end - end - - if intermediates === nothing && !isempty(intermediate_exprs) - return Let(map(x -> Assignment(x[2], x[1]), intermediate_exprs), loops, false) - end - return loops -end - hasnode(r::Function, y::Arr) = _hasnode(r, y) -hasnode(r::Union{Num, Symbolic, Arr}, y::Arr) = occursin(unwrap(r), unwrap(y)) +hasnode(r::Union{Num, Arr}, y::Arr) = occursin(unwrap(r), unwrap(y)) +hasnode(r::Arr, y::BasicSymbolic) = occursin(unwrap(r), unwrap(y)) #= """ diff --git a/src/build_function.jl b/src/build_function.jl index 3045dd389..5c8ea9d67 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -86,7 +86,6 @@ end # Scalar output unwrap_nometa(x) = unwrap(x) -unwrap_nometa(x::CallWithMetadata) = unwrap(x.f) function destructure_arg(arg::Union{AbstractArray, Tuple,NamedTuple}, inbounds, name) if !(arg isa Arr) DestructuredArgs(map(unwrap_nometa, arg), name, inbounds=inbounds, create_bindings=false) @@ -107,7 +106,7 @@ SymbolicUtils.Code.cse_inside_expr(sym, ::Symbolics.Operator, args...) = false # don't CSE inside `getindex` of things created via `@variables` # EXCEPT called variables function SymbolicUtils.Code.cse_inside_expr(sym, ::typeof(getindex), x::BasicSymbolic, idxs...) - return !hasmetadata(sym, VariableSource) || hasmetadata(sym, CallWithParent) + return !hasmetadata(sym, VariableSource) || SymbolicUtils.is_called_function_symbolic(x) end function _build_function(target::JuliaTarget, op, args...; @@ -122,6 +121,9 @@ function _build_function(target::JuliaTarget, op, args...; nanmath = true, kwargs...) op = _recursive_unwrap(op) + if symtype(op) <: AbstractArray + return _build_function(target, wrap(op), args...; conv, expression, expression_module, checkbounds, states, linenumbers, cse, nanmath, kwargs...) + end states.rewrites[:nanmath] = nanmath dargs = map((x) -> destructure_arg(x[2], !checkbounds, default_arg_name(x[1])), enumerate(collect(args))) fun = Func(dargs, [], op) @@ -149,7 +151,7 @@ end SymbolicUtils.Code.get_rewrites(x::Arr) = SymbolicUtils.Code.get_rewrites(unwrap(x)) -function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUtils.BasicSymbolic{<:AbstractArray}}, args...; +function _build_function(target::JuliaTarget, op::Arr, args...; conv = toexpr, expression = Val{true}, expression_module = @__MODULE__(), @@ -182,8 +184,10 @@ function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUt oop_expr = Code.cse(oop_expr) iip_expr = Code.cse(iip_expr) end - oop_expr = conv(oop_expr, states) + if SymbolicUtils.isarrayop(op) && !haskey(states.rewrites, :arrayop_output) + states.rewrites[:arrayop_output] = outsym + end iip_expr = conv(iip_expr, states) if !checkbounds @@ -338,7 +342,7 @@ function _build_function(target::JuliaTarget, rhss::AbstractArray, args...; if iip - out = Sym{Any}(DEFAULT_OUTSYM) + out = Sym{VartypeT}(DEFAULT_OUTSYM; type = Any, shape = SymbolicUtils.Unknown(-1)) iip_expr = Func(vcat(out, dargs), [], postprocess_fbody(set_array(parallel, dargs, out, @@ -603,10 +607,11 @@ function buildvarnumbercache(args...) return Dict(varnumsdict) end -function numbered_expr(O::Symbolic,varnumbercache,args...;varordering = args[1],offset = 0, +function numbered_expr(O::BasicSymbolic,varnumbercache,args...;varordering = args[1],offset = 0, states = LazyState(), lhsname=:du,rhsnames=[Symbol("MTK$i") for i in 1:length(args)]) O = value(O) + O isa BasicSymbolic || return O if (issym(O) || issym(operation(O))) || (iscall(O) && operation(O) == getindex) (j,i) = get(varnumbercache, O, (nothing, nothing)) if !isnothing(j) diff --git a/src/complex.jl b/src/complex.jl index b221fa30f..79ece3764 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -1,15 +1,7 @@ -abstract type AbstractComplexTerm{T} <: Symbolic{Complex{T}} -end - -struct ComplexTerm{T} <: AbstractComplexTerm{T} - re - im -end - -Base.imag(c::Symbolic{Complex{T}}) where {T} = term(imag, c) SymbolicUtils.promote_symtype(::typeof(imag), ::Type{Complex{T}}) where {T} = T Base.promote_rule(::Type{Complex{T}}, ::Type{S}) where {T<:Real, S<:Num} = Complex{S} # 283 +is_wrapper_type(::Type{Complex{Num}}) = true has_symwrapper(::Type{<:Complex{T}}) where {T<:Real} = true wraps_type(::Type{Complex{Num}}) = Complex{Real} iswrapped(::Complex{Num}) = true @@ -17,26 +9,24 @@ function wrapper_type(::Type{Complex{T}}) where T Symbolics.has_symwrapper(T) ? Complex{wrapper_type(T)} : Complex{T} end -symtype(a::ComplexTerm{T}) where T = Complex{T} -iscall(a::ComplexTerm) = true -operation(a::ComplexTerm{T}) where T = Complex{T} -arguments(a::ComplexTerm) = [a.re, a.im] -metadata(a::ComplexTerm) = metadata(a.re) - -function maketerm(T::Type{<:ComplexTerm}, f, args, metadata) - if f <: Complex - ComplexTerm{real(f)}(args...) - else - maketerm(typeof(first(args)), f, args, metadata) +function SymbolicUtils.unwrap(a::Complex{<:Num}) + re, img = unwrap(real(a)), unwrap(imag(a)) + if SymbolicUtils.isconst(re) && SymbolicUtils.isconst(img) + return Const{VartypeT}(complex(unwrap_const(re), unwrap_const(img))) end + if iscall(re) && operation(re) === real && iscall(img) && operation(img) === imag && isequal(arguments(re)[1], arguments(img)[1]) + return arguments(re)[1] + end + sT = promote_type(symtype(re), symtype(img)) + return Term{VartypeT}(complex, SymbolicUtils.ArgsT{vartype(re)}((re, img)); type = Complex{sT}, shape = SymbolicUtils.ShapeVecT()) end -function Base.show(io::IO, mime::MIME"text/plain", a::ComplexTerm) - print(io, "ComplexTerm(") - show(io, mime, wrap(a)) - print(io, ")") +function Base.Complex{Num}(x::BasicSymbolic{VartypeT}) + Complex{Num}(wrap(real(x)), wrap(imag(x))) end +const IM = Sym{VartypeT}(:im; type = Number) + function Base.show(io::IO, a::Complex{Num}) rr = unwrap(real(a)) ii = unwrap(imag(a)) @@ -48,53 +38,5 @@ function Base.show(io::IO, a::Complex{Num}) return print(io, arguments(rr)[1]) end - i = Sym{Real}(:im) - show(io, real(a) + i * imag(a)) -end - -function unwrap(a::Complex{<:Num}) - re, img = unwrap(real(a)), unwrap(imag(a)) - if re isa Real && img isa Real - return re + im * img - end - T = promote_type(symtype(re), symtype(img)) - ComplexTerm{T}(re, img) -end -wrap(a::ComplexTerm) = Complex(wrap.(arguments(a))...) -wrap(a::Symbolic{<:Complex}) = Complex(wrap(real(a)), wrap(imag(a))) - -SymbolicUtils.@number_methods( - ComplexTerm, - unwrap(f(wrap(a))), - unwrap(f(wrap(a), wrap(b))), - ) - -function Base.isequal(a::ComplexTerm{T}, b::ComplexTerm{S}) where {T,S} - T === S && isequal(a.re, b.re) && isequal(a.im, b.im) -end - -function Base.hash(a::ComplexTerm{T}, h::UInt) where T - hash(hash(a.im, hash(a.re, hash(T, hash(h ⊻ 0x1af5d7582250ac4a))))) -end - -Base.iszero(x::Complex{<:Num}) = iszero(real(x)) && iszero(imag(x)) -Base.isone(x::Complex{<:Num}) = isone(real(x)) && iszero(imag(x)) -_iszero(x::Complex{<:Num}) = _iszero(unwrap(x)) -_isone(x::Complex{<:Num}) = _isone(unwrap(x)) - -function SymbolicIndexingInterface.hasname(x::ComplexTerm) - a = arguments(unwrap(x.im))[1] - b = arguments(unwrap(x.re))[1] - return isequal(a, b) && hasname(a) -end - -function _getname(x::ComplexTerm, val) - a = arguments(unwrap(x.im))[1] - b = arguments(unwrap(x.re))[1] - if isequal(a, b) - return _getname(a, val) - end - if val == _fail - throw(ArgumentError("Variable $x doesn't have a name.")) - end + show(io, real(a) + IM * imag(a)) end diff --git a/src/diff.jl b/src/diff.jl index 08a73268d..7ab1eeed0 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -1,6 +1,3 @@ -abstract type Operator end -propagate_shape(::Operator, x) = axes(x) - """ $(TYPEDEF) @@ -37,19 +34,22 @@ struct Differential <: Operator end function (D::Differential)(x) x = unwrap(x) - if isarraysymbolic(x) - array_term(D, x) - else - term(D, x) - end + term(D, x) end (D::Differential)(x::Union{AbstractFloat, Integer}) = wrap(0) (D::Differential)(x::Union{Num, Arr}) = wrap(D(unwrap(x))) -(D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x))))) -SymbolicUtils.promote_symtype(::Differential, T) = T +(D::Differential)(x::Complex{Num}) = Complex{Num}(wrap(D(unwrap(real(x)))), wrap(D(unwrap(imag(x))))) SymbolicUtils.isbinop(f::Differential) = false +function (s::SymbolicUtils.Substituter)(x::Differential) + Differential(s(x.x)) +end + +function SymbolicUtils.operator_to_term(d::Differential, ex::BasicSymbolic{T}) where {T} + return diff2term(ex) +end + is_derivative(x) = iscall(x) ? operation(x) isa Differential : false Base.:*(D1, D2::Differential) = D1 ∘ D2 @@ -63,9 +63,6 @@ Base.nameof(D::Differential) = :Differential Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x) Base.hash(D::Differential, u::UInt) = hash(D.x, xor(u, 0xdddddddddddddddd)) -_isfalse(occ::Bool) = occ === false -_isfalse(occ::Symbolic) = iscall(occ) && _isfalse(operation(occ)) - """ $(TYPEDSIGNATURES) @@ -191,8 +188,6 @@ function _recursive_hasoperator(op, O) else if isadd(O) || ismul(O) any(_recursive_hasoperator(op), keys(O.dict)) - elseif ispow(O) - _recursive_hasoperator(op)(O.base) || _recursive_hasoperator(op)(O.exp) elseif isdiv(O) _recursive_hasoperator(op)(O.num) || _recursive_hasoperator(op)(O.den) else @@ -227,6 +222,23 @@ function Base.showerror(io::IO, err::DerivativeNotDefinedError) show(io, MIME"text/plain"(), err_str) end +function symdiff_substitute_filter(ex::BasicSymbolic{T}) where {T} + SymbolicUtils.default_substitute_filter(ex) || @match ex begin + BSImpl.Term(; f) && if f isa Differential end => true + _ => false + end +end + +""" + $(TYPEDSIGNATURES) + +Identical to `substitute` except it also substitutes inside `Differential` operator +applications. +""" +function substitute_in_deriv(ex, rules; kw...) + substitute(ex, rules; kw..., filterer = symdiff_substitute_filter) +end + """ executediff(D, arg, simplify=false; occurrences=nothing) @@ -237,7 +249,7 @@ passed differential and not any other Differentials it encounters. # Arguments - `D::Differential`: The differential to apply -- `arg::Symbolic`: The symbolic expression to apply the differential on. +- `arg::BasicSymbolic`: The symbolic expression to apply the differential on. - `simplify::Bool=false`: Whether to simplify the resulting expression using [`SymbolicUtils.simplify`](@ref). - `occurrences=nothing`: Information about the occurrences of the independent @@ -303,12 +315,12 @@ function executediff(D, arg, simplify=false; throw_no_derivative=false) c = 0 inner_function = arguments(arg)[1] if iscall(a) || isequal(a, D.x) - t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => a)) + t1 = substitute_in_deriv(inner_function, Dict(op.domain.variables => a)) t2 = executediff(D, a, simplify; throw_no_derivative) c -= t1*t2 end if iscall(b) || isequal(b, D.x) - t1 = SymbolicUtils.substitute(inner_function, Dict(op.domain.variables => b)) + t1 = substitute_in_deriv(inner_function, Dict(op.domain.variables => b)) t2 = executediff(D, b, simplify; throw_no_derivative) c += t1*t2 end @@ -349,7 +361,7 @@ function executediff(D, arg, simplify=false; throw_no_derivative=false) if _iszero(x) continue - elseif x isa Symbolic + elseif x isa BasicSymbolic push!(exprs, x) else c += x @@ -376,7 +388,7 @@ This function recursively traverses a symbolic expression, applying the chain ru and other derivative rules to expand any derivatives it encounters. # Arguments -- `O::Symbolic`: The symbolic expression to expand. +- `O::BasicSymbolic`: The symbolic expression to expand. - `simplify::Bool=false`: Whether to simplify the resulting expression using [`SymbolicUtils.simplify`](@ref). @@ -398,7 +410,7 @@ julia> dfx = expand_derivatives(Dx(f)) (k*((2abs(x - y)) / y - 2z)*ifelse(signbit(x - y), -1, 1)) / y ``` """ -function expand_derivatives(O::Symbolic, simplify=false; throw_no_derivative=false) +function expand_derivatives(O::BasicSymbolic, simplify=false; throw_no_derivative=false) if iscall(O) && isa(operation(O), Differential) arg = only(arguments(O)) arg = expand_derivatives(arg, false; throw_no_derivative) @@ -417,14 +429,12 @@ function expand_derivatives(n::Num, simplify=false; kwargs...) wrap(expand_derivatives(value(n), simplify; kwargs...)) end function expand_derivatives(n::Complex{Num}, simplify=false; kwargs...) - wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; kwargs...), - expand_derivatives(imag(n), simplify; kwargs...))) + re = expand_derivatives(real(n), simplify; kwargs...) + img = expand_derivatives(imag(n), simplify; kwargs...) + Complex{Num}(wrap(re), wrap(img)) end expand_derivatives(x, simplify=false; kwargs...) = x -_iszero(x) = false -_isone(x) = false - # Don't specialize on the function here """ $(SIGNATURES) @@ -461,7 +471,7 @@ sin(x) ``` """ derivative_idx(O::Any, ::Any) = 0 -function derivative_idx(O::Symbolic, idx) +function derivative_idx(O::BasicSymbolic, idx) iscall(O) ? derivative(operation(O), (arguments(O)...,), Val(idx)) : 0 end @@ -826,48 +836,48 @@ end hessian(O, vars::Arr; kwargs...) = hessian(O, collect(vars); kwargs...) -isidx(x) = x isa TermCombination +isidx(x) = unwrap_const(x) isa TermCombination -basic_mkterm(t, g, args, m) = metadata(Term{Any}(g, args), m) +basic_mkterm(t, g, args, m) = metadata(Term{VartypeT}(g, args; type = Any), m) const _scalar = one(TermCombination) -const linearity_rules = [ - @rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar) - @rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar) +const linearity_rules = ( + (@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)), + (@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)), - @rule (~f)(~x) => isidx(~x) ? combine_terms_1(linearity_1(~f), ~x) : _scalar - @rule (^)(~x::isidx, ~y) => ~y isa Number && isone(~y) ? ~x : (~x) * (~x) - @rule (~f)(~x, ~y) => combine_terms_2(linearity_2(~f), isidx(~x) ? ~x : _scalar, isidx(~y) ? ~y : _scalar) + (@rule (~f)(~x) => isidx(~x) ? combine_terms_1(linearity_1(~f), ~x) : _scalar), + (@rule (^)(~x::isidx, ~y) => ~y isa Number && isone(~y) ? ~x : (~x) * (~x)), + (@rule (~f)(~x, ~y) => combine_terms_2(linearity_2(~f), isidx(~x) ? ~x : _scalar, isidx(~y) ? ~y : _scalar)), - @rule ~x::issym => 0 + (@rule ~x::issym => 0), # `ifelse(cond, x, y)` can be written as cond * x + (1 - cond) * y # where condition `cond` is considered constant in differentiation - @rule ifelse(~cond, ~x, ~y) => (isidx(~x) ? ~x : _scalar) + (isidx(~y) ? ~y : _scalar) + (@rule ifelse(~cond, ~x, ~y) => (isidx(~x) ? ~x : _scalar) + (isidx(~y) ? ~y : _scalar)), # Fallback: Unknown functions with arbitrary number of arguments have non-zero partial derivatives # Functions with 1 and 2 arguments are already handled above - @rule (~f)(~~xs) => reduce(+, filter(isidx, ~~xs); init=_scalar)^2 -] -const linearity_rules_affine = [ - @rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar) - @rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar) + (@rule (~f)(~~xs) => reduce(+, filter(isidx, ~~xs); init=_scalar)^2), +) +const linearity_rules_affine = ( + (@rule +(~~xs) => reduce(+, filter(isidx, ~~xs), init=_scalar)), + (@rule *(~~xs) => reduce(*, filter(isidx, ~~xs), init=_scalar)), - @rule (~f)(~x) => isidx(~x) ? combine_terms_1(linearity_1(~f), ~x) : _scalar - @rule (^)(~x::isidx, ~y) => ~y isa Number && isone(~y) ? ~x : (~x) * (~x) - @rule (~f)(~x, ~y) => combine_terms_2(linearity_2(~f), isidx(~x) ? ~x : _scalar, isidx(~y) ? ~y : _scalar) + (@rule (~f)(~x) => isidx(~x) ? combine_terms_1(linearity_1(~f), ~x) : _scalar), + (@rule (^)(~x::isidx, ~y) => ~y isa Number && isone(~y) ? ~x : (~x) * (~x)), + (@rule (~f)(~x, ~y) => combine_terms_2(linearity_2(~f), isidx(~x) ? ~x : _scalar, isidx(~y) ? ~y : _scalar)), - @rule ~x::issym => 0 + (@rule ~x::issym => 0), # if the condition is dependent on the variable, do not consider this as affine - @rule ifelse(~cond::isidx, ~x, ~y) => (~cond)^2 + (@rule ifelse(~cond::isidx, ~x, ~y) => (~cond)^2), # `ifelse(cond, x, y)` can be written as cond * x + (1 - cond) * y # where condition `cond` is considered constant in differentiation - @rule ifelse(~cond::(!isidx), ~x, ~y) => (isidx(~x) ? ~x : _scalar) + (isidx(~y) ? ~y : _scalar) + (@rule ifelse(~cond::(!isidx), ~x, ~y) => (isidx(~x) ? ~x : _scalar) + (isidx(~y) ? ~y : _scalar)), # Fallback: Unknown functions with arbitrary number of arguments have non-zero partial derivatives # Functions with 1 and 2 arguments are already handled above - @rule (~f)(~~xs) => reduce(+, filter(isidx, ~~xs); init=_scalar)^2 -] + (@rule (~f)(~~xs) => reduce(+, filter(isidx, ~~xs); init=_scalar)^2), +) const linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules); maketerm=basic_mkterm)) const affine_linearity_propagator = Fixpoint(Postwalk(Chain(linearity_rules_affine); maketerm=basic_mkterm)) @@ -901,7 +911,7 @@ function hessian_sparsity(expr, vars::AbstractVector; full::Bool=true, linearity u = map(value, vars) dict = Dict(ui => TermCombination(Set([Dict(i=>1)])) for (i, ui) in enumerate(u)) f = Rewriters.Prewalk(x-> get(dict, x, x); maketerm=basic_mkterm)(expr) - lp = linearity_propagator(f) + lp = unwrap_const(linearity_propagator(f)) S = _sparse(lp, length(u)) S = full ? S : tril(S) end diff --git a/src/diffeqs/diffeq_helpers.jl b/src/diffeqs/diffeq_helpers.jl index 26e34823a..dc3b05b19 100644 --- a/src/diffeqs/diffeq_helpers.jl +++ b/src/diffeqs/diffeq_helpers.jl @@ -12,7 +12,7 @@ function _get_der_order(expr, x, t) return maximum(_get_der_order.(factors(expr), Ref(x), Ref(t))) end - return _get_der_order(substitute(expr, Dict(Differential(t)(x) => x)), x, t) + 1 + return _get_der_order(substitute_in_deriv(expr, Dict(Differential(t)(x) => x)), x, t) + 1 end # takes into account fractions @@ -45,7 +45,7 @@ function reduce_order(eq, x, t, ys) # reduction of order y_sub = Dict([[(Dt^i)(x) => ys[i+1] for i=0:n-1]; (Dt^n)(x) => variable(:𝒴)]) - eq = substitute(eq, y_sub) + eq = substitute_in_deriv(eq, y_sub) # isolate (Dt^n)(x) f = symbolic_linear_solve(eq, variable(:𝒴), check=false) @@ -60,7 +60,7 @@ function unreduce_order(expr, x, t, ys) Dt = Differential(t) rev_y_sub = Dict(ys[i] => (Dt^(i-1))(x) for i in eachindex(ys)) - return substitute(expr, rev_y_sub) + return substitute_in_deriv(expr, rev_y_sub) end function is_solution(solution, eq::Equation, x, t) @@ -76,9 +76,9 @@ function is_solution(solution, eq, x, t) return false end - expr = substitute(eq, Dict(x => solution)) + expr = substitute_in_deriv(eq, Dict(x => solution)) expr = expand(expand_derivatives(expr)) - return isequal(expr, 0) + return SymbolicUtils._iszero(expr) end function _parse_trig(expr, t) @@ -94,4 +94,4 @@ function _parse_trig(expr, t) end return nothing -end \ No newline at end of file +end diff --git a/src/diffeqs/diffeqs.jl b/src/diffeqs/diffeqs.jl index 0308f87ec..6948bbcb5 100644 --- a/src/diffeqs/diffeqs.jl +++ b/src/diffeqs/diffeqs.jl @@ -59,7 +59,7 @@ function is_linear_ode(expr, x, t) @assert n >= 1 "ODE must have at least one derivative" y_sub = Dict([[(Dt^i)(x) => ys[i+1] for i=0:n-1]; (Dt^n)(x) => variable(:𝒴)]) - expr = substitute(expr, y_sub) + expr = substitute_in_deriv(expr, y_sub) # isolate (Dt^n)(x) f = symbolic_linear_solve(expr, variable(:𝒴), check=false) @@ -291,7 +291,7 @@ function const_coeff_solve(eq::SymbolicLinearODE) @variables 𝓇 p = characteristic_polynomial(eq, 𝓇) roots = symbolic_solve(p, 𝓇, dropmultiplicity = false) - + roots = map(value, roots) # Handle complex + repeated roots solutions = exp.(roots * eq.t) for i in eachindex(solutions)[1:(end - 1)] @@ -327,7 +327,7 @@ function integrating_factor_solve(eq::SymbolicLinearODE) else v = exp(sympy_integrate(p, eq.t)) end - solution = (1 / v) * ((isequal(eq.q, 0) ? 0 : sympy_integrate(eq.q * v, eq.t)) + eq.C[1]) + solution = (1 / v) * ((SymbolicUtils._iszero(eq.q) ? 0 : sympy_integrate(eq.q * v, eq.t)) + eq.C[1]) if !isempty(Symbolics.get_variables(solution, variable(:Integral))) return nothing @@ -415,12 +415,12 @@ function exp_trig_particular_solution(eq::SymbolicLinearODE) @variables 𝓈 p = characteristic_polynomial(eq, 𝓈) Ds = Differential(𝓈) - while isequal(substitute(expand_derivatives((Ds^k)(p)), Dict(𝓈 => r+b*im)), 0) + while SymbolicUtils._iszero(substitute_in_deriv(expand_derivatives((Ds^k)(p)), Dict(𝓈 => r+b*im))) k += 1 end rrf = expand(simplify(a * exp((r + b * im) * eq.t) * eq.t^k / - (substitute(expand_derivatives((Ds^k)(p)), Dict(𝓈 => r+b*im))))) + (substitute_in_deriv(expand_derivatives((Ds^k)(p)), Dict(𝓈 => r+b*im))))) return is_sin ? imag(rrf) : real(rrf) end @@ -447,12 +447,12 @@ function resonant_response_formula(eq::SymbolicLinearODE) @variables 𝓈 p = characteristic_polynomial(eq, 𝓈) Ds = Differential(𝓈) - while isequal(substitute(expand_derivatives((Ds^k)(p)), Dict(𝓈 => r)), 0) + while SymbolicUtils._iszero(substitute_in_deriv(expand_derivatives((Ds^k)(p)), Dict(𝓈 => r))) k += 1 end return expand(simplify(a * exp(r * eq.t) * eq.t^k / - (substitute(expand_derivatives((Ds^k)(p)), Dict(𝓈 => r))))) + (substitute_in_deriv(expand_derivatives((Ds^k)(p)), Dict(𝓈 => r))))) end function method_of_undetermined_coefficients(eq::SymbolicLinearODE) @@ -466,7 +466,7 @@ function method_of_undetermined_coefficients(eq::SymbolicLinearODE) degree = max(Symbolics.degree(eq.q, eq.t), Symbolics.degree.(eq.p, eq.t)...) # just a starting point a = Symbolics.variables(:𝒶, 1:degree+1) form = sum(a[n]*eq.t^(n-1) for n = 1:degree+1) - eq_subbed = substitute(get_expression(eq), Dict(eq.x => form)) + eq_subbed = substitute_in_deriv(get_expression(eq), Dict(eq.x => form)) eq_subbed = eq_subbed.lhs - eq_subbed.rhs eq_subbed = expand_derivatives(eq_subbed) @@ -476,8 +476,8 @@ function method_of_undetermined_coefficients(eq::SymbolicLinearODE) coeff_solution = nothing end - if degree > 0 && coeff_solution !== nothing && !isempty(coeff_solution) && isequal(expand(substitute(eq_subbed, coeff_solution[1])), 0) - return substitute(form, coeff_solution[1]) + if degree > 0 && coeff_solution !== nothing && !isempty(coeff_solution) && SymbolicUtils._iszero(expand(substitute_in_deriv(eq_subbed, coeff_solution[1]))) + return substitute_in_deriv(form, coeff_solution[1]) end # exponential @@ -487,13 +487,13 @@ function method_of_undetermined_coefficients(eq::SymbolicLinearODE) r = coeff[2] form = a_form*exp(r*eq.t) - eq_subbed = substitute(get_expression(eq), Dict(eq.x => form)) + eq_subbed = substitute_in_deriv(get_expression(eq), Dict(eq.x => form)) eq_subbed = expand_derivatives(eq_subbed) eq_subbed = simplify(expand((eq_subbed.lhs - eq_subbed.rhs) / exp(r*eq.t))) coeff_solution = solve_interms_ofvar(eq_subbed, eq.t) if coeff_solution !== nothing && !isempty(coeff_solution) - return substitute(form, coeff_solution[1]) + return substitute_in_deriv(form, coeff_solution[1]) end end @@ -505,9 +505,9 @@ function method_of_undetermined_coefficients(eq::SymbolicLinearODE) if parsed !== nothing ω = parsed[1] form = 𝒶*cos(ω*eq.t) + 𝒷*sin(ω*eq.t) - eq_subbed = substitute(get_expression(eq), Dict(eq.x => form)) + eq_subbed = substitute_in_deriv(get_expression(eq), Dict(eq.x => form)) eq_subbed = expand_derivatives(eq_subbed) - eq_subbed = expand(substitute(eq_subbed.lhs - eq_subbed.rhs, Dict(cos(ω*eq.t)=>𝒸𝓈, sin(ω*eq.t)=>𝓈𝓃))) + eq_subbed = expand(substitute_in_deriv(eq_subbed.lhs - eq_subbed.rhs, Dict(cos(ω*eq.t)=>𝒸𝓈, sin(ω*eq.t)=>𝓈𝓃))) cos_eq = simplify(sum(filter(term -> !isempty(Symbolics.get_variables(term, 𝒸𝓈)), terms(eq_subbed)))/𝒸𝓈) sin_eq = simplify(sum(filter(term -> !isempty(Symbolics.get_variables(term, 𝓈𝓃)), terms(eq_subbed)))/𝓈𝓃) if !isempty(Symbolics.get_variables(cos_eq, [eq.t,𝓈𝓃,𝒸𝓈])) || !isempty(Symbolics.get_variables(sin_eq, [eq.t,𝓈𝓃,𝒸𝓈])) @@ -517,7 +517,7 @@ function method_of_undetermined_coefficients(eq::SymbolicLinearODE) end if coeff_solution !== nothing && !isempty(coeff_solution) - return substitute(form, coeff_solution[1]) + return substitute_in_deriv(form, coeff_solution[1]) end end end @@ -545,18 +545,18 @@ function solve_symbolic_IVP(ivp::IVP) for i in eachindex(ivp.initial_conditions) eq::Num = expand_derivatives((Dt(ivp.eq)^(i-1))(general_solution)) - ivp.initial_conditions[i] - eq = substitute(eq, Dict(ivp.eq.t => 0), fold=false) + eq = substitute_in_deriv(eq, Dict(ivp.eq.t => 0), fold=false) # make sure exp, sin, and cos don't evaluate to floats - exp0 = substitute(exp(ivp.eq.t), Dict(ivp.eq.t => 0), fold=false) - sin0 = substitute(sin(ivp.eq.t), Dict(ivp.eq.t => 0), fold=false) - cos0 = substitute(cos(ivp.eq.t), Dict(ivp.eq.t => 0), fold=false) + exp0 = substitute_in_deriv(exp(ivp.eq.t), Dict(ivp.eq.t => 0), fold=false) + sin0 = substitute_in_deriv(sin(ivp.eq.t), Dict(ivp.eq.t => 0), fold=false) + cos0 = substitute_in_deriv(cos(ivp.eq.t), Dict(ivp.eq.t => 0), fold=false) - eq = expand(simplify(substitute(eq, Dict(exp0 => 1, sin0 => 0, cos0 => 1), fold=false))) + eq = expand(simplify(substitute_in_deriv(eq, Dict(exp0 => 1, sin0 => 0, cos0 => 1), fold=false))) push!(eqs, eq) end - return expand(simplify(substitute(general_solution, symbolic_solve(eqs, ivp.eq.C)[1]))) + return expand(simplify(substitute_in_deriv(general_solution, symbolic_solve(eqs, ivp.eq.C)[1]))) end """ @@ -622,7 +622,7 @@ function solve_clairaut(expr, x, t) end C = Symbolics.variable(:C, 1) # constant of integration - f = substitute(f, Dict(Dt(x) => C)) + f = substitute_in_deriv(f, Dict(Dt(x) => C)) if !isempty(Symbolics.get_variables(f, [x])) return nothing end @@ -650,7 +650,7 @@ function linearize_bernoulli(expr, x, t, v) if Symbolics.hasderiv(Symbolics.value(term)) facs = _true_factors(term) leading_coeff = prod(filter(fac -> !Symbolics.hasderiv(Symbolics.value(fac)), facs)) - if !isequal(term//leading_coeff, Dt(x)) + if !isequal(term/leading_coeff, Dt(x)) return nothing end elseif !isempty(Symbolics.get_variables(term, [x])) @@ -670,7 +670,11 @@ function linearize_bernoulli(expr, x, t, v) end p //= leading_coeff - q //= leading_coeff + if q isa Union{Num, BasicSymbolic{VartypeT}} + q /= leading_coeff + else + q //= leading_coeff + end return SymbolicLinearODE(v, t, [p*(1-n)], q*(1-n)), n end @@ -693,4 +697,4 @@ function solve_bernoulli(expr, x, t) end return simplify(solution^(1//(1-n))) -end \ No newline at end of file +end diff --git a/src/diffeqs/systems.jl b/src/diffeqs/systems.jl index 3f586a875..41652f30c 100644 --- a/src/diffeqs/systems.jl +++ b/src/diffeqs/systems.jl @@ -93,10 +93,9 @@ function symbolic_eigen(A::Matrix{<:Number}) # find eigenvalues first p = det(λ*I - A) ~ 0 # polynomial to solve values = symbolic_solve(p, λ) # solve polynomial - + values = map(unwrap_const, values) # then, find eigenvectors S::Matrix{Number} = Matrix(I, size(A, 1), 0) # matrix storing vertical eigenvectors - for value in values eqs = (value*I - A) * v# .~ zeros(size(A, 1)) # equations to give eigenvectors eqs = substitute(eqs, Dict(v[1] => 1)) # set first element to 1 to constrain solution space @@ -107,6 +106,7 @@ function symbolic_eigen(A::Matrix{<:Number}) if sol[1] isa Dict sol = [sol[1][var] for var in v[2:end]] end + sol = map(unwrap_const, sol) vec::Vector{Number} = prepend!(sol, [1]) # add back the 1 (representing v_1) from substitution S = [S vec] # add vec to matrix end diff --git a/src/difference.jl b/src/difference.jl deleted file mode 100644 index b3921d01d..000000000 --- a/src/difference.jl +++ /dev/null @@ -1,63 +0,0 @@ -""" -$(TYPEDEF) - -Represents a difference operator. - -# Fields -$(FIELDS) - -# Examples - -```jldoctest -julia> using Symbolics - -julia> @variables t; - -julia> Δ = Difference(t; dt=0.01) -(::Difference) (generic function with 2 methods) -``` -""" -struct Difference <: Operator - """Fixed Difference""" - t - dt - update::Bool - Difference(t; dt, update=false) = new(value(t), dt, update) -end -(D::Difference)(t) = Term{symtype(t)}(D, [t]) -(D::Difference)(t::Num) = Num(D(value(t))) -SymbolicUtils.promote_symtype(::Difference, t) = t -""" -$(SIGNATURES) - -Represents a discrete update (shift) operator with the semantics -``` -DiscreteUpdate(t; dt=0.01)(y) ~ y(t+dt) -``` - -# Examples - -```jldoctest -julia> using Symbolics - -julia> @variables t; - -julia> U = DiscreteUpdate(t; dt=0.01) -(::Difference) (generic function with 2 methods) -``` -""" -DiscreteUpdate(t; dt) = Difference(t; dt=dt, update=true) - -Base.show(io::IO, D::Difference) = print(io, "Difference(", D.t, "; dt=", D.dt, ", update=", D.update, ")") - -Base.:(==)(D1::Difference, D2::Difference) = isequal(D1.t, D2.t) && isequal(D1.dt, D2.dt) && isequal(D1.update, D2.update) -Base.hash(D::Difference, u::UInt) = hash(D.dt, hash(D.t, xor(u, 0x055640d6d952f101))) - -Base.:^(D::Difference, n::Integer) = _repeat_apply(D, n) - -""" - hasdiff(O) - -Returns true if the expression or equation `O` contains [`Difference`](@ref) terms (this include [`DiscreteUpdate`](@ref)). -""" -hasdiff(O) = recursive_hasoperator(Difference, O) \ No newline at end of file diff --git a/src/domains.jl b/src/domains.jl index 02d23f3bb..c3f07c6c6 100644 --- a/src/domains.jl +++ b/src/domains.jl @@ -6,7 +6,7 @@ struct VarDomainPairing domain::Domain end -const DomainedVar = Union{Symbolic{<:Number}, Num} +const DomainedVar = Union{BasicSymbolic, Num} Base.:∈(variable::DomainedVar,domain::Domain) = VarDomainPairing(value(variable),domain) Base.:∈(variable::DomainedVar,domain::Interval) = VarDomainPairing(value(variable),domain) diff --git a/src/equations.jl b/src/equations.jl index 4f7f300b5..6c1565b72 100644 --- a/src/equations.jl +++ b/src/equations.jl @@ -2,103 +2,6 @@ const NAMESPACE_SEPARATOR = '₊' hide_lhs(_) = false -### -### Connection -### -struct Connection - systems -end -Base.broadcastable(x::Connection) = Ref(x) -Connection() = Connection(nothing) -Base.hash(c::Connection, seed::UInt) = hash(c.systems, (0xc80093537bdc1311 % UInt) ⊻ seed) -hide_lhs(_::Connection) = true - -function connect(sys1, sys2, syss...) - syss = (sys1, sys2, syss...) - length(unique(nameof, syss)) == length(syss) || error("connect takes distinct systems!") - Equation(Connection(), Connection(syss)) # the RHS are connected systems -end - -function Base.show(io::IO, c::Connection) - print(io, "connect(") - if c.systems isa AbstractArray || c.systems isa Tuple - n = length(c.systems) - for (i, s) in enumerate(c.systems) - str = join(split(string(nameof(s)), NAMESPACE_SEPARATOR), '.') - print(io, str) - i != n && print(io, ", ") - end - end - print(io, ")") -end - -### -### State machine -### -_nameof(s) = nameof(s) -_nameof(s::Union{Int, Symbol}) = s -abstract type StateMachineOperator end -Base.broadcastable(x::StateMachineOperator) = Ref(x) -hide_lhs(_::StateMachineOperator) = true -struct InitialState <: StateMachineOperator - s -end -Base.show(io::IO, s::InitialState) = print(io, "initial_state(", _nameof(s.s), ")") -initial_state(s) = Equation(InitialState(nothing), InitialState(s)) - -Base.@kwdef struct Transition{A, B, C} <: StateMachineOperator - from::A = nothing - to::B = nothing - cond::C = nothing - immediate::Bool = true - reset::Bool = true - synchronize::Bool = false - priority::Int = 1 - function Transition(from, to, cond, immediate, reset, synchronize, priority) - cond = unwrap(cond) - new{typeof(from), typeof(to), typeof(cond)}(from, to, cond, immediate, - reset, synchronize, - priority) - end -end -function Base.:(==)(transition1::Transition, transition2::Transition) - transition1.from == transition2.from && - transition1.to == transition2.to && - isequal(transition1.cond, transition2.cond) && - transition1.immediate == transition2.immediate && - transition1.reset == transition2.reset && - transition1.synchronize == transition2.synchronize && - transition1.priority == transition2.priority -end - -""" - transition(from, to, cond; immediate::Bool = true, reset::Bool = true, synchronize::Bool = false, priority::Int = 1) - -Create a transition from state `from` to state `to` that is enabled when transitioncondition `cond` evaluates to `true`. - -# Arguments: -- `from`: The source state of the transition. -- `to`: The target state of the transition. -- `cond`: A transition condition that evaluates to a Bool, such as `ticksInState() >= 2`. -- `immediate`: If `true`, the transition will fire at the same tick as it becomes true, if `false`, the actions of the state are evaluated first, and the transition fires during the next tick. -- `reset`: If true, the destination state `to` is reset to its initial condition when the transition fires. -- `synchronize`: If true, the transition will only fire if all sub-state machines in the source state are in their final (terminal) state. A final state is one that has no outgoing transitions. -- `priority`: If a state has more than one outgoing transition, all outgoing transitions must have a unique priority. The transitions are evaluated in priority order, i.e., the transition with priority 1 is evaluated first. -""" -function transition(from, to, cond; - immediate::Bool = true, reset::Bool = true, synchronize::Bool = false, - priority::Int = 1) - Equation(Transition(), Transition(; from, to, cond, immediate, reset, - synchronize, priority)) -end -function Base.show(io::IO, s::Transition) - print(io, _nameof(s.from), " → ", _nameof(s.to), " if (", s.cond, ") [") - print(io, "immediate: ", Int(s.immediate), ", ") - print(io, "reset: ", Int(s.reset), ", ") - print(io, "sync: ", Int(s.synchronize), ", ") - print(io, "prio: ", s.priority, "]") -end - """ $(TYPEDEF) @@ -127,7 +30,7 @@ function Base.show(io::IO, eq::Equation) end end -scalarize(eq::Equation) = scalarize(eq.lhs) .~ scalarize(eq.rhs) +SymbolicUtils.scalarize(eq::Equation) = scalarize(eq.lhs) .~ scalarize(eq.rhs) SymbolicUtils.simplify(x::Equation; kw...) = simplify(x.lhs; kw...) ~ simplify(x.rhs; kw...) # ambiguity for T in [:Pair, :Any] @@ -201,11 +104,9 @@ end canonical_form(eq::Equation) = eq.lhs - eq.rhs ~ 0 -get_variables(eq::Equation) = unique(vcat(get_variables(eq.lhs), get_variables(eq.rhs))) - -struct ConstrainedEquation - constraints - eq +function SymbolicUtils.search_variables!(buffer, eq::Equation; kw...) + SymbolicUtils.search_variables!(buffer, eq.lhs; kw...) + SymbolicUtils.search_variables!(buffer, eq.rhs; kw...) end function expand_derivatives(eq::Equation, simplify=false) diff --git a/src/extra_functions.jl b/src/extra_functions.jl index 0adc291e3..ed8790f3f 100644 --- a/src/extra_functions.jl +++ b/src/extra_functions.jl @@ -1,42 +1,14 @@ -@register_symbolic Base.binomial(n, k)::Int true -function _binomial(nothing, n, k) - begin - args = [n, k] - unwrapped_args = map(Symbolics.unwrap, args) - res = if !(any((x->begin - SymbolicUtils.issym(x) || SymbolicUtils.iscall(x) - end), unwrapped_args)) - Base.binomial(unwrapped_args...) - else - SymbolicUtils.Term{Int}(Base.binomial, unwrapped_args) - end - if typeof.(args) == typeof.(unwrapped_args) - return res - else - return Symbolics.wrap(res) - end +for (T1, T2) in Iterators.product([Number, BasicSymbolic{VartypeT}, Num], [Integer, BasicSymbolic{VartypeT}, Num]) + if T1 != Num && T2 != Num + continue end -end - -for (T1, T2) in ((Symbolics.SymbolicUtils.Symbolic{<:Real}, Int64), - (Num, Int64), - (Real, Symbolics.SymbolicUtils.Symbolic{<:Int64}), - (Symbolics.SymbolicUtils.Symbolic{<:Real}, Symbolics.SymbolicUtils.Symbolic{<:Int64}), - (Num, Symbolics.SymbolicUtils.Symbolic{<:Int64})) - - @eval function Base.binomial(n::$T1, k::$T2) - if any(Symbolics.iswrapped, (n, k)) - Symbolics.wrap(_binomial(nothing, Symbolics.unwrap(n), Symbolics.unwrap(k))) - else - _binomial(nothing, n, k) - end + @eval function Base.binomial(a::$T1, b::$T2) + binomial(unwrap(a), unwrap(b)) end end -@register_symbolic Base.sign(x)::Int derivative(::typeof(sign), args::NTuple{1,Any}, ::Val{1}) = 0 -@register_symbolic Base.signbit(x)::Bool derivative(::typeof(signbit), args::NTuple{1,Any}, ::Val{1}) = 0 derivative(::typeof(abs), args::NTuple{1,Any}, ::Val{1}) = ifelse(signbit(args[1]),-one(args[1]),one(args[1])) @@ -57,10 +29,6 @@ function derivative(::typeof(max), args::NTuple{2,Any}, ::Val{2}) ifelse(x > y, zero(y), one(y)) end -@register_symbolic Base.ceil(x) -@register_symbolic Base.floor(x) -@register_symbolic Base.factorial(x) - function derivative(::Union{typeof(ceil),typeof(floor),typeof(factorial)}, args::NTuple{1,Any}, ::Val{1}) zero(args[1]) end @@ -68,7 +36,14 @@ end @register_symbolic Base.rand(x) @register_symbolic Base.randn(x) -@register_symbolic Base.clamp(x, y, z) +for (T1, T2, T3) in Iterators.product(Iterators.repeated((Num, BasicSymbolic{VartypeT}, Number), 3)...) + if T1 != Num && T2 != Num && T3 != Num + continue + end + @eval function Base.clamp(a::$T1, b::$T2, c::$T3) + wrap(clamp(unwrap(a), unwrap(b), unwrap(c))) + end +end function derivative(::typeof(Base.clamp), args::NTuple{3, Any}, ::Val{1}) x, l, h = args @@ -78,21 +53,28 @@ function derivative(::typeof(Base.clamp), args::NTuple{3, Any}, ::Val{1}) ifelse(xh, z, o)) end -@register_symbolic Distributions.pdf(dist,x) -@register_symbolic Distributions.logpdf(dist,x) -@register_symbolic Distributions.cdf(dist,x) -@register_symbolic Distributions.logcdf(dist,x) -@register_symbolic Distributions.quantile(dist,x) - -@register_symbolic Distributions.Uniform(mu,sigma) false -@register_symbolic Distributions.Normal(mu,sigma) false - -@register_symbolic ∈(x::Real, y::AbstractArray)::Bool -@register_symbolic ∪(x, y) -@register_symbolic ∩(x, y) -@register_symbolic ∨(x, y) -@register_symbolic ∧(x, y) -@register_symbolic ⊆(x, y) +for T1 in [Real, Num, BasicSymbolic{VartypeT}], T2 in [AbstractArray, Arr, BasicSymbolic{VartypeT}] + if T1 != Num && T2 != Arr + continue + end + @eval function Base.in(x::$T1, y::$T2) + return in(unwrap(x), unwrap(y)) + end +end +for (T1, T2) in Iterators.product(Iterators.repeated([AbstractArray, Arr, BasicSymbolic{VartypeT}], 2)...) + if T1 != Arr && T2 != Arr + continue + end + @eval function Base.union(a::$T1, b::$T2) + union(unwrap(a), unwrap(b)) + end + @eval function Base.intersect(a::$T1, b::$T2) + intersect(unwrap(a), unwrap(b)) + end + @eval function Base.issubset(a::$T1, b::$T2) + issubset(unwrap(a), unwrap(b)) + end +end LinearAlgebra.norm(x::Num, p::Real) = abs(x) @@ -103,11 +85,4 @@ derivative(::typeof(>=), ::NTuple{2, Any}, ::Val{i}) where {i} = 0 derivative(::typeof(==), ::NTuple{2, Any}, ::Val{i}) where {i} = 0 derivative(::typeof(!=), ::NTuple{2, Any}, ::Val{i}) where {i} = 0 -@register_symbolic SpecialFunctions.expinti(x::Real) derivative(::typeof(expinti), args::NTuple{1,Any}, ::Val{1}) = exp(args[1])/args[1] - -@register_symbolic SpecialFunctions.expint(nu, z) - -@register_symbolic SpecialFunctions.sinint(x) - -@register_symbolic SpecialFunctions.cosint(x) diff --git a/src/inequality.jl b/src/inequality.jl index c38296e46..370a55bd8 100644 --- a/src/inequality.jl +++ b/src/inequality.jl @@ -23,7 +23,7 @@ Base.hash(a::Inequality, salt::UInt) = hash(a.lhs, hash(a.rhs, hash(a.relational @enum RelationalOperator leq geq # strict less than or strict greater than are not supported by any solver -function scalarize(ineq::Inequality) +function SymbolicUtils.scalarize(ineq::Inequality) if ineq.relational_op == leq scalarize(ineq.lhs) ≲ scalarize(ineq.rhs) else @@ -108,7 +108,10 @@ function canonical_form(cs::Inequality; form=leq) end end -get_variables(ineq::Inequality) = unique(vcat(get_variables(ineq.lhs), get_variables(ineq.rhs))) +function SymbolicUtils.search_variables!(buffer, ineq::Inequality; kw...) + search_variables!(buffer, ineq.lhs; kw...) + search_variables!(buffer, ineq.rhs; kw...) +end SymbolicUtils.simplify(cs::Inequality; kw...) = Inequality(simplify(cs.lhs; kw...), simplify(cs.rhs; kw...), cs.relational_op) diff --git a/src/integral.jl b/src/integral.jl index 0995330c5..d335582dd 100644 --- a/src/integral.jl +++ b/src/integral.jl @@ -30,8 +30,8 @@ function (I::Integral)(x::Union{Rational, AbstractIrrational, AbstractFloat, Int a, b = value.(DomainSets.endpoints(domain)) wrap((b - a)*x) end -(I::Integral)(x::Complex) = wrap(ComplexTerm{Real}(I(unwrap(real(x))), I(unwrap(imag(x))))) -(I::Integral)(x) = Term{SymbolicUtils.symtype(x)}(I, [x]) +(I::Integral)(x::Complex) = Complex{Num}(wrap(I(unwrap(real(x)))), wrap(I(unwrap(imag(x))))) +(I::Integral)(x) = Term{VartypeT}(I, [x]; type = SymbolicUtils.symtype(x), shape = SymbolicUtils.shape(x)) (I::Integral)(x::Num) = Num(I(Symbolics.value(x))) SymbolicUtils.promote_symtype(::Integral, x) = x diff --git a/src/latexify_recipes.jl b/src/latexify_recipes.jl index 30d0b5ca6..aa9b09cec 100644 --- a/src/latexify_recipes.jl +++ b/src/latexify_recipes.jl @@ -64,7 +64,7 @@ recipe(n) = latexify_derivatives(cleanup_exprs(_toexpr(n))) snakecase --> true safescripts --> true - return recipe(n) + return recipe(value(n)) end @latexrecipe function f(z::Complex{Num}) @@ -72,16 +72,9 @@ end mult_symbol --> "" index --> :subscript - iszero(z.im) && return :($(recipe(z.re))) - iszero(z.re) && return :($(recipe(z.im)) * $im) - return :($(recipe(z.re)) + $(recipe(z.im)) * $im) -end - -@latexrecipe function f(n::ArrayOp) - env --> :equation - mult_symbol --> "" - index --> :subscript - return recipe(n.term) + iszero(z.im) && return :($(recipe(value(z.re)))) + iszero(z.re) && return :($(recipe(value(z.im))) * $im) + return :($(recipe(value(z.re))) + $(recipe(value(z.im))) * $im) end @latexrecipe function f(n::Function) @@ -98,18 +91,10 @@ end mult_symbol --> "" index --> :subscript - return unwrap(n) + return value(n) end -@latexrecipe function f(n::CallWithMetadata) - env --> :equation - mult_symbol --> "" - index --> :subscript - - return n.f -end - -@latexrecipe function f(n::Symbolic) +@latexrecipe function f(n::BasicSymbolic) env --> :equation mult_symbol --> "" index --> :subscript @@ -133,28 +118,23 @@ end env --> :equation index --> :subscript - if hide_lhs(eq.lhs) || !(eq.lhs isa Union{Number, AbstractArray, Symbolic}) - return eq.rhs + if hide_lhs(eq.lhs) || !(eq.lhs isa Union{Number, AbstractArray, BasicSymbolic}) + return value(eq.rhs) else return Expr(:(=), Num(eq.lhs), Num(eq.rhs)) end end -@latexrecipe function f(c::Connection) - index --> :subscript - return Expr(:call, :connect, map(nameof, c.systems)...) -end - Base.show(io::IO, ::MIME"text/latex", x::RCNum) = print(io, "\$\$ " * latexify(x) * " \$\$") -Base.show(io::IO, ::MIME"text/latex", x::Symbolic) = print(io, "\$\$ " * latexify(x) * " \$\$") +Base.show(io::IO, ::MIME"text/latex", x::BasicSymbolic) = print(io, "\$\$ " * latexify(x) * " \$\$") Base.show(io::IO, ::MIME"text/latex", x::Equation) = print(io, "\$\$ " * latexify(x) * " \$\$") Base.show(io::IO, ::MIME"text/latex", x::Vector{Equation}) = print(io, "\$\$ " * latexify(x) * " \$\$") Base.show(io::IO, ::MIME"text/latex", x::AbstractArray{<:RCNum}) = print(io, "\$\$ " * latexify(x) * " \$\$") -_toexpr(O::ArrayOp; latexwrapper = default_latex_wrapper) = _toexpr(O.term; latexwrapper) - # `_toexpr` is only used for latexify function _toexpr(O; latexwrapper = default_latex_wrapper) + O = unwrap(O) + SymbolicUtils.isconst(O) && return value(O) if ismul(O) m = O numer = Any[] @@ -167,9 +147,8 @@ function _toexpr(O; latexwrapper = default_latex_wrapper) push!(numer, _toexpr(term)) continue end - - base = term.base - pow = term.exp + base, pow = arguments(term) + pow = value(pow) isneg = (pow isa Number && pow < 0) || (iscall(pow) && operation(pow) === (-) && length(arguments(pow)) == 1) if !isneg if _isone(pow) diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index aef33569f..39a8601d4 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -7,7 +7,7 @@ function nterms(t) end # Soft pivoted -# Note: we call this function with a matrix of Union{SymbolicUtils.Symbolic, Any} +# Note: we call this function with a matrix of Union{SymbolicUtils.BasicSymbolic, Any} function sym_lu(A; check=true) SINGULAR = typemax(Int) m, n = size(A) @@ -28,7 +28,7 @@ function sym_lu(A; check=true) p[k] = kp - if amin == SINGULAR && !(amin isa Symbolic) && (amin isa Number) && iszero(info) + if amin == SINGULAR && !(amin isa BasicSymbolic) && (amin isa Number) && iszero(info) info = k end @@ -132,7 +132,7 @@ function _solve(A::AbstractMatrix, b::AbstractArray, do_simplify) do_simplify ? SymbolicUtils.simplify_fractions.(sol) : sol end -LinearAlgebra.ldiv!(A::UpperTriangular{<:Union{Symbolic,RCNum}}, b::AbstractVector{<:Union{Symbolic,RCNum}}, x::AbstractVector{<:Union{Symbolic,RCNum}} = b) = symsub!(A, b, x) +LinearAlgebra.ldiv!(A::UpperTriangular{<:Union{BasicSymbolic,RCNum}}, b::AbstractVector{<:Union{BasicSymbolic,RCNum}}, x::AbstractVector{<:Union{BasicSymbolic,RCNum}} = b) = symsub!(A, b, x) function symsub!(A::UpperTriangular, b::AbstractVector, x::AbstractVector = b) LinearAlgebra.require_one_based_indexing(A, b, x) n = size(A, 2) @@ -152,7 +152,7 @@ function symsub!(A::UpperTriangular, b::AbstractVector, x::AbstractVector = b) x end -LinearAlgebra.ldiv!(A::UnitLowerTriangular{<:Union{Symbolic,RCNum}}, b::AbstractVector{<:Union{Symbolic,RCNum}}, x::AbstractVector{<:Union{Symbolic,RCNum}} = b) = symsub!(A, b, x) +LinearAlgebra.ldiv!(A::UnitLowerTriangular{<:Union{BasicSymbolic,RCNum}}, b::AbstractVector{<:Union{BasicSymbolic,RCNum}}, x::AbstractVector{<:Union{BasicSymbolic,RCNum}} = b) = symsub!(A, b, x) function symsub!(A::UnitLowerTriangular, b::AbstractVector, x::AbstractVector = b) LinearAlgebra.require_one_based_indexing(A, b, x) n = size(A, 2) @@ -220,7 +220,7 @@ end function LinearAlgebra.norm(x::AbstractArray{<:RCNum}, p::Real=2) p = value(p) - issym = p isa Symbolic + issym = p isa BasicSymbolic if !issym && p == 2 sqrt(sum(x->abs2(x), x)) elseif !issym && isone(p) @@ -258,7 +258,7 @@ function linear_expansion(ts::AbstractArray, xs::AbstractArray) @label FINISH return A, bvec, islinear end -# _linear_expansion always returns `Symbolic` +# _linear_expansion always returns `BasicSymbolic` function _linear_expansion(t::Equation, x) a₂, b₂, islinear = linear_expansion(t.rhs, x) islinear || return (a₂, b₂, false) @@ -272,7 +272,7 @@ is_expansion_leaf(t) = !iscall(t) || (operation(t) isa Operator) @noinline expansion_check(op) = op isa Operator && error("The operation is an Operator. This should never happen.") function _linear_expansion(t, x) t = value(t) - t isa Symbolic || return (0, t, true) + t isa BasicSymbolic || return (0, t, true) x = value(x) is_expansion_leaf(t) && return trivial_linear_expansion(t, x) isequal(t, x) && return (1, 0, true) @@ -332,9 +332,9 @@ function _linear_expansion(t, x) elseif op === getindex arrt, idxst... = arguments(t) isequal(arrt, arrx) && return (0, t, true) - shape(arrt) == Unknown() && return (0, t, true) + shape(arrt) isa SymbolicUtils.Unknown && return (0, t, true) - indexed_t = OffsetArrays.Origin(map(first, axes(arrt)))(Symbolics.scalarize(arrt))[idxst...] + indexed_t = OffsetArrays.Origin(map(first, axes(arrt)))(Symbolics.scalarize(arrt))[unwrap_const.(idxst)...] # when indexing a registered function/callable symbolic # scalarizing and indexing leads to the same symbolic variable # which causes a StackOverflowError without this @@ -388,7 +388,6 @@ end ### # Pretty much just copy-pasted from stdlib -SparseArrays.SparseMatrixCSC{Tv,Ti}(M::StridedMatrix) where {Tv<:RCNum,Ti} = _sparse(Tv, Ti, M) function _sparse(::Type{Tv}, ::Type{Ti}, M) where {Tv, Ti} nz = count(!_iszero, M) colptr = zeros(Ti, size(M, 2) + 1) diff --git a/src/linearity.jl b/src/linearity.jl index 3eae625cc..487a0f2a4 100644 --- a/src/linearity.jl +++ b/src/linearity.jl @@ -64,6 +64,22 @@ end # to make Mul and Add work Base.:*(::Number, comb::TermCombination) = comb +function Base.:*(x::BasicSymbolic{VartypeT}, comb::TermCombination) + @assert SymbolicUtils.isconst(x) + unwrap_const(x) * comb +end +function Base.:*(comb::TermCombination, x::BasicSymbolic{VartypeT}) + @assert SymbolicUtils.isconst(x) + comb * unwrap_const(x) +end +function Base.:+(x::BasicSymbolic{VartypeT}, comb::TermCombination) + @assert SymbolicUtils.isconst(x) + unwrap_const(x) + comb +end +function Base.:+(comb::TermCombination, x::BasicSymbolic{VartypeT}) + @assert SymbolicUtils.isconst(x) + comb + unwrap_const(x) +end function Base.:^(comb::TermCombination, ::Number) isone(comb) && return comb iszero(comb) && return _scalar diff --git a/src/num.jl b/src/num.jl index 9eaf4af7b..c61709ab7 100644 --- a/src/num.jl +++ b/src/num.jl @@ -1,10 +1,22 @@ @symbolic_wrap struct Num <: Real - val::Any + val::BasicSymbolic{VartypeT} + + function Num(ex) + # need `<: Number` instead of `<: Real` to allow the primitive `@number_methods` + # methods below to infer. They could be made to infer `Union{Complex{Num}, Num}` + # by manually checking the `symtype` of the result and branching instead of using + # `wrap`. However, this causes issues with LinearAlgebra methods because it + # preallocates a buffer using the inferred result type, and then tries to + # e.g. set an integer into it, which fails because `convert(::Type{Union{..}}, ::T)` + # doesn't work. + @assert symtype(ex) <: Number + return new(Const{VartypeT}(ex)) + end end const RCNum = Union{Num, Complex{Num}} -unwrap(x::Num) = x.val +SymbolicUtils.unwrap(x::Num) = x.val """ Num(val) @@ -17,7 +29,6 @@ const show_numwrap = Ref(false) Num(x::Num) = x # ideally this should never be called (n::Num)(args...) = Num(value(n)(map(value, args)...)) -value(x) = unwrap(x) SymbolicUtils.@number_methods(Num, Num(f(value(a))), @@ -31,6 +42,10 @@ Base.typemin(::Type{Num}) = Num(-Inf) Base.typemax(::Type{Num}) = Num(Inf) Base.float(x::Num) = x +function SymbolicUtils.search_variables!(buffer, expr::Num; kw...) + SymbolicUtils.search_variables!(buffer, unwrap(expr); kw...) +end + """ ifelse(cond::Num, x, y) @@ -89,15 +104,11 @@ end substitute(expr, s::Pair; kw...) = substituter([s[1] => s[2]])(expr; kw...) substitute(expr, s::Vector; kw...) = substituter(s)(expr; kw...) -function _unwrap_callwithmeta(x) - x = value(x) - return x isa CallWithMetadata ? x.f : x -end function subrules_to_dict(pairs) if pairs isa Pair pairs = (pairs,) end - return Dict(_unwrap_callwithmeta(k) => value(v) for (k, v) in pairs) + return Dict(k => value(v) for (k, v) in pairs) end function substituter(pairs) dict = subrules_to_dict(pairs) @@ -110,7 +121,7 @@ Base.nameof(n::Num) = nameof(value(n)) Base.iszero(x::Num) = SymbolicUtils.fraction_iszero(unwrap(x)) Base.isone(x::Num) = SymbolicUtils.fraction_isone(unwrap(x)) -import SymbolicUtils: <ₑ, Symbolic, Term, operation, arguments +import SymbolicUtils: <ₑ, Term, operation, arguments function Base.show(io::IO, n::Num) show_numwrap[] ? print(io, :(Num($(value(n))))) : Base.show(io, value(n)) @@ -118,25 +129,14 @@ end Base.promote_rule(::Type{<:Number}, ::Type{<:Num}) = Num Base.promote_rule(::Type{BigFloat}, ::Type{<:Num}) = Num -Base.promote_rule(::Type{<:Symbolic{<:Number}}, ::Type{<:Num}) = Num -function Base.getproperty(t::Union{Add, Mul, Pow, Term}, f::Symbol) - if f === :op - Base.depwarn( - "`x.op` is deprecated, use `operation(x)` instead", :getproperty) - operation(t) - elseif f === :args - Base.depwarn("`x.args` is deprecated, use `arguments(x)` instead", - :getproperty) - arguments(t) - else - getfield(t, f) - end -end <ₑ(s::Num, x) = value(s) <ₑ value(x) <ₑ(s, x::Num) = value(s) <ₑ value(x) <ₑ(s::Num, x::Num) = value(s) <ₑ value(x) -Num(q::AbstractIrrational) = Num(Term(identity, [q])) +function Num(q::AbstractIrrational) + args = SymbolicUtils.ArgsT{VartypeT}((q,)) + Num(Term{VartypeT}(identity, args; type = Real, shape = SymbolicUtils.ShapeVecT())) +end for T in (Integer, Rational) @eval Base.:(^)(n::Num, i::$T) = Num(value(n)^i) @@ -179,12 +179,8 @@ end @num_method Base.isequal begin va = value(a) vb = value(b) - if va isa SymbolicUtils.BasicSymbolic{Real} && vb isa SymbolicUtils.BasicSymbolic{Real} - isequal(va, vb)::Bool - else - isequal(va, vb)::Bool - end -end (AbstractFloat, Number, Symbolic) + isequal(va, vb)::Bool +end (AbstractFloat, Number, BasicSymbolic) # Provide better error message for symbolic variables in ranges function Base.:(:)(a::Integer, b::Num) @@ -249,28 +245,25 @@ end Base.to_index(x::Num) = Base.to_index(value(x)) -Base.hash(x::Num, h::UInt) = hash(value(x), h)::UInt +Base.hash(x::Num, h::UInt) = hash(unwrap(x), h)::UInt -Base.convert(::Type{Num}, x::Symbolic{<:Number}) = Num(x) -Base.convert(::Type{Num}, x::Number) = Num(x) +function Base.convert(::Type{Num}, x::BasicSymbolic) + symtype(x) <: Real || error("`symtype` must be `<:Real`") + Num(x) +end +# TODO: `Const{T}` instead of `Const{SymReal}` +Base.convert(::Type{Num}, x::Number) = Num(Const{SymReal}(x)) Base.convert(::Type{Num}, x::Num) = x Base.convert(::Type{T}, x::AbstractArray{Num}) where {T <: Array{Num}} = T(map(Num, x)) -function Base.convert(::Type{Sym}, x::Num) - value(x) isa Sym ? value(x) : error("cannot convert $x to Sym") -end function LinearAlgebra.lu( x::Union{Adjoint{<:RCNum}, Transpose{<:RCNum}, Array{<:RCNum}}; check = true, kw...) sym_lu(x; check = check) end -_iszero(x::Number) = iszero(x) -_isone(x::Number) = isone(x) -_iszero(::Symbolic) = false -_isone(::Symbolic) = false -_iszero(x::Num) = _iszero(value(x))::Bool -_isone(x::Num) = _isone(value(x))::Bool +SymbolicUtils._iszero(x::Num) = SymbolicUtils._iszero(value(x)) +SymbolicUtils._isone(x::Num) = SymbolicUtils._isone(value(x)) Code.cse(x::Num) = Code.cse(unwrap(x)) diff --git a/src/parsing.jl b/src/parsing.jl index 4ccd33f04..d7ca71f84 100644 --- a/src/parsing.jl +++ b/src/parsing.jl @@ -118,7 +118,7 @@ function parse_expr_to_symbolic(ex::Expr, mod::Union{Module,AbstractDict}) else # Treat as symbolic function or term x = parse_expr_to_symbolic(op, mod) - return Term{Real}(x, parsed_args) + return Term{VartypeT}(x, parsed_args; type = Real) end elseif ex.head == :ref arr = parse_expr_to_symbolic(ex.args[1], mod) diff --git a/src/register.jl b/src/register.jl index 8d1c01dc1..6721fa84f 100644 --- a/src/register.jl +++ b/src/register.jl @@ -1,5 +1,3 @@ -using SymbolicUtils: Symbolic - """ @register_symbolic(expr, define_promotion = true, Ts = [Real]) @@ -23,34 +21,37 @@ overwriting. ``` See `@register_array_symbolic` to register functions which return arrays. """ -macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays = true) - f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, Ts) +macro register_symbolic(expr, define_promotion = true, wrap_arrays = true) + f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr) args′ = map((a, T) -> :($a::$T), argnames, Ts) ret_type = isnothing(ret_type) ? Real : ret_type - + N = length(args′) + symbolicT = Union{BasicSymbolic{VartypeT}, AbstractArray{BasicSymbolic{VartypeT}}} fexpr = :(Symbolics.@wrapped function $f($(args′...)) - args = [$(argnames...),] - unwrapped_args = map($nested_unwrap, args) - res = if !any($is_symbolic_or_array_of_symbolic, unwrapped_args) - $f(unwrapped_args...) # partial-eval if all args are unwrapped - else - $Term{$ret_type}($f, unwrapped_args) - end - if typeof.(args) == typeof.(unwrapped_args) - return res - else - return $wrap(res) - end - end $wrap_arrays) + args = ($(argnames...),) + if Base.Cartesian.@nany $N i -> args[i] isa $symbolicT + args = Base.Cartesian.@ntuple $N i -> $Const{$VartypeT}(args[i]) + $Term{$VartypeT}($f, $(SymbolicUtils.ArgsT){$VartypeT}(args); type = $ret_type, shape = $(SymbolicUtils.ShapeVecT())) + else + $f($(argnames...)) + end + end $wrap_arrays) if define_promotion fexpr = :($fexpr; (::$typeof($promote_symtype))(::$ftype, args...) = $ret_type) + promote_expr = quote + function (::$(typeof(SymbolicUtils.promote_shape)))(::$ftype, args::$(SymbolicUtils.ShapeT)...) + @nospecialize args + $(SymbolicUtils.ShapeVecT)() + end + end + fexpr = :($fexpr; $promote_expr) end esc(fexpr) end -function destructure_registration_expr(expr, Ts) +function destructure_registration_expr(expr) if expr.head === :(::) ret_type = expr.args[2] expr = expr.args[1] @@ -58,8 +59,6 @@ function destructure_registration_expr(expr, Ts) ret_type = nothing end @assert expr.head === :call - @assert Ts.head === :vect - Ts = Ts.args f = expr.args[1] args = expr.args[2:end] @@ -92,7 +91,7 @@ function is_symbolic_or_array_of_symbolic(arr::AbstractArray) end symbolic_eltype(x) = eltype(x) -symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: SymbolicUtils.Symbolic{eT}} = eT +symbolic_eltype(x::AbstractArray{BasicSymbolic{T}}) where {T} = eltype(symtype(Const{T}(x))) symbolic_eltype(::AbstractArray{Num}) = Real symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: Arr{eT}} = eT @@ -103,33 +102,46 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs ex.args[1] => ex.args[2] end |> Dict + shape_expr = if haskey(defs, :size) + quote + sz = $(defs[:size]) + nd = length(sz) + sh = $(SymbolicUtils.ShapeVecT)(map(Base.UnitRange{Int} ∘ Base.OneTo, sz)) + end + else + quote + nd = $(get(defs, :ndims, -1)) + sh = $(SymbolicUtils.Unknown)(sh) + end + end + eltype_expr = get(defs, :eltype, Any) + container_type = get(defs, :container_type, Array) args′ = map((a, T) -> :($a::$T), argnames, Ts) + N = length(args′) + symbolicT = Union{BasicSymbolic{VartypeT}, AbstractArray{BasicSymbolic{VartypeT}}} + assigns = macroexpand(@__MODULE__, :(Base.Cartesian.@nexprs $N i -> ($argnames[i] = args[i]))) fexpr = quote @wrapped function $f($(args′...)) - args = [$(argnames...),] - unwrapped_args = map($nested_unwrap, args) - eltype = $symbolic_eltype - res = if !any($is_symbolic_or_array_of_symbolic, unwrapped_args) - $f(unwrapped_args...) # partial-eval if all args are unwrapped - elseif $ret_type == nothing || ($ret_type <: AbstractArray) - $array_term($(Expr(:parameters, [Expr(:kw, k, v) for (k, v) in defs]...)), $f, unwrapped_args...) - else - $Term{$ret_type}($f, unwrapped_args) - end - - if typeof.(args) == typeof.(unwrapped_args) - return res + args = ($(argnames...),) + if Base.Cartesian.@nany $N i -> args[i] isa $symbolicT + args = Base.Cartesian.@ntuple $N i -> $Const{$VartypeT}(args[i]) + $assigns + $shape_expr + eltype = $eltype ∘ $symtype + type = if nd == -1 + $container_type{$eltype_expr} + else + $container_type{$eltype_expr, nd} + end + $Term{$VartypeT}($f, $(SymbolicUtils.ArgsT){$VartypeT}(args); type, shape = sh) else - return $wrap(res) + $f($(argnames...)) end end $wrap_arrays end |> esc if define_promotion - container_type = get(defs, :container_type, :($propagate_atype(f, $(argnames...)))) - etype = get(defs, :eltype, :($propagate_eltype(f, $(argnames...)))) - ndim = get(defs, :ndims, nothing) is_callable_struct = f isa Expr && f.head == :(::) fn_arg = if is_callable_struct f @@ -141,18 +153,25 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs else :f end + + shape_args = [:($name::$(SymbolicUtils.ShapeT)) for name in argnames] promote_expr = quote function (::$typeof($promote_symtype))($fn_arg, $(argnames...)) f = $fn_arg_name container_type = $container_type - etype = $etype - $( - if ndim === nothing - :(return container_type{etype}) - else - :(ndim = $ndim; return container_type{etype, ndim}) - end - ) + nd = $(get(defs, :ndims, -1)) + etype = $eltype_expr + if nd == -1 + return container_type{etype} + else + return container_type{etype, nd} + end + end + function (::$(typeof(SymbolicUtils.promote_shape)))($fn_arg, $(shape_args...)) + @nospecialize $(argnames...) + size = identity + $shape_expr + return sh end end |> esc fexpr = :($fexpr; $promote_expr) @@ -194,6 +213,6 @@ overloads for one function, all the rest of the registers must set overwriting. """ macro register_array_symbolic(expr, block, define_promotion = true, wrap_arrays = true) - f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, :([])) + f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr) register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion, wrap_arrays) end diff --git a/src/rewrite-helpers.jl b/src/rewrite-helpers.jl index efa8b864e..99e0a2b54 100644 --- a/src/rewrite-helpers.jl +++ b/src/rewrite-helpers.jl @@ -1,5 +1,5 @@ """ - replacenode(expr::Symbolic, rules...) + replacenode(expr::BasicSymbolic, rules...) Walk the expression and replacenode subexpressions according to `rules`. `rules` could be rules constructed with `@rule`, a function, or a pair where the @@ -18,12 +18,12 @@ function replacenode(expr::Num, r::Pair, rules::Pair...; fixpoint = false) end # Fix ambiguity replacenode(expr::Num, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint) -replacenode(expr::Symbolic, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint) -replacenode(expr::Symbolic, r::Pair, rules::Pair...; fixpoint = false) = _replacenode(expr, r, rules...; fixpoint) +replacenode(expr::BasicSymbolic, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint) +replacenode(expr::BasicSymbolic, r::Pair, rules::Pair...; fixpoint = false) = _replacenode(expr, r, rules...; fixpoint) replacenode(expr::Number, rules...; fixpoint = false) = expr replacenode(expr::Number, r::Pair, rules::Pair...; fixpoint = false) = expr -function _replacenode(expr::Symbolic, rules...; fixpoint = false) +function _replacenode(expr::BasicSymbolic, rules...; fixpoint = false) rs = map(r -> r isa Pair ? (x -> isequal(x, unwrap(r[1])) ? unwrap(r[2]) : nothing) : r, rules) R = Prewalk(Chain(rs)) if fixpoint @@ -52,12 +52,12 @@ D = Differential(t) hasnode(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`. ``` """ -function hasnode(r::Function, y::Union{Num, Symbolic}) +function hasnode(r::Function, y::Union{Num, BasicSymbolic}) _hasnode(r, y) end -hasnode(r::Num, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y)) -hasnode(r::Symbolic, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y)) -hasnode(r::Union{Num, Symbolic, Function}, y::Number) = false +hasnode(r::Num, y::Union{Num, BasicSymbolic}) = occursin(unwrap(r), unwrap(y)) +hasnode(r::BasicSymbolic, y::Union{Num, BasicSymbolic}) = occursin(unwrap(r), unwrap(y)) +hasnode(r::Union{Num, BasicSymbolic, Function}, y::Number) = false function _hasnode(r, y) y = unwrap(y) diff --git a/src/semipoly.jl b/src/semipoly.jl index 121a23c0d..3bf0e6050 100644 --- a/src/semipoly.jl +++ b/src/semipoly.jl @@ -3,251 +3,59 @@ using DataStructures export semipolynomial_form, semilinear_form, semiquadratic_form, polynomial_coeffs -import SymbolicUtils: unsorted_arguments - -""" -$(TYPEDEF) - -# Attributes -$(TYPEDFIELDS) -""" -struct SemiMonomial - "monomial" - p::Union{S, N} where {S <: Symbolic, N <: Real} - "coefficient" - coeff::Any -end - -Base.:+(a::SemiMonomial) = a -function Base.:+(a::SemiMonomial, b::SemiMonomial) - Term(+, [a, b]) -end -function Base.:+(m::SemiMonomial, t) - if iscall(t) && operation(t) == (+) - return Term(+, [arguments(t); m]) - end - Term(+, [m, t]) -end -Base.:+(t, m::SemiMonomial) = m + t - -Base.:*(m::SemiMonomial) = m -function Base.:*(a::SemiMonomial, b::SemiMonomial) - SemiMonomial(a.p * b.p, a.coeff * b.coeff) -end -Base.:*(m::SemiMonomial, n::Number) = SemiMonomial(m.p, m.coeff * n) -function Base.:*(m::SemiMonomial, t::Symbolic) - if iscall(t) - op = operation(t) - if op == (+) - args = collect(all_terms(t)) - return Term(+, (m,) .* args) - elseif op == (*) - return Term(*, [arguments(t); m]) - end - end - Term(*, [t, m]) -end -Base.:*(t, m::SemiMonomial) = m * t - -function Base.:/(a::SemiMonomial, b::SemiMonomial) - SemiMonomial(a.p / b.p, a.coeff / b.coeff) -end - -function Base.:^(base::SemiMonomial, exp::Real) - SemiMonomial(base.p^exp, base.coeff^exp) -end - -# return a dictionary of exponents with respect to variables -function pdegrees(x) - if ismul(x) - return x.dict - elseif isdiv(x) - num_dict = pdegrees(x.num) - den_dict = pdegrees(x.den) - inv_den_dict = Dict(keys(den_dict) .=> map(-, values(den_dict))) - mergewith(+, num_dict, inv_den_dict) - elseif ispow(x) - dict = pdegrees(x.base) - degrees = map(degree -> degree * x.exp, values(dict)) - Dict(keys(dict) .=> degrees) - elseif issym(x) || iscall(x) - return Dict(x=>1) - elseif x isa Number - return Dict() - else - error("pdegrees for $x unknown") - end -end - -pdegree(x::Number) = 0 -function pdegree(x::Symbolic) - degree_dict = pdegrees(x) - if isempty(degree_dict) - return 0 - end - sum(values(degree_dict)) -end - -issemimonomial(x) = x isa SemiMonomial - -# Return true is `m` is a `SemiMonomial`, satisfies the definition of a monomial and -# its degree is less than or equal to `degree_bound`. -# If `m` is a constant about `vars`, return true if `consts = true` and return false if -# `consts = false`. -function isboundedmonomial(m, vars, degree_bound::Real; consts = true)::Bool - if !(m isa SemiMonomial) - return false - end - degree_dict = pdegrees(m.p) - if isempty(degree_dict) - return consts && !has_vars(m.coeff, vars) - end - degrees = values(degree_dict) - for degree in degrees - if !isinteger(degree) || degree < 0 - return false - end - end - if sum(degrees) > degree_bound - return false - end - !has_vars(m.coeff, vars) -end - -# Return true if the degrees of `m` are all 0s and its coefficient is a `Real`. -Base.:isreal(m::SemiMonomial) = m.p isa Number && isone(m.p) && unwrap(m.coeff) isa Real -Base.:isreal(::Symbolic) = false - -# Transform `m` to a `Real`. -# Assume `isreal(m) == true`, otherwise calling this function does not make sense. -function Base.:real(m::SemiMonomial)::Real - if isinteger(m.coeff) - return Int(m.coeff) - end - return m.coeff -end - -symtype(m::SemiMonomial) = symtype(m.p) - -issym(::SemiMonomial) = true - -Base.:nameof(m::SemiMonomial) = Symbol(:SemiMonomial, m.p, m.coeff) - -isop(x, op) = iscall(x) && operation(x) === op -isop(op) = Base.Fix2(isop, op) - -simpleterm(T, f, args, m) = Term{SymbolicUtils._promote_symtype(f, args)}(f, args) - -function mark_and_exponentiate(expr, vars) - # Step 1 - # Mark all the interesting variables -- substitute without recursing into nl forms - expr′ = mark_vars(expr, vars) - - # Step 2 - # Construct and propagate BoundedDegreeMonomial for ^ and * and / - - # does not do fraction simplification - rules = [@rule (~a::issemimonomial)^(~b::isreal) => (~a)^real(~b) - @rule (~a::isop(+))^(~b::isreal) => expand(Pow((~a), real(~b))) - @rule *(~~xs::(xs -> all(issemimonomial, xs))) => *(~~xs...) - @rule *(~~xs::(xs -> any(isop(+), xs))) => expand(Term(*, ~~xs)) - @rule (~a::isop(+)) / (~b::issemimonomial) => +(map(x->x/~b, arguments(~a))...) - @rule (~a::issemimonomial) / (~b::issemimonomial) => (~a) / (~b)] - expr′ = Postwalk(RestartedChain(rules), maketerm = simpleterm)(expr′) -end - -function semipolyform_terms(expr, vars) - expr = mark_and_exponentiate(expr, vars) - if iscall(expr) && operation(expr) == (+) - args = collect(all_terms(expr)) - return args - elseif isreal(expr) && iszero(real(expr)) # when `expr` is just a 0 - return [] - else - return [expr] - end -end -semipolyform_terms(vars) = Base.Fix2(semipolyform_terms, vars) - -""" -$(TYPEDSIGNATURES) - -Return true if `expr` contains any variables in `vars`. -""" -function has_vars(expr, vars)::Bool - if symbolic_type(expr) == ArraySymbolic() && shape(expr) != Unknown() - for i in eachindex(expr) - expr[i] in vars && return true +const SemipolyDictT = Dict{BasicSymbolic{VartypeT}, BasicSymbolic{VartypeT}} + +function canonicalize_poly(poly_to_bs, bs_to_poly, poly, degree, vars::AbstractSet) + subskeys = SymbolicUtils.PolyVarT[] + subsvals = SymbolicUtils.PolynomialT[] + for pvar in MP.variables(poly) + var = poly_to_bs[pvar] + if isdiv(var) + var = SymbolicUtils.flatten_fractions(var) end - end - if expr in vars - return true - elseif iscall(expr) - for arg in arguments(expr) - if has_vars(arg, vars) - return true + @match var begin + BSImpl.Div(; num, den) => begin + den_has_vars = SymbolicUtils.query!(in(vars), den) + # We only care about terms up to degree `degree`. So we only care about + # numerators with degree <= `2degree + 1`, since if the denominator is + # degree `degree` this entire variable goes in the residual anyway. The +1 + # is in case `degree == 0`. + result, resid = semipolynomial_form(num, vars, 2degree + 1; consts = false) + newpoly = SymbolicUtils.to_poly!(poly_to_bs, bs_to_poly, resid / den) + if newpoly isa SymbolicUtils.PolyVarT + newpoly = MP.polynomial(newpoly, SymbolicUtils.PolyCoeffT) + end + for (monomial, coeff) in result + if den_has_vars + monomial = monomial / den + @match monomial begin + BSImpl.Div(; num = n2, den = d2) => begin + den_has_vars2 = isequal(den, d2) || SymbolicUtils.query!(in(vars), den) + if den_has_vars2 + coeff *= monomial + monomial = 1 + end + end + _ => nothing + end + else + coeff = coeff / den + end + mono = SymbolicUtils.to_poly!(poly_to_bs, bs_to_poly, monomial, false) + coeff = if SymbolicUtils.isconst(coeff) + unwrap_const(coeff) + else + SymbolicUtils.basicsymbolic_to_polyvar(bs_to_poly, coeff) + end + MA.operate!(+, newpoly, mono * coeff) + end + push!(subskeys, pvar) + push!(subsvals, newpoly) end - end - elseif expr isa Array - for el in expr - has_vars(el, vars) && return true + _ => nothing end end - return false -end - -function mark_vars(expr, vars) - if expr in vars - return SemiMonomial(expr, 1) - elseif !iscall(expr) - return SemiMonomial(1, expr) - end - op = operation(expr) - if op === (^) || op == (/) - args = arguments(expr) - @assert length(args) == 2 - return Term{symtype(expr)}(op, map(mark_vars(vars), args)) - end - args = arguments(expr) - if op === (+) || op === (*) - return Term{symtype(expr)}(op, map(mark_vars(vars), args)) - elseif length(args) == 1 - if op == sqrt - return mark_vars(args[1]^(1//2), vars) - elseif linearity_1(op) - return Term{symtype(expr)}(op, mark_vars(args[1], vars)) - end - end - return SemiMonomial(1, expr) -end -mark_vars(vars) = Base.Fix2(mark_vars, vars) - -function bifurcate_terms(terms, vars, degree::Real; consts = true) - # Step 4: Bifurcate polynomial and nonlinear parts: - monomial_indices = findall(t -> isboundedmonomial(t, vars, degree; consts = consts), - terms) - monomials = @view terms[monomial_indices] - polys_dict = Dict() - sizehint!(polys_dict, length(monomials)) - for m in monomials - if haskey(polys_dict, m.p) - polys_dict[m.p] += m.coeff - else - polys_dict[m.p] = m.coeff - end - end - if length(monomials) == length(terms) - return polys_dict, 0 - end - deleteat!(terms, monomial_indices) # the remaining elements in terms are not monomials - nl = cautious_sum(terms) - return polys_dict, nl -end - -function init_semipoly_vars(vars) - set = OrderedSet(unwrap.(vars)) - @assert length(set) == length(vars) # vars passed to semi-polynomial form must be unique - set + return MP.subs(poly, subskeys => subsvals) end """ @@ -266,12 +74,64 @@ a key `1` and the corresponding value will be the constant term. If `false`, the function semipolynomial_form(expr, vars, degree::Real; consts = true) if degree < 0 @warn "Degree for semi-polynomial form should be ≥ 0" - return Dict(), expr + return SemipolyDictT(), expr + end + vars = Set([unwrap(x) for x in vars]) + for v in vars + v isa BasicSymbolic{VartypeT} || continue + @match v begin + BSImpl.Term(; f, args) && if f === getindex end => push!(vars, args[1]) + _ => nothing + end end - vars = init_semipoly_vars(vars) expr = unwrap(expr) - terms = semipolyform_terms(expr, vars) - bifurcate_terms(terms, vars, degree; consts = consts) + expr = expand(expr, false) + poly_to_bs = Bijections.Bijection{SymbolicUtils.PolyVarT, BasicSymbolic{VartypeT}}() + bs_to_poly = Bijections.active_inv(poly_to_bs) + poly = SymbolicUtils.to_poly!(poly_to_bs, bs_to_poly, expr, false) + poly = canonicalize_poly(poly_to_bs, bs_to_poly, poly, degree, vars) + pvars = MP.variables(poly) + nonpoly_mask = falses(length(pvars)) + in_vars_mask = falses(length(pvars)) + for (i, pvar) in enumerate(pvars) + var = poly_to_bs[pvar] + in_vars_mask[i] = var in vars + in_vars_mask[i] && continue + nonpoly_mask[i] = SymbolicUtils.query!(in(vars), var) + end + result = SemipolyDictT() + constant = SymbolicUtils.zeropoly() + residual = SymbolicUtils.zeropoly() + for t in MP.terms(poly) + is_monomial_in_vars = true + monomial_degree = 0 + true_monomial = SymbolicUtils.MonomialT() + coeff_monomial = SymbolicUtils.MonomialT() + for (i, exp) in enumerate(MP.exponents(t)) + iszero(exp) && continue + monomial_degree += in_vars_mask[i] * exp + MA.operate!(*, in_vars_mask[i] ? true_monomial : coeff_monomial, MP.variables(t)[i] ^ exp) + is_monomial_in_vars &= !nonpoly_mask[i] + end + if is_monomial_in_vars && monomial_degree <= degree + if monomial_degree == 0 + MA.operate!(+, constant, t) + else + mono = SymbolicUtils.from_poly(poly_to_bs, true_monomial) + result[mono] = get(result, mono, 0) + MP.coefficient(t) * SymbolicUtils.from_poly(poly_to_bs, coeff_monomial) + end + else + MA.operate!(+, residual, t) + end + end + + if consts && !iszero(constant) + result[Const{VartypeT}(1)] = SymbolicUtils.from_poly(poly_to_bs, constant) + else + MA.operate!(+, residual, constant) + end + residual = SymbolicUtils.from_poly(poly_to_bs, residual) + return result, residual end """ @@ -287,13 +147,21 @@ a key `1` and the corresponding value will be the constant term. If `false`, the function semipolynomial_form(exprs::AbstractArray, vars, degree::Real; consts = true) if degree < 0 @warn "Degree for semi-polynomial form should be ≥ 0" - return fill(Dict(), length), exprs + return fill(SemipolyDictT(), size(exprs)), exprs + end + if any(iswrapped, vars) + vars = map(unwrap, vars) end - vars = init_semipoly_vars(vars) - exprs = unwrap.(exprs) - matches = map(semipolyform_terms(vars), exprs) - tmp = map(match -> bifurcate_terms(match, vars, degree; consts = consts), matches) - map(first, tmp), map(last, tmp) + if !(vars isa AbstractSet) + vars = Set(vars) + end + results = similar(exprs, SemipolyDictT) + residuals = similar(exprs, BasicSymbolic{VartypeT}) + + for (i, expr) in enumerate(exprs) + results[i], residuals[i] = semipolynomial_form(expr, vars, degree; consts) + end + return results, residuals end """ @@ -377,6 +245,7 @@ function semiquadratic_form(exprs, vars) push!(I2, i) if isop(k, ^) b, e = arguments(k) + e = unwrap_const(e) @assert e == 2 q = idxmap[b] j = div(q*(q+1), 2) @@ -410,37 +279,37 @@ function semiquadratic_form(exprs, vars) wrap.(nls)) end -## Utilities - -all_terms(x) = iscall(x) && operation(x) == (+) ? collect(Iterators.flatten(map(all_terms, arguments(x)))) : (x,) +isop(x, op) = iscall(x) && operation(x) === op +isop(op) = Base.Fix2(isop, op) -function unwrap_sp(m::SemiMonomial) - degree_dict = pdegrees(m.p) - # avoid making negative exponent in `Mul` dict - positive_dict = Dict() - negative_dict = Dict() - for (var, degree) in degree_dict - if isinteger(degree) - degree = Int(degree) - end - if degree > 0 - positive_dict[var] = degree - else - negative_dict[var] = -degree - end - end - m.coeff * Mul(symtype(m.p), 1, positive_dict) / Mul(symtype(m.p), 1, negative_dict) -end -function unwrap_sp(x) +function pdegrees(x) x = unwrap(x) - iscall(x) ? maketerm(typeof(x), - TermInterface.head(x), map(unwrap_sp, - TermInterface.children(x)), nothing) : x + if ismul(x) + return x.dict + elseif isdiv(x) + num_dict = pdegrees(x.num) + den_dict = pdegrees(x.den) + inv_den_dict = Dict(keys(den_dict) .=> map(-, values(den_dict))) + mergewith(+, num_dict, inv_den_dict) + elseif ispow(x) + base, exp = arguments(x) + dict = pdegrees(base) + degrees = map(degree -> degree * unwrap_const(exp), values(dict)) + Dict(keys(dict) .=> degrees) + elseif issym(x) || iscall(x) + return Dict(x=>1) + elseif SymbolicUtils.isconst(x) || x isa Number + return Dict() + else + error("pdegrees for $x unknown") + end end -function cautious_sum(nls) - if isempty(nls) +pdegree(x::Number) = 0 +function pdegree(x) + degree_dict = pdegrees(x) + if isempty(degree_dict) return 0 end - sum(unwrap_sp, nls) + sum(values(degree_dict)) end diff --git a/src/solver/attract.jl b/src/solver/attract.jl index 6d03778e8..6c32bc903 100644 --- a/src/solver/attract.jl +++ b/src/solver/attract.jl @@ -156,7 +156,7 @@ function attract_exponential(lhs, var) r_addexpon = Vector{Any}() #! format: off - push!(r_addexpon, @acrule (~b)^(~f::(contains_var)) + (~d)^(~g::(contains_var)) => ~f*term(slog, ~b) - ~g*term(slog, ~d) + term(log, term(complex, -1))) + push!(r_addexpon, @acrule (~b)^(~f::(contains_var)) + (~d)^(~g::(contains_var)) => ~f*term(slog, ~b) - ~g*term(slog, ~d) + term(slog, -1)) push!(r_addexpon, @acrule (~a)*(~b)^(~f::(contains_var)) + (~d)^(~g::(contains_var)) => ~f*term(slog, ~b) - ~g*term(slog, ~d) + term(slog, -~a)) push!(r_addexpon, @acrule (~a)*(~b)^(~f::(contains_var)) + (~c)*(~d)^(~g::(contains_var)) => ~f*term(slog, ~b) - ~g*term(slog, ~d) + term(slog, -(~a)//(~c))) #! format: on diff --git a/src/solver/ia_helpers.jl b/src/solver/ia_helpers.jl index 00d3a6e93..f5b515cbb 100644 --- a/src/solver/ia_helpers.jl +++ b/src/solver/ia_helpers.jl @@ -40,7 +40,7 @@ function n_func_occ(expr, var) n_occurrences(arg, var) == 0 && continue # x - if !iscall(arg) && isequal(var, get_variables(arg)[1]) && !outside + if !iscall(arg) && isequal(var, first(get_variables(arg))) && !outside n += 1 outside = true continue @@ -57,7 +57,7 @@ function n_func_occ(expr, var) case_1_pow = oper_arg === (^) && n_occurrences(args_arg[2], var) == 0 && n_occurrences(args_arg[1], var) != 0 && check_poly_inunivar(args_arg[1], var) && - n_occurrences(arg, var) != 0 && !(args_arg[2] isa Number) + n_occurrences(arg, var) != 0 && !(value(args_arg[2]) isa Number) case_2_pow = oper_arg === (^) && n_occurrences(args_arg[2], var) != 0 && n_occurrences(args_arg[1], var) == 0 case_3_pow = oper_arg === (^) && n_occurrences(args_arg[2], var) == 0 && @@ -83,6 +83,16 @@ function n_func_occ(expr, var) # n(2 / x) = 1; n(x/x^2) = 2? elseif oper_arg === (/) + num, den = args_arg + if SymbolicUtils.isconst(den) + if is_var_outside(num) + n += 1 + outside = true + elseif !check_poly_inunivar(num, var) + n += n_func_occ(num, var) + end + continue + end n += n_func_occ(numerator(arg), var) n += n_func_occ(denominator(arg), var) diff --git a/src/solver/ia_main.jl b/src/solver/ia_main.jl index 193d3ef10..aef6b0883 100644 --- a/src/solver/ia_main.jl +++ b/src/solver/ia_main.jl @@ -21,7 +21,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri else a, b, islin = linear_expansion(lhs - new_var, var) if islin - lhs_roots = [-b // a] + lhs_roots = [-b / a] else lhs_roots = [RootsOf(lhs - new_var, var)] if warns @@ -32,7 +32,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri for i in eachindex(lhs_roots) for j in eachindex(rhs) - if iscall(lhs_roots[i]) && operation(lhs_roots[i]) == RootsOf + if iscall(lhs_roots[i]) && operation(lhs_roots[i]) === RootsOf _args = copy(parent(arguments(lhs_roots[i]))) _args[1] = substitute(_args[1], Dict(new_var => rhs[j]), fold = false) T = typeof(lhs_roots[i]) @@ -53,7 +53,7 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri return nothing, conditions end - old_lhs = deepcopy(lhs) + old_lhs = lhs oper = operation(lhs) args = arguments(lhs) @@ -86,32 +86,33 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri lhs = unwrap(numerator(lhs)) else # 2 / x = y - rhs = map(sol -> term(/, numerator(lhs), sol), rhs) + rhs = map(sol -> /(numerator(lhs), sol), rhs) lhs = unwrap(denominator(lhs)) end elseif oper === (^) var_in_base = any(isequal(x, var) for x in get_variables(args[1])) var_in_pow = n_occurrences(args[2], var) != 0 - if var_in_base && !var_in_pow && args[2] isa Integer + a2 = unwrap_const(args[2]) + if var_in_base && !var_in_pow && a2 isa Integer lhs = args[1] - power = args[2] + power = a2 new_roots = [] if complex_roots for i in eachindex(rhs) - for k in 0:(args[2] - 1) - r = term(^, rhs[i], (1 // power)) - c = term(*, 2 * (k), pi) * im / power + for k in 0:(a2 - 1) + r = ^(rhs[i], (1 // power)) + c = *(2 * (k), pi) * im / power root = r * Base.MathConstants.e^c push!(new_roots, root) end end else for i in eachindex(rhs) - push!(new_roots, term(^, rhs[i], (1 // power))) + push!(new_roots, ^(rhs[i], (1 // power))) if iseven(power) - push!(new_roots, term(-, new_roots[end])) + push!(new_roots, -new_roots[end]) end end end @@ -120,10 +121,10 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri elseif var_in_base && !var_in_pow lhs = args[1] s, power = filter_stuff(args[2]) - rhs = map(sol -> term(^, sol, 1 // power), rhs) + rhs = map(sol ->^(sol, 1 / power), rhs) else lhs = args[2] - rhs = map(sol -> term(/, term(slog, sol), term(slog, args[1])), rhs) + rhs = map(sol -> /(slog(sol), slog(args[1])), rhs) end elseif has_left_inverse(oper) lhs = args[1] @@ -135,12 +136,12 @@ function isolate(lhs, var; warns=true, conditions=[], complex_roots = true, peri new_var = (@variables $new_var)[1] period = fundamental_period(oper) rhs = map( - sol -> term(invop, sol) + - term(*, period, new_var), + sol -> invop(sol) + + *(period, new_var), rhs) @info string(new_var) * " ϵ" * " Ζ" else - rhs = map(sol -> term(invop, sol), rhs) + rhs = map(sol -> invop(sol), rhs) end end @@ -305,7 +306,7 @@ function ia_solve(lhs, var; warns = true, complex_roots = true, periodic_roots = domain_error = false for j in eachindex(conditions) condition, t = conditions[j] - cond_val = substitute(condition, Dict(var=>eval(toexpr(sols[i])))) + cond_val = unwrap_const(substitute(condition, Dict(var=>eval(toexpr(sols[i]))))) cond_val isa Complex && continue domain_error |= !t(cond_val, 0) end diff --git a/src/solver/main.jl b/src/solver/main.jl index 9082a9518..ba5265fb2 100644 --- a/src/solver/main.jl +++ b/src/solver/main.jl @@ -146,7 +146,7 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where expr_univar = false x_univar = false - if (T == Num || T == SymbolicUtils.BasicSymbolic{Real}) + if (T === Num || T === BasicSymbolic{VartypeT} && symtype(x) <: Real) x_univar = true check_x(x) else @@ -154,7 +154,6 @@ function symbolic_solve(expr, x::T; dropmultiplicity = true, warns = true) where check_x(var) end end - if !(expr isa Vector) expr_univar = true expr = expr isa Equation ? expr.lhs - expr.rhs : expr @@ -241,9 +240,13 @@ function symbolic_solve(expr; x...) sub_vars = reduce(vcat, subs isa Dict ? collect(keys(subs)) : collect.(keys.(subs))) end - vars_list = get_variables.(filtered) - filt_vars = unique(isa(vars_list, AbstractArray) && all(isa.(vars_list, AbstractArray)) ? reduce(vcat, vars_list) : vars_list) - + vars_list = Set{BasicSymbolic{VartypeT}}() + for ex in wrap(filtered) + SymbolicUtils.search_variables!(vars_list, ex) + end + # vars_list = get_variables.(filtered) + # filt_vars = unique(isa(vars_list, AbstractArray) && all(isa.(vars_list, AbstractArray)) ? reduce(vcat, vars_list) : vars_list) + filt_vars = collect(vars_list) vars = isempty(sub_vars) ? filt_vars : setdiff(filt_vars, sub_vars) vars = wrap.(vars) @assert all(v isa Num for v in vars) "All variables should be Nums or BasicSymbolics" @@ -281,16 +284,15 @@ function solve_univar(expression, x; dropmultiplicity=true, strict=true) args = [] mult_n = 1 expression = unwrap(expression) - expression = expression isa PolyForm ? SymbolicUtils.toterm(expression) : expression # handle multiplicities (repeated roots), i.e. (x+1)^20 if iscall(expression) expr = unwrap(simplify((copy(wrap(expression))))) args = arguments(expr) operation = SymbolicUtils.operation(expr) - if isequal(operation, ^) && args[2] isa Int64 + if isequal(operation, ^) && SymbolicUtils.isconst(args[2]) && (a2 = unwrap_const(args[2]); a2 isa Int64) expression = wrap(args[1]) - mult_n = args[2] + mult_n = a2 end end diff --git a/src/solver/nemo_stuff.jl b/src/solver/nemo_stuff.jl index 4cbef51b2..26b3da07b 100644 --- a/src/solver/nemo_stuff.jl +++ b/src/solver/nemo_stuff.jl @@ -5,11 +5,11 @@ function check_polynomial(poly; strict=true) vars = get_variables(poly) distr, rem = polynomial_coeffs(poly, vars) if strict - @assert isequal(rem, 0) "Not a polynomial" - @assert all(c -> c isa Integer || c isa Rational, collect(values(distr))) "Coefficients must be integer or rational" + @assert SymbolicUtils._iszero(rem) "Not a polynomial" + @assert all(c -> unwrap_const(c) isa Union{Integer, Rational}, collect(values(distr))) "Coefficients must be integer or rational" return true else - return isequal(rem, 0) + return SymbolicUtils._iszero(rem) end end diff --git a/src/solver/polynomialization.jl b/src/solver/polynomialization.jl index 6b0c9e8cd..b437bd9c5 100644 --- a/src/solver/polynomialization.jl +++ b/src/solver/polynomialization.jl @@ -43,7 +43,7 @@ function turn_to_poly(expr, var) expr = unwrap(expr) !iscall(expr) && return (expr, Dict()) - args = copy(parent(arguments(expr))) + args = copy(parent(sorted_arguments(expr))) sub = 0 broken = Ref(false) @@ -107,8 +107,8 @@ julia> trav_pow(unwrap(x^2), x, Ref(false), 3^x) """ function trav_pow(arg, var, broken, sub) args_arg = arguments(arg) - base = args_arg[1] - power = args_arg[2] + base = value(args_arg[1]) + power = value(args_arg[2]) # case 1: log(x)^2 .... 9^x = 3^2^x = 3^2x = (3^x)^2 !isequal(add_sub(sub, base, var, broken), false) && power isa Integer && return arg, base @@ -125,7 +125,7 @@ function trav_pow(arg, var, broken, sub) broken[] = true return arg, false end - new_b = term(^, new_b, p) + new_b = ^(new_b, p) return new_b, sub end @@ -359,7 +359,7 @@ function attract_and_solve_sqrtpoly(lhs, var) for root in roots if isapprox( - substitute(lhs, Dict(var => eval(Symbolics.toexpr(root)))), 0, atol = 1e-4) + value(substitute(lhs, Dict(var => eval(Symbolics.toexpr(root))))), 0, atol = 1e-4) push!(answers, root) end end diff --git a/src/solver/postprocess.jl b/src/solver/postprocess.jl index b0f5c20c5..6b1bcffbd 100644 --- a/src/solver/postprocess.jl +++ b/src/solver/postprocess.jl @@ -30,7 +30,8 @@ end function _postprocess_root(x::SymbolicUtils.BasicSymbolic) !iscall(x) && return x - x = Symbolics.term(operation(x), map(_postprocess_root, arguments(x))...) + x = maketerm(BasicSymbolic{VartypeT}, operation(x), map(_postprocess_root, arguments(x)), nothing) + iscall(x) || return x oper = operation(x) # sqrt(0), cbrt(0) => 0 @@ -46,22 +47,22 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) args = arguments(x) # (X)^0 => 1 - if oper === (^) && isequal(args[2], 0) && !isequal(args[1], 0) + if oper === (^) && SymbolicUtils._iszero(args[2]) && !SymbolicUtils._iszero(args[1]) return 1 end # (X)^1 => X - if oper === (^) && isequal(args[2], 1) + if oper === (^) && SymbolicUtils._isone(args[2]) return args[1] end # (0)^X => 0 - if oper === (^) && isequal(args[1], 0) && !isequal(args[2], 0) + if oper === (^) && SymbolicUtils._iszero(args[1]) && !SymbolicUtils._iszero(args[2]) return 0 end # y / 0 => Inf - if oper === (/) && !isequal(numerator(x), 0) && isequal(denominator(x), 0) + if oper === (/) && !SymbolicUtils._iszero(numerator(x)) && SymbolicUtils._iszero(denominator(x)) return Inf end @@ -76,14 +77,18 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end square, squarefree end - arg = arguments(x)[1] + arg = value(arguments(x)[1]) if arg isa Integer square, squarefree = squarefree_decomp(arg) if arg < 0 square = im * square end if !isone(square) - return square * Symbolics.term(Symbolics.operation(x), squarefree) + if isone(squarefree) + return square + else + return square * Symbolics.term(Symbolics.operation(x), squarefree; type = symtype(x), shape = shape(x)) + end end elseif arg isa Rational n, d = numerator(arg), denominator(arg) @@ -95,7 +100,11 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) nd_square = n_square // d_square nd_squarefree = n_squarefree // d_squarefree if !isone(nd_square) - return nd_square * Symbolics.term(Symbolics.operation(x), nd_squarefree) + if isone(squarefree) + return nd_square + else + return nd_square * Symbolics.term(Symbolics.operation(x), nd_squarefree; type = symtype(x), shape = shape(x)) + end end end end @@ -103,6 +112,7 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) # (sqrt(N))^M => N^div(M, 2)*sqrt(N)^(mod(M, 2)) if oper === (^) arg1, arg2 = arguments(x) + arg2 = unwrap_const(arg2) if iscall(arg1) && (operation(arg1) === sqrt || operation(arg1) === ssqrt) if arg2 isa Integer isequal(arg2, 2) && return arguments(arg1)[1] @@ -121,10 +131,10 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) if oper === (+) args = arguments(x) for arg in args - if isequal(arg, 0) + if SymbolicUtils.isconst(arg) && isequal(value(arg), 0) after_removing = setdiff(args, arg) isone(length(after_removing)) && return after_removing[1] - return Symbolics.term(+, after_removing) + return Symbolics.term(+, after_removing; type = symtype(x), shape = shape(x)) end end end @@ -133,10 +143,13 @@ function _postprocess_root(x::SymbolicUtils.BasicSymbolic) end function postprocess_root(x) - math_consts = (Base.MathConstants.pi, Base.MathConstants.e) + math_consts = Set(Any[Base.MathConstants.pi, Base.MathConstants.e]) while true - old_x = deepcopy(x) - contains_math_const = any([Symbolics.n_occurrences(x, c) > 0 for c in math_consts]) + old_x = x + x isa Union{Num, BasicSymbolic{VartypeT}} || return x + x = unwrap(x) + contains_math_const = SymbolicUtils.query!(in(math_consts) ∘ value, x) + # contains_math_const = any([Symbolics.n_occurrences(x, c) > 0 for c in math_consts]) if contains_math_const x = _postprocess_root(x) else @@ -184,7 +197,7 @@ function convert_consts(x) inv_opers = [asin, acos, atan] if any(isequal(oper, o) for o in inv_opers) && isempty(Symbolics.get_variables(x)) - val = Symbolics.symbolic_to_float(x) + val = value(Symbolics.symbolic_to_float(x)) for (exact, evald) in inv_pairs if isapprox(evald, val) return exact diff --git a/src/solver/preprocess.jl b/src/solver/preprocess.jl index 5f11a3422..78e1c6fb5 100644 --- a/src/solver/preprocess.jl +++ b/src/solver/preprocess.jl @@ -70,6 +70,7 @@ julia> filter_stuff(123) ``` """ function filter_stuff(expr) + expr = value(expr) if expr isa Integer return Dict(), expr elseif expr isa Rational || expr isa AbstractFloat || expr isa Complex @@ -109,22 +110,23 @@ julia> RootFinding._filter_poly(x*sqrt(2), x) function _filter_poly(expr, var) expr = unwrap(expr) vars = get_variables(expr) - if !isequal(vars, []) && isequal(vars[1], expr) + if !isempty(vars) && isequal(first(vars), expr) return (Dict{Any, Any}(), expr) - elseif isequal(vars, []) + elseif isempty(vars) return filter_stuff(expr) end args = copy(parent(arguments(expr))) - if expr isa ComplexTerm + if symtype(expr) <: Complex subs1, subs2 = Dict(), Dict() expr1, expr2 = 0, 0 - - if !isequal(expr.re, 0) - subs1, expr1 = _filter_poly(expr.re, var) + rr = real(expr) + ii = imag(expr) + if !isequal(rr, 0) + subs1, expr1 = _filter_poly(rr, var) end - if !isequal(expr.im, 0) - subs2, expr2 = _filter_poly(expr.im, var) + if !isequal(ii, 0) + subs2, expr2 = _filter_poly(ii, var) end subs = merge(subs1, subs2) @@ -138,27 +140,26 @@ function _filter_poly(expr, var) oper = operation(expr) return subs, term(oper, args...) end - subs = Dict{Any, Any}() for (i, arg) in enumerate(args) # handle constants - arg = unwrap(arg) + arg = value(arg) vars = get_variables(arg) - if isequal(vars, []) + if isempty(vars) if arg isa Integer - args[i] = bigify(args[i]) + args[i] = Const{VartypeT}(bigify(args[i])) continue elseif arg isa Rational || arg isa AbstractFloat || arg isa Complex - args[i] = comp_rational(arg, 1) + args[i] = Const{VartypeT}(comp_rational(arg, 1)) continue end - args[i] = sub(subs, args[i]) + args[i] = Const{VartypeT}(sub(subs, args[i])) continue end # handle "x" as an argument if length(vars) == 1 - if isequal(arg, var) || isequal(vars[1], arg) + if isequal(arg, var) || isequal(first(vars), arg) continue end end @@ -170,8 +171,10 @@ function _filter_poly(expr, var) continue end # filter(args[1]), filter[args[2]] and then merge - subs1, monomial[1] = _filter_poly(monomial[1], var) - subs2, monomial[2] = _filter_poly(monomial[2], var) + subs1, __monomial_1 = _filter_poly(monomial[1], var) + subs2, __monomial_2 = _filter_poly(monomial[2], var) + monomial[1] = Const{VartypeT}(__monomial_1) + monomial[2] = Const{VartypeT}(__monomial_2) merge!(subs, merge(subs1, subs2)) args[i] = maketerm(typeof(arg), oper, monomial, metadata(arg)) @@ -182,7 +185,7 @@ function _filter_poly(expr, var) subs_of_monom = Dict{Any, Any}() for (j, x) in enumerate(monomial) vars = get_variables(x) - if (!isempty(vars) && isequal(vars[1], x)) + if (!isempty(vars) && isequal(first(vars), x)) continue elseif x isa Integer monomial[j] = bigify(monomial[j]) @@ -192,7 +195,8 @@ function _filter_poly(expr, var) continue end # filter each arg and then merge - new_subs, monomial[j] = _filter_poly(monomial[j], var) + new_subs, __monomial_j = _filter_poly(monomial[j], var) + monomial[j] = Const{VartypeT}(__monomial_j) merge!(subs_of_monom, new_subs) end merge!(subs, subs_of_monom) @@ -245,10 +249,10 @@ function filter_poly(og_expr, var; assumptions=false) vars = get_variables(expr) # handle edge cases - if !isequal(vars, []) && isequal(vars[1], expr) + if !isempty(vars) && isequal(first(vars), expr) assumptions && return Dict{Any, Any}(), expr, [] return (Dict{Any, Any}(), expr) - elseif isequal(vars, []) + elseif isempty(vars) assumptions && return filter_stuff(expr), [] return filter_stuff(expr) end @@ -259,7 +263,7 @@ function filter_poly(og_expr, var; assumptions=false) # reassemble expr to avoid variables remembering original values issue and clean args = arguments(expr) oper = operation(expr) - new_expr, assum_array = clean_f(term(oper, args...), var, subs) + new_expr, assum_array = clean_f(term(oper, args...; type = symtype(expr), shape = shape(expr)), var, subs) assumptions && return subs, new_expr, assum_array return subs, new_expr @@ -301,7 +305,7 @@ function sdegree(coeffs, var) degree = 0 vars = collect(keys(coeffs)) for n in vars - isequal(n, 1) && continue + SymbolicUtils._isone(n) && continue isequal(n, var) && degree > 1 && continue if isequal(n, var) && degree < 1 @@ -310,9 +314,7 @@ function sdegree(coeffs, var) end args = arguments(n) - if args[2] > degree - degree = args[2] - end + degree = max(unwrap_const(args[2]), degree) end return degree end diff --git a/src/solver/solve_helpers.jl b/src/solver/solve_helpers.jl index f70ebdb05..cc73d10e1 100644 --- a/src/solver/solve_helpers.jl +++ b/src/solver/solve_helpers.jl @@ -33,11 +33,14 @@ function ssqrt(n) return sqrt(n) end - if n isa SymbolicUtils.BasicSymbolic{Real} + if symtype(n) === Real return term(ssqrt, n) end end +SymbolicUtils.promote_type(::typeof(ssqrt), ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::typeof(ssqrt), @nospecialize(sh::SymbolicUtils.ShapeT)) = sh + derivative(::typeof(ssqrt), args...) = substitute(derivative(sqrt, args...), sqrt => ssqrt) function scbrt(n) @@ -52,11 +55,13 @@ function scbrt(n) return (n)^(1 / 3) end - if n isa SymbolicUtils.BasicSymbolic{Real} + if symtype(n) === Real return term(scbrt, n) end end +SymbolicUtils.promote_type(::typeof(scbrt), ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::typeof(scbrt), @nospecialize(sh::SymbolicUtils.ShapeT)) = sh derivative(::typeof(scbrt), args...) = substitute(derivative(cbrt, args...), cbrt => scbrt) function slog(n) @@ -74,6 +79,9 @@ function slog(n) return term(slog, n) end +SymbolicUtils.promote_type(::typeof(slog), ::Type{T}) where {T} = T +SymbolicUtils.promote_shape(::typeof(slog), @nospecialize(sh::SymbolicUtils.ShapeT)) = sh + derivative(::typeof(slog), args...) = substitute(derivative(log, args...), log => slog) const RootsOf = (SymbolicUtils.@syms roots_of(poly,var))[1] @@ -85,9 +93,9 @@ Base.show(io::IO, r::typeof(slog)) = print(io, "slog") function check_expr_validity(expr) type_expr = typeof(expr) valid_type = false - - if type_expr <: Number || type_expr == Num || type_expr == SymbolicUtils.BasicSymbolic{Real} || - type_expr == Complex{Num} || type_expr == ComplexTerm{Real} || type_expr == SymbolicUtils.BasicSymbolic{Complex{Real}} + st = symtype(expr) + if type_expr <: Number || type_expr == Num || st <: Real || + type_expr == Complex{Num} || st <: Complex{Real} valid_type = true end iscall(unwrap(expr)) && @assert !hasderiv(unwrap(expr)) "Differential equations are not currently supported" @@ -103,14 +111,14 @@ end function check_poly_inunivar(poly, var) subs, filtered = filter_poly(poly, var) - coeffs, constant = polynomial_coeffs(filtered, [var]) - return isequal(constant, 0) + coeffs, constant = polynomial_coeffs(filtered, var isa Array ? var : [var]) + return SymbolicUtils._iszero(constant) end # converts everything to BIG function bigify(n) - n = unwrap(n) - if n isa ComplexTerm || n isa Float64 || n isa Irrational + n = value(n) + if n isa Float64 || n isa Irrational return n end @@ -118,7 +126,7 @@ function bigify(n) !iscall(n) && return n args = copy(parent(arguments(n))) for i in eachindex(args) - args[i] = bigify(args[i]) + args[i] = Const{VartypeT}(bigify(args[i])) end n = maketerm(typeof(n), operation(n), args, metadata(n)) return n @@ -144,7 +152,7 @@ function bigify(n) end function comp_rational(x, y) - x, y = wrap(bigify(x)), wrap(bigify(y)) + x, y = bigify(unwrap(x)), bigify(unwrap(y)) if !(unwrap(x) isa AbstractFloat || x isa Complex) && !(unwrap(y) isa AbstractFloat || y isa Complex) r = x // y diff --git a/src/solver/univar.jl b/src/solver/univar.jl index 8886f6cc9..eebba4bb5 100644 --- a/src/solver/univar.jl +++ b/src/solver/univar.jl @@ -1,22 +1,24 @@ -function get_roots_deg1(expression, x) - subs, filtered_expr = filter_poly(expression, x) - coeffs, constant = polynomial_coeffs(filtered_expr, [x]) - +function get_roots_deg1(expression, x, subs, coeffs) @assert isequal(sdegree(coeffs, x), 1) "Expected a polynomial of degree 1 in $x, got $expression" m = get(coeffs, x, 0) c = get(coeffs, x^0, 0) - - root = -c // m + root = -c / m root = unwrap(ssubs(root, subs)) return [root] end +function get_roots_deg1(expression, x) + subs, filtered_expr = filter_poly(expression, x) + coeffs, = polynomial_coeffs(filtered_expr, [x]) + get_roots_deg1(expression, x, subs, coeffs) +end + function get_deg2_with_coeffs(a, b, c) a, b, c = bigify(a), bigify(b), bigify(c) - root1 = (-b + term(ssqrt, (b^2 - 4(a * c)))) // 2a - root2 = (-b - term(ssqrt, (b^2 - 4(a * c)))) // 2a + root1 = (-b + term(ssqrt, (b^2 - 4(a * c)))) / 2a + root2 = (-b - term(ssqrt, (b^2 - 4(a * c)))) / 2a return [root1, root2] end @@ -28,11 +30,11 @@ function get_roots_deg2(expression, x) @assert isequal(sdegree(coeffs, x), 2) "Expected a polynomial of degree 2 in $x, got $expression" - results = (unwrap(ssubs(get(coeffs, x^i, 0), subs)) for i in 2:-1:0) + results = (bigify(ssubs(get(coeffs, x^i, 0), subs)) for i in 2:-1:0) a, b, c = results - root1 = (-b + term(ssqrt, (b^2 - 4(a * c)))) // 2a - root2 = (-b - term(ssqrt, (b^2 - 4(a * c)))) // 2a + root1 = (-b + term(ssqrt, (b^2 - 4(a * c)))) / 2a + root2 = (-b - term(ssqrt, (b^2 - 4(a * c)))) / 2a return [root1, root2] end @@ -43,18 +45,18 @@ function get_roots_deg3(expression, x) @assert isequal(sdegree(coeffs, x), 3) "Expected a polynomial of degree 3 in $x, got $expression" - results = (unwrap(ssubs(get(coeffs, x^i, 0), subs)) for i in 3:-1:0) + results = (bigify(unwrap(ssubs(get(coeffs, x^i, 0), subs))) for i in 3:-1:0) a, b, c, d = results - Q = (((3 * a * c) - b^2)) // (9a^2) - R = ((9 * a * b * c - ((27 * (a^2) * d) + 2b^3))) // (54a^3) + Q = (((3 * a * c) - b^2)) / (9a^2) + R = ((9 * a * b * c - ((27 * (a^2) * d) + 2b^3))) / (54a^3) S = term(scbrt, (R + term(ssqrt, (Q^3 + R^2)))) T = term(scbrt, (R - term(ssqrt, (Q^3 + R^2)))) - root1 = S + T - (b // (3 * a)) - root2 = -((S + T) // 2) - (b // (3 * a)) + (im * (term(ssqrt, 3)) / 2) * (S - T) - root3 = -((S + T) // 2) - (b // (3 * a)) - (im * (term(ssqrt, 3)) / 2) * (S - T) + root1 = S + T - (b / (3 * a)) + root2 = -((S + T) / 2) - (b // (3 * a)) + (im * (term(ssqrt, 3)) / 2) * (S - T) + root3 = -((S + T) / 2) - (b // (3 * a)) - (im * (term(ssqrt, 3)) / 2) * (S - T) return [root1, root2, root3] end @@ -65,7 +67,7 @@ function get_roots_deg4(expression, x) @assert isequal(sdegree(coeffs, x), 4) "Expected a polynomial of degree 4 in $x, got $expression" - results = (unwrap(ssubs(get(coeffs, x^i, 0), subs)) for i in 4:-1:0) + results = (bigify(unwrap(ssubs(get(coeffs, x^i, 0), subs))) for i in 4:-1:0) a, b, c, d, e = results p = (8(a * c) - 3(b^2)) // (8(a^2)) @@ -84,12 +86,12 @@ function get_roots_deg4(expression, x) # Yassin: this thing is a problem for parametric for root in roots_m vars = get_variables(root) - if isequal(vars, []) && !isequal(eval(toexpr(root)), 0) + if isempty(vars) && !SymbolicUtils._iszero(eval(toexpr(root))) m = unwrap(copy(wrap(root))) break end end - if isequal(m, 0) + if SymbolicUtils._iszero(m) @info "Assuming $(roots_m[1] != 0)" m = roots_m[1] end @@ -105,9 +107,9 @@ end function get_yroots(m, p, q) a = 1 b1 = term(ssqrt, 2m) - c1 = (p // 2) + m - (q // (2 * term(ssqrt, 2m))) + c1 = (p / 2) + m - (q / (2 * term(ssqrt, 2m))) b2 = -term(ssqrt, 2m) - c2 = (p // 2) + m + (q // (2 * term(ssqrt, 2m))) + c2 = (p / 2) + m + (q / (2 * term(ssqrt, 2m))) root1, root2 = get_deg2_with_coeffs(a, b1, c1) root3, root4 = get_deg2_with_coeffs(a, b2, c2) @@ -119,7 +121,7 @@ function get_roots(expression, x) subs, filtered_expr = filter_poly(expression, x) coeffs, constant = polynomial_coeffs(filtered_expr, [x]) - @assert isequal(constant, 0) "Expected a polynomial in $x, got $expression" + @assert SymbolicUtils._iszero(constant) "Expected a polynomial in $x, got $expression" degree = sdegree(coeffs, x) @@ -130,7 +132,7 @@ function get_roots(expression, x) end if degree == 1 - return get_roots_deg1(expression, x) + return get_roots_deg1(expression, x, subs, coeffs) end if degree == 2 diff --git a/src/struct.jl b/src/struct.jl deleted file mode 100644 index dc649a33b..000000000 --- a/src/struct.jl +++ /dev/null @@ -1,84 +0,0 @@ -struct Struct{T} <: Real -end - -""" - symstruct(T) - -Create a symbolic wrapper for struct from a given struct `T`. -""" -symstruct(::Type{T}) where T = Struct{T} -Struct{T}(vals...) where T = T(vals...) - -function Base.hash(x::Struct{T}, seed::UInt) where T - h1 = hash(T, seed) - h2 ⊻ (0x0e39036b7de2101a % UInt) -end - -""" - juliatype(s::Type{<:Struct}) - -Get the Julia type that `s` is representing. -""" -juliatype(::Type{Struct{T}}) where T = T -getelements(s::Type{<:Struct}) = fieldnames(juliatype(s)) -getelementtypes(s::Type{<:Struct}) = fieldtypes(juliatype(s)) - -typed_getfield(obj, ::Val{fieldname}) where fieldname = getfield(obj, fieldname) - -""" - symbolic_getproperty(ss, name::Symbol) - -Symbolic term corresponding to accessing the field with name `name`. -""" -function symbolic_getproperty(ss, name::Symbol) - s = symtype(ss) - idx = findfirst(isequal(name), getelements(s)) - idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!") - T = getelementtypes(s)[idx] - if isstructtype(T) - T = Struct{T} - end - SymbolicUtils.term(typed_getfield, ss, Val{name}(), type = T) -end -function symbolic_getproperty(s::Union{Arr, Num}, name::Symbol) - wrap(symbolic_getproperty(unwrap(s), name)) -end - -""" - symbolic_setproperty!(ss, name::Symbol) - -Symbolic term corresponding to modifying the field with name `name` to val `val`. -""" -function symbolic_setproperty!(ss, name::Symbol, val) - s = symtype(ss) - idx = findfirst(isequal(name), getelements(s)) - idx === nothing && error("$(juliatype(s)) doesn't have field $(name)!") - T = getelementtypes(s)[idx] - SymbolicUtils.term(setfield!, ss, Meta.quot(name), val, type = T) -end -function symbolic_setproperty!(s::Union{Arr, Num}, name::Symbol, val) - wrap(symbolic_setproperty!(unwrap(s), name, val)) -end - -function symbolic_constructor(s::Type{<:Struct}, vals...) - N = length(getelements(s)) - N′ = length(vals) - N′ == N || error("$(juliatype(s)) needs $N field. Got $N′ fields!") - SymbolicUtils.term(s, vals..., type = s) -end - -# We cannot precisely derive the type after `getfield` due to SU limitations, -# so give up and just say Real. -function SymbolicUtils.promote_symtype(::typeof(typed_getfield), ::Type{<:Struct{T}}, v::Type{Val{fieldname}}) where {T, fieldname} - FT = fieldtype(T, fieldname) - if isstructtype(FT) - return Struct{FT} - end - FT -end - -function SymbolicUtils.promote_symtype(s::Type{<:Struct{T}}, _...) where T - s -end - -SymbolicUtils.promote_symtype(::typeof(setfield!), ::Type{<:Struct}, _, ::Type{T}) where T = T diff --git a/src/taylor.jl b/src/taylor.jl index f322b6c1f..dff5e301a 100644 --- a/src/taylor.jl +++ b/src/taylor.jl @@ -85,9 +85,9 @@ function taylor_coeff(f, x, n = missing; rationalize=true, kwargs...) # TODO: error if x is not a "pure variable" D = Differential(x) n! = factorial(n) - c = (D^n)(f) / n! # TODO: optimize the implementation for multiple n with a loop that avoids re-differentiating the same expressions + c = (D^n)(f) # TODO: optimize the implementation for multiple n with a loop that avoids re-differentiating the same expressions c = expand_derivatives(c) - c = substitute(c, x => 0; kwargs...) + c = substitute(c, x => 0; kwargs...) / n! if rationalize && unwrap(c) isa Number # TODO: make rational coefficients "organically" and not using rationalize (see https://github.com/JuliaSymbolics/Symbolics.jl/issues/1299) c = unwrap(c) diff --git a/src/utils.jl b/src/utils.jl index 8359430e9..cbe195aea 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,77 +13,64 @@ function flatten_expr!(x) end xs end -function build_expr(head::Symbol, args) - ex = Expr(head) - append!(ex.args, args) - ex + +""" + $(TYPEDSIGNATURES) + +Equivalent to `unwrap_const ∘ unwrap`. +""" +value(x) = unwrap_const(unwrap(x)) + +function is_singleton(e) + if iscall(e) + op = operation(e) + op === getindex && return true + iscall(op) && return is_singleton(op) # recurse to reach getindex for array element variables + return issym(op) && !hasmetadata(e, CallWithParent) + else + return issym(e) + end end """ - get_variables(e, varlist = nothing; sort::Bool = false) + get_variables(e, varlist = nothing) -Return a vector of variables appearing in `e`, optionally restricting to variables in `varlist`. +Return a vector of variables appearing in e, optionally restricting to variables in varlist. -Note that the returned variables are not wrapped in the `Num` type. +Note that the returned variables are not wrapped in the Num type. -# Examples -```jldoctest +Examples +≡≡≡≡≡≡≡≡ + +```julia julia> @variables t x y z(t); -julia> Symbolics.get_variables(x + y + sin(z); sort = true) +julia> Symbolics.get_variables(x + y + sin(z)) 3-element Vector{SymbolicUtils.BasicSymbolic}: x y z(t) -julia> Symbolics.get_variables(x - y; sort = true) +julia> Symbolics.get_variables(x - y) 2-element Vector{SymbolicUtils.BasicSymbolic}: x y ``` """ -function get_variables(e::Num, varlist = nothing; sort::Bool = false) - get_variables(value(e), varlist; sort) -end -function get_variables(e, varlist = nothing; sort::Bool = false) - vars = Vector{BasicSymbolic}() - get_variables!(vars, e, varlist) - if sort - sort!(vars; by = SymbolicUtils.get_degrees) - end - vars -end - -get_variables!(vars, e::Num, varlist=nothing) = get_variables!(vars, value(e), varlist) -get_variables!(vars, e, varlist=nothing) = vars - -function is_singleton(e) - if iscall(e) - op = operation(e) - op === getindex && return true - iscall(op) && return is_singleton(op) # recurse to reach getindex for array element variables - return issym(op) && !hasmetadata(e, CallWithParent) - else - return issym(e) - end +function get_variables(e; kw...) + return search_variables(e; kw...) end -get_variables!(vars, e::Number, varlist=nothing) = vars - -function get_variables!(vars, e::Symbolic, varlist=nothing) - if is_singleton(e) - if isnothing(varlist) || any(isequal(e), varlist) - push!(vars, e) +function _get_is_atomic(varlist) + let vars = Set(varlist) + function _is_atomic(ex) + SymbolicUtils.default_is_atomic(ex) && ex in vars end - else - get_variables!(vars, operation(e), varlist) - foreach(x -> get_variables!(vars, x, varlist), arguments(e)) end - return (vars isa AbstractVector) ? unique!(vars) : vars end -function get_variables!(vars, e::Equation, varlist=nothing) - get_variables!(vars, e.rhs, varlist) +function get_variables(e, varlist; kw...) + search_variables(e; kw..., is_atomic = _get_is_atomic(varlist)) end """ @@ -126,7 +113,7 @@ get_differential_vars!(vars, e, varlist=nothing) = vars get_differential_vars!(vars, e::Number, varlist=nothing) = vars -function get_differential_vars!(vars, e::Symbolic, varlist=nothing) +function get_differential_vars!(vars, e::BasicSymbolic, varlist=nothing) if is_derivative(e) if isnothing(varlist) || any(isequal(e), varlist) push!(vars, e) @@ -148,11 +135,11 @@ function get_differential_vars!(vars, e::Equation, varlist=nothing) end # Sym / Term --> Symbol -Base.Symbol(x::Union{Num,Symbolic}) = tosymbol(x) +Base.Symbol(x::Num) = Symbol(unwrap(x)) tosymbol(t::Num; kwargs...) = tosymbol(value(t); kwargs...) """ - diff2term(x, x_metadata::Dict{Datatype, Any}) -> Symbolic + diff2term(x, x_metadata::Dict{Datatype, Any}) -> BasicSymbolic Convert a differential variable to a `Term`. Note that it only takes a `Term` not a `Num`. @@ -192,7 +179,7 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi d_separator = 'ˍ' if ds === nothing - return maketerm(typeof(O), TermInterface.head(O), map(diff2term, children(O)), + return maketerm(typeof(O), TermInterface.head(O), map(diff2term, arguments(O)), O_metadata isa Nothing ? metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...)) else @@ -207,7 +194,7 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi error("diff2term case not handled: $oldop") end newname = occursin(d_separator, opname) ? Symbol(opname, ds) : Symbol(opname, d_separator, ds) - return setname(maketerm(typeof(O), rename(oldop, newname), children(O), O_metadata isa Nothing ? + return setname(maketerm(typeof(O), rename(oldop, newname), arguments(O), O_metadata isa Nothing ? metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...)), newname) end end @@ -215,7 +202,7 @@ end setname(v, name) = setmetadata(v, Symbolics.VariableSource, (:variables, name)) """ - tosymbol(x::Union{Num,Symbolic}; states=nothing, escape=true) -> Symbol + tosymbol(x::Union{Num,BasicSymbolic}; states=nothing, escape=true) -> Symbol Convert `x` to a symbol. `states` are the states of a system, and `escape` means if the target has escapes like `val"y(t)"`. If `escape` is false, then @@ -265,7 +252,7 @@ function tosymbol(t; states=nothing, escape=true) end end -function lower_varname(var::Symbolic, idv, order) +function lower_varname(var::BasicSymbolic, idv, order) order == 0 && return var D = Differential(idv) for _ in 1:order @@ -274,34 +261,9 @@ function lower_varname(var::Symbolic, idv, order) return diff2term(var) end -### OOPS - -struct Unknown end - -macro oops(ex) - quote - tmp = $(esc(ex)) - if tmp === Unknown() - return Unknown() - else - tmp - end - end -end - -function makesubscripts(n) - set = 'i':'z' - m = length(set) - map(1:n) do i - repeats = ceil(Int, i / m) - c = set[(i-1) % m + 1] - Sym{Int}(Symbol(join([c for _ in 1:repeats], ""))) - end -end - function var_from_nested_derivative(x,i=0) x = unwrap(x) - if issym(x) || x isa CallWithMetadata + if issym(x) (x, i) elseif iscall(x) operation(x) isa Differential ? @@ -334,14 +296,17 @@ julia> Symbolics.degree(x^2) function degree(p, sym=nothing) p = value(p) sym = value(sym) - if p isa Number + if SymbolicUtils.isconst(p) || p isa Number return 0 end if isequal(p, sym) return 1 end if isterm(p) - if sym === nothing + if operation(p) === (^) + base, exp = arguments(p) + return unwrap_const(exp) * degree(base, sym) + elseif sym === nothing return 1 else return Int(isequal(p, sym)) @@ -350,8 +315,6 @@ function degree(p, sym=nothing) return sum(degree(k^v, sym) for (k, v) in zip(keys(p.dict), values(p.dict))) elseif isadd(p) return maximum(degree(key, sym) for key in keys(p.dict)) - elseif ispow(p) - return p.exp * degree(p.base, sym) elseif isdiv(p) return degree(p.num, sym) - degree(p.den, sym) elseif issym(p) @@ -392,7 +355,7 @@ function coeff(p, sym=nothing) # if `sym` is a product, iteratively compute the coefficient w.r.t. each term in `sym` if iscall(value(sym)) && operation(value(sym)) === (*) for t in arguments(value(sym)) - @assert !(t isa Number) "`coeff(p, sym)` does not allow `sym` containing numerical factors" + @assert !(t isa Number || SymbolicUtils.isconst(t)) "`coeff(p, sym)` does not allow `sym` containing numerical factors" p = coeff(p, t) end return p @@ -400,13 +363,11 @@ function coeff(p, sym=nothing) p, sym = value(p), value(sym) - if isequal(sym, 1) + if _isone(sym) sym = nothing end - if issym(p) || isterm(p) - sym === nothing ? 0 : Int(isequal(p, sym)) - elseif ispow(p) + if issym(p) || SymbolicUtils.isconst(p) || isterm(p) sym === nothing ? 0 : Int(isequal(p, sym)) elseif isadd(p) if sym===nothing @@ -431,7 +392,7 @@ function coeff(p, sym=nothing) end else p isa Number && return sym === nothing ? p : 0 - p isa Symbolic && return coeff(p, sym) + p isa BasicSymbolic && return coeff(p, sym) throw(DomainError(p, "Datatype $(typeof(p)) not accepted.")) end end @@ -441,9 +402,29 @@ end const DP = DynamicPolynomials # extracting underlying polynomial and coefficient type from Polyforms underlyingpoly(x::Number) = x -underlyingpoly(pf::PolyForm) = pf.p coefftype(x::Number) = typeof(x) -coefftype(pf::PolyForm) = DP.coefficient_type(underlyingpoly(pf)) +coefftype(x::DP.Polynomial) = eltype(MP.coefficients(x)) + +as_concrete_polynomial(x::Number) = x +function as_concrete_polynomial(x::DP.Polynomial) + coeffs = MP.coefficients(x) + isconcretetype(eltype(coeffs)) && return x + isempty(coeffs) && return poly_to_coefftype(Int, x) + T = typeof(coeffs[1]) + for coeff in coeffs + T = promote_type(T, typeof(coeff)) + end + poly_to_coefftype(T, x) +end + +function as_concrete_polynomial(x::SymbolicUtils.PolyVarT) + mv = DP.MonomialVector{SymbolicUtils.PolyVarOrder, SymbolicUtils.MonomialOrder}([x], [Int[1]]) + return DP.Polynomial(Int[1], mv) +end + +function poly_to_coefftype(::Type{T}, x::DP.Polynomial) where {T} + DP.Polynomial(Vector{T}(MP.coefficients(x)), MP.monomials(x)) +end #= Converts an array of symbolic polynomials @@ -456,31 +437,27 @@ function symbol_to_poly(sympolys::AbstractArray) stdsympolys = map(unwrap, sympolys) sort!(stdsympolys, lt=(<ₑ)) - pvar2sym = Bijections.Bijection{Any,Any}() - sym2term = Dict{BasicSymbolic,Any}() - polyforms = map(f -> PolyForm(f, pvar2sym, sym2term), stdsympolys) + symidx = findfirst(x -> x isa BasicSymbolic, stdsympolys) + varT = vartype(stdsympolys[symidx]) + poly_to_bs = Bijections.Bijection{SymbolicUtils.PolyVarT, BasicSymbolic{varT}}() + bs_to_poly = Bijections.active_inv(poly_to_bs) + polyforms = map(f -> as_concrete_polynomial(SymbolicUtils.to_poly!(poly_to_bs, bs_to_poly, f)), stdsympolys) # Discover common coefficient type commontype = mapreduce(coefftype, promote_type, polyforms, init=Int) @assert commontype <: Union{Integer,Rational} "Only integer and rational coefficients are supported as input." - # Convert all to DP.Polynomial, so that coefficients are of same type, - # and constants are treated as polynomials - # We also need this because Groebner does not support abstract types as input - polynoms = Vector{DP.Polynomial{DP.Commutative{DP.CreationOrder},DP.Graded{DP.LexOrder},commontype}}(undef, length(sympolys)) - for (i, pf) in enumerate(polyforms) - polynoms[i] = underlyingpoly(pf) - end + polynoms = map(Base.Fix1(poly_to_coefftype, commontype), polyforms) - polynoms, pvar2sym, sym2term + polynoms, poly_to_bs end #= Converts an array of AbstractPolynomialLike`s into an array of symbolic expressions mapping variables w.r.t pvar2sym =# -function poly_to_symbol(polys, pvar2sym, sym2term, ::Type{T}) where {T} - map(f -> PolyForm{T}(f, pvar2sym, sym2term), polys) +function poly_to_symbol(polys, poly_to_bs) + map(Base.Fix1(SymbolicUtils.from_poly, poly_to_bs), polys) end """ @@ -501,7 +478,7 @@ function symbolic_to_float end symbolic_to_float(x::Num) = symbolic_to_float(unwrap(x)) symbolic_to_float(x::Number) = x function symbolic_to_float(x::SymbolicUtils.BasicSymbolic) - substitute(x,Dict()) + unwrap_const(substitute(x,Dict())) end """ @@ -516,7 +493,7 @@ julia> numerator(x/y) x ``` """ -function Base.numerator(x::Union{Num, Symbolic}) +function Base.numerator(x::Union{Num, BasicSymbolic}) x = unwrap(x) if iscall(x) && operation(x) == / x = arguments(x)[1] # get numerator @@ -536,7 +513,7 @@ julia> denominator(x/y) y ``` """ -function Base.denominator(x::Union{Num, Symbolic}) +function Base.denominator(x::Union{Num, BasicSymbolic}) x = unwrap(x) if iscall(x) && operation(x) == / x = arguments(x)[2] # get denominator @@ -622,14 +599,14 @@ false function evaluate end function evaluate(eq::Equation, subs) - lhs = fast_substitute(eq.lhs, subs) - rhs = fast_substitute(eq.rhs, subs) + lhs = substitute(eq.lhs, subs) + rhs = substitute(eq.rhs, subs) return isequal(lhs, rhs) end function evaluate(ineq::Inequality, subs) - lhs = fast_substitute(ineq.lhs, subs) - rhs = fast_substitute(ineq.rhs, subs) + lhs = substitute(ineq.lhs, subs) + rhs = substitute(ineq.rhs, subs) if (ineq.relational_op == geq) return isless(rhs, lhs) elseif (ineq.relational_op == leq) @@ -639,4 +616,6 @@ function evaluate(ineq::Inequality, subs) end end - +vartype_from_args(::BasicSymbolic{T}, args...) where {T} = T +vartype_from_args(_, args...) = vartype_from_args(args...) +vartype_from_args() = error("Cannot infer `vartype`.") diff --git a/src/variable.jl b/src/variable.jl index fa751a6c2..456ba7f65 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -28,67 +28,28 @@ Symbolic metadata key for storing the macro used to create a symbolic variable. """ struct VariableSource <: AbstractVariableMetadata end -function recurse_and_apply(f, x) - if symtype(x) <: AbstractArray - getindex_posthook(x) do r,x,i... - recurse_and_apply(f, r) - end +function setdefaultval(x, val) + sh = shape(x) + if sh isa SymbolicUtils.Unknown + @assert sh.ndims == -1 || ndims(val) == sh.ndims """ + Variable $x must have default of matching `ndims`. Got $val with `ndims` \ + $(ndims(val)). + """ else - f(x) + @assert isempty(sh) || symtype(x) <: FnType || size(x) == size(val) """ + Variable $x must have default of matching size. Got $val with size \ + $(size(val)). + """ end + setmetadata(x, VariableDefaultValue, val) end -function set_scalar_metadata(x, V, val) - val = unwrap(val) - if val isa AbstractArray - val = unwrap.(val) - end - if symtype(x) <: AbstractArray - x = if val isa AbstractArray - getindex_posthook(x) do r,x,i... - set_scalar_metadata(r, V, val[i...]) - end - else - getindex_posthook(x) do r,x,i... - set_scalar_metadata(r, V, val) - end - end - end - setmetadata(x, V, val) -end -setdefaultval(x, val) = set_scalar_metadata(x, VariableDefaultValue, val) - -struct GetindexParent end - -function scalarize_getindex(x, parent=Ref{Any}(x)) - if symtype(x) <: AbstractArray - parent[] = getindex_posthook(x) do r,x,i... - scalarize_getindex(r, parent) - end - else - xx = unwrap(scalarize(x)) - xx = metadata(xx, metadata(x)) - if symtype(xx) <: FnType - setmetadata(CallWithMetadata(xx, metadata(xx)), GetindexParent, parent[]) - else - setmetadata(xx, GetindexParent, parent[]) - end - end -end - - function map_subscripts(indices) str = string(indices) join(IndexMap[c] for c in str) end -function unwrap_runtime_var(v) - isruntime = Meta.isexpr(v, :$) && length(v.args) == 1 - isruntime && (v = v.args[1]) - return isruntime, v -end - # Build variables more easily """ $(TYPEDSIGNATURES) @@ -101,11 +62,9 @@ macro. `transform` is an optional function that takes constructed variables and custom postprocessing to them, returning the created variables. This function returns the `Expr` for constructing the parsed variables. """ -parse_vars(macroname, type, x, transform=identity) = _parse_vars(macroname, type, x, transform=identity) - -function _parse_vars(macroname, type, x, transform=identity) +function parse_vars(macroname, type, x, transform = identity) ex = Expr(:block) - var_names = Symbol[] + var_names = Expr(:vect) # if parsing things in the form of # begin # x @@ -118,167 +77,141 @@ function _parse_vars(macroname, type, x, transform=identity) isoption(ex) = Meta.isexpr(ex, [:vect, :vcat, :hcat]) while cursor < length(x) cursor += 1 - v = x[cursor] + var_expr = x[cursor] - # We need lookahead to the next `v` to parse - # `@variables x [connect=Flow,unit=u]` - nv = cursor < length(x) ? x[cursor+1] : nothing - val = unit = connect = options = nothing - - # x = 1, [connect = flow; unit = u"m^3/s"] - if Meta.isexpr(v, :(=)) - v, val = v.args + default = nothing + options = nothing + if Meta.isexpr(var_expr, :(=)) + var_expr, default = var_expr.args # defaults with metadata for function variables - if Meta.isexpr(val, :block) - Base.remove_linenums!(val) - val = only(val.args) + if Meta.isexpr(default, :block) + Base.remove_linenums!(default) + default = only(default.args) end - if Meta.isexpr(val, :tuple) && length(val.args) == 2 && isoption(val.args[2]) - options = val.args[2].args - val = val.args[1] + if Meta.isexpr(default, :tuple) && length(default.args) == 2 && isoption(default.args[2]) + options = default.args[2].args + default = default.args[1] end + default = esc(default) end + parse_result = SymbolicUtils.parse_variable(var_expr; default_type = type) + handle_nonconcrete_symtype!(parse_result) + sym = SymbolicUtils.sym_from_parse_result(parse_result, VartypeT) + sym = handle_maybe_dependent_variable!(parse_result, sym, type) - type′ = type - - if Meta.isexpr(v, :(::)) - v, type′ = v.args - type′ = type′ === :Complex ? Complex{type} : type′ - end - - - # x [connect = flow; unit = u"m^3/s"] - if isoption(nv) - options = nv.args + if options === nothing && cursor < length(x) && isoption(x[cursor + 1]) + options = x[cursor + 1].args cursor += 1 end + sym = _add_metadata(parse_result, sym, default, macroname, options) + sym = handle_maybe_callandwrap!(parse_result, sym) + sym = Expr(:call, wrap, sym) - isruntime, v = unwrap_runtime_var(v) - iscall = Meta.isexpr(v, :call) - isarray = Meta.isexpr(v, :ref) - if iscall && Meta.isexpr(v.args[1], :ref) && !call_args_are_function(map(last∘unwrap_runtime_var, @view v.args[2:end])) - @warn("The variable syntax $v is deprecated. Use $(Expr(:ref, Expr(:call, v.args[1].args[1], v.args[2]), v.args[1].args[2:end]...)) instead. - The former creates an array of functions, while the latter creates an array valued function. - The deprecated syntax will cause an error in the next major release of Symbolics. - This change will facilitate better implementation of various features of Symbolics.") - end - issym = v isa Symbol - @assert iscall || isarray || issym "@$macroname expects a tuple of expressions or an expression of a tuple (`@$macroname x y z(t) v[1:3] w[1:2,1:4]` or `@$macroname x y z(t) v[1:3] w[1:2,1:4] k=1.0`)" - - if isarray && Meta.isexpr(v.args[1], :call) - # This is the new syntax - isruntime, fname = unwrap_runtime_var(v.args[1].args[1]) - call_args = map(last∘unwrap_runtime_var, @view v.args[1].args[2:end]) - size = v.args[2:end] - var_name, expr = construct_dep_array_vars(macroname, fname, type′, call_args, size, val, options, transform, isruntime) - elseif iscall - isruntime, fname = unwrap_runtime_var(v.args[1]) - call_args = map(last∘unwrap_runtime_var, @view v.args[2:end]) - var_name, expr = construct_vars(macroname, fname, type′, call_args, val, options, transform, isruntime) - elseif isarray - var_name, expr = construct_vars(macroname, v, type′, nothing, val, options, transform, isruntime) + if parse_result[:isruntime] + varname = Symbol(parse_result[:name]) else - var_name, expr = construct_vars(macroname, v, type′, nothing, val, options, transform, isruntime) + varname = esc(parse_result[:name]) end - - push!(var_names, var_name) - push!(ex.args, expr) + push!(var_names.args, varname) + push!(ex.args, Expr(:(=), varname, sym)) end - rhs = build_expr(:vect, var_names) - push!(ex.args, rhs) + push!(ex.args, var_names) return ex end -call_args_are_function(_) = false -function call_args_are_function(call_args::AbstractArray) - !isempty(call_args) && (call_args[end] == :(..) || all(Base.Fix2(Meta.isexpr, :(::)), call_args)) -end - -function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, prop, transform, isruntime) - ndim = :($length(($(indices...),))) - if call_args_are_function(call_args) - vname, fntype = function_name_and_type(lhs) - # name was already unwrapped before calling this function and is of the form $x - if isruntime - _vname = vname - else - # either no ::fnType or $x::fnType - vname, fntype = function_name_and_type(lhs) - isruntime, vname = unwrap_runtime_var(vname) - if isruntime - _vname = vname - else - _vname = Meta.quot(vname) - end - end - argtypes = arg_types_from_call_args(call_args) - ex = :($CallWithMetadata($Sym{$FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}}($_vname))) - else - vname = lhs - if isruntime - _vname = vname - else - _vname = Meta.quot(vname) - end - ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}}}($_vname)(map($unwrap, ($(call_args...),))...)) - end - ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) - - if val !== nothing - ex = :($setdefaultval($ex, $val)) - end - ex = setprops_expr(ex, prop, macroname, Meta.quot(vname)) - #ex = :($scalarize_getindex($ex)) - - ex = :($wrap($ex)) - - ex = :($transform($ex)) - if isruntime - vname = gensym(Symbol(vname)) +function handle_nonconcrete_symtype!(parse_result) + type = parse_result[:type] + if type == :Complex + parse_result[:type] = :(Complex{Real}) end - vname, :($vname = $ex) -end - -function construct_vars(macroname, v, type, call_args, val, prop, transform, isruntime) - issym = v isa Symbol - isarray = !isruntime && Meta.isexpr(v, :ref) - if isarray - # this can't be an array of functions, since that was handled by `construct_dep_array_vars` - var_name = v.args[1] - if Meta.isexpr(var_name, :(::)) - var_name, type′ = var_name.args - type = type′ === :Complex ? Complex{type} : type′ + if Meta.isexpr(type, :curly) + if type.args[1] in (:Array, :Vector, :Matrix, Array, Vector, Matrix) && type.args[2] == :Complex + type.args[2] = :(Complex{Real}) end - isruntime, var_name = unwrap_runtime_var(var_name) - indices = v.args[2:end] - expr = _construct_array_vars(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop, indices...) - elseif call_args_are_function(call_args) - var_name, fntype = function_name_and_type(v) - # name was already unwrapped before calling this function and is of the form $x - if isruntime - vname = var_name - else - # either no ::fnType or $x::fnType - var_name, fntype = function_name_and_type(v) - isruntime, var_name = unwrap_runtime_var(var_name) - if isruntime - vname = var_name - else - vname = Meta.quot(var_name) + if type.args[1] == :FnType || type.args[1] == SymbolicUtils.FnType + if Meta.isexpr(type.args[2], :curly) # Tuple{...} + for i in 2:length(type.args[2].args) + if type.args[2].args[i] == :Complex + type.args[2].args[i] = :(Complex{Real}) + end + end + end + if type.args[3] == :Complex + type.args[3] = :(Complex{Real}) + end + for parse_arg in parse_result[:args] + handle_nonconcrete_symtype!(parse_arg) end end - expr = construct_var(macroname, fntype == () ? vname : Expr(:(::), vname, fntype[1]), type, call_args, val, prop) + end + return nothing +end + +function parse_result_is_dependent_variable(parse_result) + # This means it is a function call + return haskey(parse_result, :args) && + # This checks `fnT` in `FnType{argsT, retT, fnT}` + parse_result[:type].args[4] === Nothing && + # This ensures all arguments have defined names + all(n -> n !== nothing && n != :.., + (get(arg, :name, nothing) for arg in parse_result[:args])) +end + +function handle_maybe_dependent_variable!(parse_result, sym, type) + # is a function call and the function doesn't have a type and all arguments + # are named + parse_result_is_dependent_variable(parse_result) || return sym + + args = parse_result[:args] + argnames = Any[get(arg, :name, nothing) for arg in args] + # if the last arg is a `Vararg`, splat it + if !isempty(args) && Meta.isexpr(args[end][:type], :curly) && args[end][:type].args[1] == :Vararg + argnames[end] = Expr(:..., argnames[end]) + end + # Turn the result into something of the form `@variables x(..)`. + # This makes it so that the `FnType` is recognized as a dependent variable + # according to `SymbolicUtils.is_function_symtype` + parse_result[:args] = [SymbolicUtils.parse_variable(:(..); default_type = type)] + parse_result[:type].args[2] = Tuple + # Re-create the `Sym` + sym = SymbolicUtils.sym_from_parse_result(parse_result, VartypeT) + # Change the type + parse_result[:type] = parse_result[:type].args[3] + # Call the `Sym` with the arguments to create a dependent variable. + map!(esc, argnames, argnames) + sym = Expr(:call, sym) + append!(sym.args, argnames) + return sym +end + +function handle_maybe_callandwrap!(parse_result, sym) + type = parse_result[:type] + if Meta.isexpr(type, :curly) && (type.args[1] == :FnType || type.args[1] === SymbolicUtils.FnType) + sym = Expr(:call, CallAndWrap, sym) + end + return sym +end + +function _add_metadata(parse_result, var::Expr, default, macroname::Symbol, metadata::Union{Nothing, Vector{Any}}) + @nospecialize var default metadata + if default !== nothing + var = Expr(:call, setdefaultval, var, default) + end + varname = parse_result[:name] + if parse_result[:isruntime] + varname = esc(varname) else - var_name = v - if Meta.isexpr(v, :(::)) - var_name, type′ = v.args - type = type′ === :Complex ? Complex{type} : type′ - end - expr = construct_var(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop) + varname = Meta.quot(varname) + end + var = Expr(:call, setmetadata, var, VariableSource, Expr(:tuple, Meta.quot(macroname), varname)) + metadata === nothing && return var + for ex in metadata + Meta.isexpr(ex, :(=)) || error("Metadata must of the form of `key = value`") + key, value = ex.args + key_type = option_to_metadata_type(Val{key}())::DataType + var = Expr(:call, setmetadata, var, key_type, esc(value)) end - lhs = isruntime ? gensym(Symbol(var_name)) : var_name - rhs = :($transform($expr)) - lhs, :($lhs = $rhs) + return var end """ @@ -313,164 +246,6 @@ option_to_metadata_type(::Val{:_____!_internal_3}) = error("Invalid option") option_to_metadata_type(::Val{:_____!_internal_4}) = error("Invalid option") option_to_metadata_type(::Val{:_____!_internal_5}) = error("Invalid option") -function setprops_expr(expr, props, macroname, varname) - expr = :($setmetadata($expr, $VariableSource, ($(Meta.quot(macroname)), $varname,))) - isnothing(props) && return expr - for opt in props - if !Meta.isexpr(opt, :(=)) - throw(Base.Meta.ParseError( - "Variable properties must be in " * - "the form of `a = b`. Got $opt.")) - end - - lhs, rhs = opt.args - - @assert lhs isa Symbol "the lhs of an option must be a symbol" - expr = :($set_scalar_metadata($expr, - $(option_to_metadata_type(Val{lhs}())), - $rhs)) - end - expr -end - -struct CallWithMetadata{T,M} <: Symbolic{T} - f::Symbolic{T} - metadata::M -end - -for f in [:iscall, :operation, :arguments] - @eval SymbolicUtils.$f(x::CallWithMetadata) = $f(x.f) -end - -SymbolicUtils.Code.toexpr(x::CallWithMetadata, st) = SymbolicUtils.Code.toexpr(x.f, st) - -CallWithMetadata(f) = CallWithMetadata(f, nothing) - -SymbolicIndexingInterface.symbolic_type(::Type{<:CallWithMetadata}) = ScalarSymbolic() - -# HACK: -# A `DestructuredArgs` with `create_bindings = false` doesn't create a `Let` block, and -# instead adds the assignments to the rewrites dictionary. This is problematic, because -# if the `DestructuredArgs` contains a `CallWithMetadata` the key in the `Dict` will be -# a `CallWithMetadata` which won't match against the operation of the called symbolic. -# This is the _only_ hook we have and relies on the `DestructuredArgs` being converted -# into a list of `Assignment`s before being added to the `Dict` inside `toexpr(::Let, st)`. -# The callable symbolic is unwrapped so it matches the operation of the called version. -SymbolicUtils.Code.Assignment(f::CallWithMetadata, x) = SymbolicUtils.Code.Assignment(f.f, x) - -function Base.show(io::IO, c::CallWithMetadata) - show(io, c.f) - print(io, "⋆") -end - -struct CallWithParent end - -function (f::CallWithMetadata)(args...) - setmetadata(metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f)), CallWithParent, f) -end - -Base.isequal(a::CallWithMetadata, b::CallWithMetadata) = isequal(a.f, b.f) -Base.hash(x::CallWithMetadata, h::UInt) = hash(x.f, h) - -function arg_types_from_call_args(call_args) - if length(call_args) == 1 && only(call_args) == :.. - return Tuple - end - Ts = map(call_args) do arg - if arg == :.. - Vararg - elseif arg isa Expr && arg.head == :(::) - if length(arg.args) == 1 - arg.args[1] - elseif arg.args[1] == :.. - :(Vararg{$(arg.args[2])}) - else - arg.args[2] - end - else - error("Invalid call argument $arg") - end - end - return :(Tuple{$(Ts...)}) -end - -function function_name_and_type(var_name) - if var_name isa Expr && Meta.isexpr(var_name, :(::), 2) - var_name.args[1], (var_name.args[2],) - else - var_name, () - end -end - -function construct_var(macroname, var_name, type, call_args, val, prop) - expr = if call_args === nothing - :($Sym{$type}($var_name)) - elseif call_args_are_function(call_args) - # function syntax is (x::TFunc)(.. or ::TArg1, ::TArg2)::TRet - # .. is Vararg - # (..)::ArgT is Vararg{ArgT} - var_name, fntype = function_name_and_type(var_name) - argtypes = arg_types_from_call_args(call_args) - :($CallWithMetadata($Sym{$FnType{$argtypes, $type, $(fntype...)}}($var_name))) - # This elseif handles the special case with e.g. variables on the form - # @variables X(deps...) where deps is a vector (which length might be unknown). - elseif (call_args isa Vector) && (length(call_args) == 1) && (call_args[1] isa Expr) && - call_args[1].head == :(...) && (length(call_args[1].args) == 1) - :($Sym{$FnType{NTuple{$length($(call_args[1].args[1])), Any}, $type}}($var_name)($value.($(call_args[1].args[1]))...)) - else - :($Sym{$FnType{NTuple{$(length(call_args)), Any}, $type}}($var_name)($(map(x->:($value($x)), call_args)...))) - end - - if val !== nothing - expr = :($setdefaultval($expr, $val)) - end - - :($wrap($(setprops_expr(expr, prop, macroname, var_name)))) -end - -struct CallWith - args -end - -(c::CallWith)(f) = unwrap(f(c.args...)) - -function _construct_array_vars(macroname, var_name, type, call_args, val, prop, indices...) - # TODO: just use Sym here - ndim = :($length(($(indices...),))) - - need_scalarize = false - expr = if call_args === nothing - ex = :($Sym{Array{$type, $ndim}}($var_name)) - :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) - elseif call_args_are_function(call_args) - need_scalarize = true - var_name, fntype = function_name_and_type(var_name) - argtypes = arg_types_from_call_args(call_args) - ex = :($Sym{Array{$FnType{$argtypes, $type, $(fntype...)}, $ndim}}($var_name)) - ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) - :($map($CallWithMetadata, $ex)) - else - # [(R -> R)(R) ....] - need_scalarize = true - ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name)) - ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) - :($map($CallWith(($(call_args...),)), $ex)) - end - - if val !== nothing - expr = :($setdefaultval($expr, $val)) - end - expr = setprops_expr(expr, prop, macroname, var_name) - if need_scalarize - expr = :($scalarize_getindex($expr)) - end - - expr = :($wrap($expr)) - - return expr -end - - """ Define one or more unknown variables. @@ -535,56 +310,22 @@ julia> (t, a, b, c) ``` """ macro variables(xs...) - esc(_parse_vars(:variables, Real, xs)) + parse_vars(:variables, Real, xs) end const _fail = Dict() -_getname(x, _) = nameof(x) -_getname(x::Symbol, _) = x -function _getname(x::Symbolic, val) - issym(x) && return nameof(x) - if iscall(x) && issym(operation(x)) - return nameof(operation(x)) - end - if !hasmetadata(x, Symbolics.GetindexParent) && iscall(x) && operation(x) == getindex - return _getname(arguments(x)[1], val) - end - ss = getsource(x, nothing) - if ss === nothing - ss = getsource(getparent(x), val) - end - ss === _fail && throw(ArgumentError("Variable $x doesn't have a source defined.")) - ss[2] -end - getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val) SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() -function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: Symbolic{S}} - ArraySymbolic() -end -# need this otherwise the `::Type{<:BasicSymbolic}` method in SymbolicUtils is -# more specific -function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: BasicSymbolic{S}} - ArraySymbolic() -end SymbolicIndexingInterface.hasname(x::Union{Num,Arr,Complex{Num}}) = hasname(unwrap(x)) - -function SymbolicIndexingInterface.hasname(x::Symbolic) - issym(x) || !iscall(x) || iscall(x) && (issym(operation(x)) || operation(x) == getindex && hasname(arguments(x)[1])) +function SymbolicIndexingInterface.getname(x::Union{Num, Arr, Complex{Num}}) + SymbolicIndexingInterface.getname(unwrap(x)) end -# This is type piracy, but changing it breaks precompilation for MTK because it relies on this falling back to -# `_getname` which calls `nameof` which returns the name of the system, when `x::AbstractSystem`. -# FIXME: In a breaking release -function SymbolicIndexingInterface.getname(x, val = _fail) - _getname(unwrap(x), val) -end - -function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symbolic, Equation, Inequality}, d::Dict; kwargs...) +function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, BasicSymbolic, Equation, Inequality}, d::Dict; kwargs...) val = fixpoint_sub(ex, d; kwargs...) return _recursive_unwrap(val) end @@ -618,6 +359,15 @@ function _recursive_unwrap(val::AbstractSparseArray) end end +struct FPSubFilterer{O} end + +function (::FPSubFilterer{O})(ex::BasicSymbolic{T}) where {T, O} + @match ex begin + BSImpl.Term(; f) && if f isa Operator end => !(f isa O) + _ => SymbolicUtils.default_substitute_filter(ex) + end +end + """ fixpoint_sub(expr, dict; operator = Nothing, maxiters = 1000) @@ -629,16 +379,14 @@ specified to prevent substitution of expressions inside operators of the given t `maxiters` keyword is used to limit the number of times the substitution can occur to avoid infinite loops in cases where the substitutions in `dict` are circular (e.g. `[x => y, y => x]`). - -See also: [`fast_substitute`](@ref). """ function fixpoint_sub(x, dict; operator = Nothing, maxiters = 1000) dict = subrules_to_dict(dict) - y = fast_substitute(x, dict; operator) + y = substitute(x, dict; filterer=FPSubFilterer{operator}()) iters = maxiters while !isequal(x, y) && iters > 0 y = x - x = fast_substitute(y, dict; operator) + x = substitute(y, dict; filterer=FPSubFilterer{operator}()) iters -= 1 end @@ -655,106 +403,6 @@ function fixpoint_sub(x::SparseMatrixCSC, dict; operator = Nothing, maxiters = 1 return sparse(I, J, V, m, n) end -const Eq = Union{Equation, Inequality} -""" - fast_substitute(expr, dict; operator = Nothing) - -Given a symbolic expression, equation or inequality `expr` perform the substitutions in -`dict`. This only performs the substitutions once. For example, -`fast_substitute(x, Dict(x => y, y => 3))` will return `y`. The `operator` keyword can be -specified to prevent substitution of expressions inside operators of the given type. - -See also: [`fixpoint_sub`](@ref). -""" -function fast_substitute(eq::Eq, subs; operator = Nothing) - if eq isa Inequality - Inequality(fast_substitute(eq.lhs, subs; operator), - fast_substitute(eq.rhs, subs; operator), - eq.relational_op) - else - Equation(fast_substitute(eq.lhs, subs; operator), - fast_substitute(eq.rhs, subs; operator)) - end -end -function fast_substitute(eq::T, subs::Pair; operator = Nothing) where {T <: Eq} - T(fast_substitute(eq.lhs, subs; operator), fast_substitute(eq.rhs, subs; operator)) -end -function fast_substitute(eqs::AbstractArray, subs; operator = Nothing) - fast_substitute.(eqs, (subs,); operator) -end -function fast_substitute(eqs::AbstractArray, subs::Pair; operator = Nothing) - fast_substitute.(eqs, (subs,); operator) -end -function fast_substitute(eqs::SparseMatrixCSC, subs; operator = Nothing) - I, J, V = findnz(eqs) - V = fast_substitute(V, subs; operator) - m, n = size(eqs) - return sparse(I, J, V, m, n) -end -for (exprType, subsType) in Iterators.product((Num, Symbolics.Arr), (Any, Pair)) - @eval function fast_substitute(expr::$exprType, subs::$subsType; operator = Nothing) - fast_substitute(value(expr), subs; operator) - end -end -function fast_substitute(expr, subs; operator = Nothing) - if (_val = get(subs, expr, nothing)) !== nothing - return _val - end - iscall(expr) || return expr - op = fast_substitute(operation(expr), subs; operator) - args = SymbolicUtils.arguments(expr) - if !(op isa operator) - canfold = Ref(!(op isa Symbolic)) - args = let canfold = canfold - map(args) do x - symbolic_type(x) == NotSymbolic() && !is_array_of_symbolics(x) && return x - x′ = fast_substitute(x, subs; operator) - canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) - x′ - end - end - if op === getindex && symbolic_type(args[1]) == NotSymbolic() - canfold[] = true - end - canfold[] && return op(args...) - end - maketerm(typeof(expr), - op, - args, - metadata(expr)) -end -function fast_substitute(expr, pair::Pair; operator = Nothing) - a, b = pair - isequal(expr, a) && return b - if a isa AbstractArray - for (ai, bi) in zip(a, b) - expr = fast_substitute(expr, ai => bi; operator) - end - end - iscall(expr) || return expr - op = fast_substitute(operation(expr), pair; operator) - args = SymbolicUtils.arguments(expr) - if !(op isa operator) - canfold = Ref(!(op isa Symbolic)) - args = let canfold = canfold - map(args) do x - symbolic_type(x) == NotSymbolic() && !is_array_of_symbolics(x) && return x - x′ = fast_substitute(x, pair; operator) - canfold[] = canfold[] && (symbolic_type(x′) == NotSymbolic() && !is_array_of_symbolics(x′)) - x′ - end - end - if op === getindex && symbolic_type(args[1]) == NotSymbolic() - canfold[] = true - end - canfold[] && return op(args...) - end - maketerm(typeof(expr), - op, - args, - metadata(expr)) -end - function is_array_of_symbolics(x) symbolic_type(x) == ArraySymbolic() && return true symbolic_type(x) == ScalarSymbolic() && return false @@ -786,7 +434,13 @@ function getdefaultval(x, val=_fail) if val !== _fail return val else - error("$x has no default value") + @match x begin + BSImpl.Term(; f, args) && if f === getindex end => begin + idxs = Iterators.map(unwrap_const, Iterators.drop(args, 1)) + return getdefaultval(args[1], val)[idxs...] + end + _ => error("$x has no default value") + end end end @@ -828,11 +482,8 @@ Also see `variables`. """ function variable(name, idx...; T=Real) name_ij = Symbol(name, join(map_subscripts.(idx), "ˏ")) - v = Sym{T}(name_ij) - if T <: FnType - v = CallWithMetadata(v) - end - Num(setmetadata(v, VariableSource, (:variables, name_ij))) + v = Sym{VartypeT}(name_ij; type = T) + wrap(setmetadata(v, VariableSource, (:variables, name_ij))) end ##### Renaming ##### @@ -845,69 +496,90 @@ end # x_t # sys.x -function rename_getindex_source(x, parent=x) - getindex_posthook(x) do r,x,i... - hasmetadata(r, GetindexParent) ? setmetadata(r, GetindexParent, parent) : r +function renamed_metadata(metadata::Union{Nothing, SymbolicUtils.MetadataT}, name::Symbol) + @nospecialize metadata + if metadata === nothing + return metadata + elseif metadata isa Base.ImmutableDict{DataType, Any} + newmeta = Base.ImmutableDict{DataType, Any}() + for (k, v) in metadata + if k === VariableSource + v = v::NTuple{2, Symbol} + v = (v[1], name) + end + newmeta = Base.ImmutableDict(newmeta, k, v) + end + return newmeta end + error() end -function rename_metadata(from, to, name) - if hasmetadata(from, VariableSource) - s = getmetadata(from, VariableSource) - to = setmetadata(to, VariableSource, (s[1], name)) - end - if hasmetadata(from, GetindexParent) - s = getmetadata(from, GetindexParent) - to = setmetadata(to, GetindexParent, rename(s, name)) +rename(x::Union{Num, Arr}, name) = wrap(rename(unwrap(x), name)) + +function rename(x::BasicSymbolic{T}, newname::Symbol) where {T} + @match x begin + BSImpl.Sym(; name, type, shape, metadata) => begin + metadata = renamed_metadata(metadata, newname) + return BSImpl.Sym{T}(newname; type, shape, metadata) + end + BSImpl.Term(; f, args, type, shape, metadata) && if f === getindex end => begin + newargs = copy(parent(args)) + newargs[1] = rename(newargs[1], newname) + return BSImpl.Term{T}(f, newargs; type, shape, metadata) + end + BSImpl.Term(; f, args, type, shape, metadata) && if f isa BasicSymbolic{T} end => begin + f = rename(f, newname) + metadata = renamed_metadata(metadata, newname) + return BSImpl.Term{T}(f, args; type, shape, metadata) + end + _ => error("Cannot rename $x.") end - return to end -rename(x::Union{Num, Arr}, name) = wrap(rename(unwrap(x), name)) -function rename(x::ArrayOp, name) - t = x.term - args = arguments(t) - # Hack: - @assert operation(t) === (map) && (args[1] isa CallWith || args[1] == CallWithMetadata) - rn = rename(args[2], name) - - xx = metadata(operation(t)(args[1], rn), metadata(x)) - rename_getindex_source(rename_metadata(x, xx, name)) -end +#### Callable ### -function rename(x::CallWithMetadata, name) - rename_metadata(x, CallWithMetadata(rename(x.f, name), x.metadata), name) +struct CallAndWrap{T} + f::BasicSymbolic{VartypeT} end -function rename(x::Symbolic, name) - if issym(x) - xx = @set! x.name = name - xx = rename_metadata(x, xx, name) - symtype(xx) <: AbstractArray ? rename_getindex_source(xx) : xx - elseif iscall(x) && operation(x) === getindex - rename(arguments(x)[1], name)[arguments(x)[2:end]...] - elseif iscall(x) && symtype(operation(x)) <: FnType || operation(x) isa CallWithMetadata - xx = @set x.f = rename(operation(x), name) - @set! xx.hash = Ref{UInt}(0) - return rename_metadata(x, xx, name) +rettype_from_fntype(::Type{T}) where {A, R, T <: SymbolicUtils.FnType{A, R}} = R + +function CallAndWrap(f::BasicSymbolic{VartypeT}) + @assert symtype(f) <: SymbolicUtils.FnType + R = rettype_from_fntype(symtype(f)) + if hasmethod(wrapper_type, Tuple{Type{R}}) + CallAndWrap{wrapper_type(R)}(f) else - error("can't rename $x to $name") + f end end -# Deprecation below - -struct Variable{T} end - -function (::Type{Variable{T}})(s, i...) where {T} - Base.depwarn("Variable{T}(name, idx...) is deprecated, use variable(name, idx...; T=T)", :Variable) - variable(s, i...; T=T) +function (caw::CallAndWrap{T})(args...) where {T} + T(caw.f(args...)) end -(::Type{Variable})(s, i...) = Variable{Real}(s, i...) +SymbolicIndexingInterface.symbolic_type(::Type{<:CallAndWrap}) = ScalarSymbolic() + +Base.isequal(a::CallAndWrap, b::CallAndWrap) = isequal(a.f, b.f) +Base.isequal(a::BasicSymbolic{VartypeT}, b::CallAndWrap) = isequal(a, b.f) +Base.isequal(a::CallAndWrap, b::BasicSymbolic{VartypeT}) = isequal(a.f, b) +Base.hash(x::CallAndWrap, h::UInt) = hash(x.f, h) -function (::Type{Sym{T}})(s, x, i...) where {T} - Base.depwarn("Sym{T}(name, x, idx...) is deprecated, use variable(name, x, idx...; T=T)", :Variable) - variable(s, x, i...; T=T) +function Base.show(io::IO, caw::CallAndWrap) + show(io, caw.f) + print(io, "⋆") end -(::Type{Sym})(s, x, i...) = Sym{Real}(s, x, i...) + +has_symwrapper(::Type{T}) where {A, R, T <: SymbolicUtils.FnType{A, R}} = has_symwrapper(R) +wrapper_type(::Type{T}) where {A, R, T <: SymbolicUtils.FnType{A, R}} = CallAndWrap{wrapper_type(R)} +is_wrapper_type(::Type{T}) where {T <: CallAndWrap} = true +wraps_type(::Type{T}) where {W, T <: CallAndWrap{W}} = FnType{A, R} where {A, R <: wraps_type(W)} +iswrapped(::CallAndWrap) = true +SymbolicUtils.unwrap(x::CallAndWrap) = x.f +SymbolicUtils.symtype(x::CallAndWrap) = symtype(x.f) +SymbolicIndexingInterface.getname(x::CallAndWrap) = getname(x.f) +SymbolicIndexingInterface.hasname(x::CallAndWrap) = hasname(x.f) +Symbolics.rename(x::CallAndWrap, name) = CallAndWrap(rename(x.f, name)) +SymbolicUtils.getmetadata(x::CallAndWrap, args...) = SymbolicUtils.getmetadata(x.f, args...) +SymbolicUtils.setmetadata(x::CallAndWrap, args...) = SymbolicUtils.setmetadata(x.f, args...) +SymbolicUtils.hasmetadata(x::CallAndWrap, args...) = SymbolicUtils.hasmetadata(x.f, args...) diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index bf30f8913..db6ae09f9 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -17,8 +17,9 @@ function set_where(subt, supert) Expr(:where, supert, Ts...) end -function SymbolicIndexingInterface.getname(x::Expr) - @assert x.head == :curly +function _get_type_name(x::Union{Symbol, Expr}) + x isa Symbol && return x + @assert Meta.isexpr(x, :curly) return x.args[1] end @@ -26,7 +27,7 @@ macro symbolic_wrap(expr) @assert expr isa Expr && expr.head == :struct @assert expr.args[2].head == :(<:) sig = expr.args[2] - T = getname(sig.args[1]) + T = _get_type_name(sig.args[1]) supertype = set_where(sig.args[1], sig.args[2]) quote @@ -41,12 +42,6 @@ macro symbolic_wrap(expr) end iswrapped(x) = false -""" - $(TYPEDSIGNATURES) - -Return the symbolic or non-symbolic value wrapped by a type such as `Num`. -""" -unwrap(x) = x """ $(TYPEDSIGNATURES) @@ -112,30 +107,30 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) # for every argument find the types that # should be allowed as argument. These are: # - # (1) T (2) wrapper_type(T) (3) Symbolic{T} + # (1) T (2) wrapper_type(T) (3) BasicSymbolic # # However later while emitting methods we omit the one # method where all arguments are (1) since those are # expected to be defined outside Symbolics if arg isa Expr && arg.head == :(::) T = Base.eval(mod, arg.args[2]) - Ts = has_symwrapper(T) ? (T, :(Symbolics.SymbolicUtils.Symbolic{<:$T}), wrapper_type(T)) : - (T,:(Symbolics.SymbolicUtils.Symbolic{<:$T})) + Ts = has_symwrapper(T) ? (T, BasicSymbolic{VartypeT}, wrapper_type(T)) : + (T, BasicSymbolic{VartypeT}) if T <: AbstractArray && wrap_arrays eT = eltype(T) if eT == Any eT = Real end _arr_type_fn = if hasmethod(ndims, Tuple{Type{T}}) - (elT) -> :(AbstractArray{T, $(ndims(T))} where {T <: $elT}) + (elT) -> AbstractArray{S, ndims(T)} where {S <: elT} else - (elT) -> :(AbstractArray{T} where {T <: $elT}) + (elT) -> AbstractArray{S} where {S <: elT} end if has_symwrapper(eT) - Ts = (Ts..., _arr_type_fn(:(Symbolics.SymbolicUtils.Symbolic{<:$eT})), + Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}}, _arr_type_fn(wrapper_type(eT))) else - Ts = (Ts..., _arr_type_fn(:(Symbolics.SymbolicUtils.Symbolic{<:$eT}))) + Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}}) end end Ts @@ -158,19 +153,40 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) :($n::$T) end - fbody = :(if any($iswrapped, ($(names...),)) - $wrap($impl_name($self, $([:($unwrap($arg)) for arg in names]...))) - else - $impl_name($self, $(names...)) - end) + any_wrapper = false + impl_args = map(enumerate(names)) do (i, name) + if is_wrapper_type(Ts[i]) + any_wrapper = true + :($unwrap($name)) + elseif Ts[i] <: AbstractArray && is_wrapper_type(Ts[i].var.ub) + any_wrapper = true + :($_recursive_unwrap($name)) + else + name + end + end + implcall = :($impl_name($self, $(impl_args...))) + if any_wrapper + implcall = :($wrap($implcall)) + end + + body = Expr(:block) + for (i, T) in enumerate(Ts) + if T === BasicSymbolic{VartypeT} + push!(body.args, :(@assert $symtype($(names[i])) <: $(types[i][1]))) + elseif T === AbstractArray{BasicSymbolic{VartypeT}} && eltype(types[i][1]) !== Any + push!(body.args, :(@assert $symtype($(names[i])[1]) <: $(eltype(types[i][1])))) + end + end + push!(body.args, implcall) if isempty(kwargs) :(function $fname($(method_args...)) - $fbody + $body end) else :(function $fname($(method_args...); $(kwargs...)) - $fbody + $body end) end end diff --git a/test/arrays.jl b/test/arrays.jl index dc1450477..b316b2cae 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -1,8 +1,8 @@ using Symbolics using SymbolicUtils, Test -using Symbolics: symtype, shape, wrap, unwrap, Unknown, Arr, array_term, jacobian, @variables, value, get_variables, @arrayop, getname, metadata, scalarize +using Symbolics: symtype, shape, wrap, unwrap, Arr, jacobian, @variables, value, get_variables, @arrayop, getname, metadata, scalarize using Base: Slice -using SymbolicUtils: Sym, term, operation +using SymbolicUtils: Sym, term, operation, search_variables import LinearAlgebra: dot, Adjoint import ..limit2 @@ -11,10 +11,11 @@ Symbolics.option_to_metadata_type(::Val{:test_meta}) = TestMetaT @testset "arrays" begin @variables X[1:5, 1:5] Y[1:5, 1:5] - @test_throws BoundsError X[1000] + @test_throws ArgumentError X[1000] + @test_throws BoundsError X[10, 1] @test typeof(X) <: Arr - @test shape(X) == Slice.((1:5, 1:5)) - @test shape(Y) == Slice.((1:5, 1:5)) + @test shape(X) == [1:5, 1:5] + @test shape(Y) == [1:5, 1:5] A = Y[2, :] @test typeof(A) <: Arr{Num,1} @@ -23,22 +24,21 @@ Symbolics.option_to_metadata_type(::Val{:test_meta}) = TestMetaT B = A[3:5] @test axes(B) == (Slice(1:3),) - i = Sym{Int}(:i) - j = Sym{Int}(:j) + @syms i::Int j::Int @test symtype(X[i, j]) == Real @test symtype(X[1, j]) == Real @variables t x(t)[1:2] - @test isequal(get_variables(0 ~ x[1]), [x[1]]) - @test Set(get_variables(2x)) == Set(collect(x)) # both array elements are present - @test isequal(get_variables(2x[1]), [x[1]]) + @test isequal(collect(search_variables(0 ~ x[1])), [x[1]]) + @test search_variables(2x) == Set([x]) # both array elements are present + @test isequal(collect(search_variables(2x[1])), [x[1]]) end @testset "getname" begin @variables t x(t)[1:4] v = Symbolics.lower_varname(unwrap(x[2]), unwrap(t), 2) @test operation(v) == getindex - @test arguments(v)[2] == 2 + @test unwrap_const(arguments(v)[2]) == 2 @test getname(v) == getname(arguments(v)[1]) == Symbol("xˍtt") end @@ -48,16 +48,17 @@ end @test isequal(X[1, 1], wrap(term(getindex, unwrap(X), 1, 1))) XX = unwrap(X) - @test isequal(unwrap(X[1, :]), Symbolics.@arrayop((j,), XX[1, j], term=XX[1, :])) - @test isequal(unwrap(X[:, 2]), Symbolics.@arrayop((i,), XX[i, 2], term=XX[:, 2])) - @test isequal(unwrap(X[:, 2:3]), Symbolics.@arrayop((i, j), XX[i, j], (j in 2:3), term=XX[:, 2:3])) + idxterm = term(getindex, XX, 1, :; type = Vector{Real}, shape = [1:5]) + @test isequal(unwrap(X[1, :]), Symbolics.@arrayop((j,), XX[1, j], term=idxterm)) + idxterm = term(getindex, XX, :, 2; type = Vector{Real}, shape = [1:5]) + @test isequal(unwrap(X[:, 2]), Symbolics.@arrayop((i,), XX[i, 2], term=idxterm)) + idxterm = term(getindex, XX, :, 2:3; type = Matrix{Real}, shape = [1:5, 1:2]) + @test isequal(unwrap(X[:, 2:3]), Symbolics.@arrayop((i, j), XX[i, j], (j in 2:3), term=idxterm)) @variables t x(t)[1:4] @syms i::Int @test isequal(x[i], operation(unwrap(x))(t)[i]) - # https://github.com/JuliaSymbolics/Symbolics.jl/issues/842 - # getindex should keep metadata @variables tv v(tv)[1:2] [test_meta = 4] v2(tv)[1:3] [test_meta=[1, 2, 3]] @test !isnothing(metadata(unwrap(v))) @test hasmetadata(unwrap(v), TestMetaT) @@ -67,13 +68,10 @@ end vsw = unwrap.(vs) vs2 = scalarize(v2) vsw2 = unwrap.(vs2) - @test !isnothing(metadata(vsw[1])) - @test hasmetadata(vsw[1], TestMetaT) - @test getmetadata(vsw[1], TestMetaT) == 4 - @test getmetadata.(vsw2, TestMetaT) == [1, 2, 3] - @test !isnothing(metadata(unwrap(v[1]))) - @test hasmetadata(unwrap(v[1]), TestMetaT) - @test getmetadata(unwrap(v[1]), TestMetaT) == 4 + vswparent = arguments(vsw[1])[1] + @test !isnothing(metadata(vswparent)) + @test hasmetadata(vswparent, TestMetaT) + @test getmetadata(vswparent, TestMetaT) == 4 end @testset "maketerm" begin @@ -82,16 +80,16 @@ end T = unwrap(3A) @test isequal(T, Symbolics.maketerm(typeof(T), operation(T), arguments(T), nothing)) T2 = unwrap(3B) - @test isequal(T2, Symbolics.maketerm(typeof(T), operation(T), [*, 3, unwrap(B)], nothing)) + @test isequal(T2, Symbolics.maketerm(typeof(T), operation(T), [3, unwrap(B)], nothing)) T3 = unwrap(A .^ 2) @test isequal(T3, Symbolics.maketerm(typeof(T3), operation(T3), arguments(T3), nothing)) T4 = unwrap(A .* C) @test isequal(T4, Symbolics.maketerm(typeof(T4), operation(T4), arguments(T4), nothing)) end -getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) +getdef(v) = Symbolics.getdefaultval(v) @testset "broadcast & scalarize" begin - @variables A[1:5,1:3]=42 b[1:3]=[2, 3, 5] t x(t)[1:4] u[1:1] + @variables A[1:5,1:3]=42ones(5, 3) b[1:3]=[2, 3, 5] t x(t)[1:4] u[1:1] AA = Symbolics.scalarize(A) bb = Symbolics.scalarize(b) @test all(isequal(42), getdef.(AA)) @@ -109,7 +107,7 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) result1 = d_vec' * E # This was causing ambiguity error result2 = d_vec' * inv(E) * d_vec # The original failing expression from issue #575 @test size(result1) == (1, 3) - @test size(result2) == (1,) + @test size(result2) == () @test isequal(collect(sin.(x)), sin.([x[i] for i in 1:4])) @@ -120,7 +118,7 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) b[3] * A[1, 3]))) D = Differential(t) - @test isequal(collect(D.(x) .~ x), map(i -> D(x[i]) ~ x[i], eachindex(x))) + # @test isequal(collect(D.(x) .~ x), map(i -> D(x[i]) ~ x[i], eachindex(x))) @test_throws ArgumentError A ~ t @test isequal(D(x[1]), D(x)[1]) a = Symbolics.unwrap(D(x)[1]) @@ -131,7 +129,7 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) @test isequal(Symbolics.scalarize(u + u), [2u[1]]) # #417 - @test isequal(Symbolics.scalarize(x', (1,1)), x[1]) + @test isequal(Symbolics.scalarize(x'), Symbolics.scalarize(unwrap(x))') # #483 # examples by @gronniger @@ -153,23 +151,23 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) A4_ = wrap(@arrayop (i, j) A_[i, k] * A_[k, l] * A_[l, m] * A_[m, j]) A5_ = wrap(@arrayop (i, j) A_[i, k] * A_[k, l] * A_[l, m] * A_[m, n] * A_[n, j]) - @test isequal(Symbolics.scalarize((A*A)[k,k]), A[k, 1]*A[1, k] + A[2, k]*A[k, 2]) + @test_broken isequal(Symbolics.scalarize((A*A)[k,k]), A[k, 1]*A[1, k] + A[2, k]*A[k, 2]) # basic tests: - @test iszero((Symbolics.scalarize(A^2) * Symbolics.scalarize(A))[1,1] - - Symbolics.scalarize(A^3)[1,1]) + @test SymbolicUtils._iszero(expand((Symbolics.scalarize(A^2) * Symbolics.scalarize(A))[1,1] - + Symbolics.scalarize(A^3)[1,1])) @testset "nested scalarize" begin - @test isequal(substitute(Symbolics.scalarize(A2 ), repl_dict), test_mat^2) - @test isequal(substitute(Symbolics.scalarize(A3_), repl_dict), test_mat^3) - @test isequal(substitute(Symbolics.scalarize(A3 ), repl_dict), test_mat^3) - @test isequal(substitute(Symbolics.scalarize(A4_), repl_dict), test_mat^4) - @test isequal(substitute(Symbolics.scalarize(A4 ), repl_dict), test_mat^4) - @test isequal(substitute(Symbolics.scalarize(A5_), repl_dict), test_mat^5) - @test isequal(substitute(Symbolics.scalarize(A5 ), repl_dict), test_mat^5) - @test isequal(substitute(Symbolics.scalarize(A6 ), repl_dict), test_mat^6) - @test isequal(substitute(Symbolics.scalarize(A7 ), repl_dict), test_mat^7) + @test isequal(unwrap_const.(substitute(Symbolics.scalarize(A2 ), repl_dict)), test_mat^2) + @test isequal(unwrap_const.(substitute(Symbolics.scalarize(A3_), repl_dict)), test_mat^3) + @test isequal(unwrap_const.(substitute(Symbolics.scalarize(A3 ), repl_dict)), test_mat^3) + @test isequal(unwrap_const.(substitute(Symbolics.scalarize(A4_), repl_dict)), test_mat^4) + @test isequal(unwrap_const.(substitute(Symbolics.scalarize(A4 ), repl_dict)), test_mat^4) + @test isequal(unwrap_const.(substitute(Symbolics.scalarize(A5_), repl_dict)), test_mat^5) + @test isequal(unwrap_const.(substitute(Symbolics.scalarize(A5 ), repl_dict)), test_mat^5) + @test isequal(unwrap_const.(substitute(Symbolics.scalarize(A6 ), repl_dict)), test_mat^6) + @test isequal(unwrap_const.(substitute(Symbolics.scalarize(A7 ), repl_dict)), test_mat^7) end - @test isequal(Symbolics.scalarize(x', (1, 1)), x[1]) + @test isequal(Symbolics.scalarize(x'[1, 1]), x[1]) ##653 Symbolics.scalarize(inv(A)[1,1]) @@ -253,88 +251,88 @@ end #@test isequal(r(unwrap((X * Y) * b)), unwrap(X * (Y * b))) end -@testset "2D Diffusion Composed With Stencil Interface" begin - n = rand(8:32) - - @variables u[1:n, 1:n] - @makearray v[1:n, 1:n] begin - #interior - v[2:end-1, 2:end-1] => @arrayop (i, j) u[i-1, j] + u[i+1, j] + u[i, j-1] + u[i, j+1] - 4 * u[i, j] - #BCs - v[1, 1:end] => 0.0 - v[n, 1:end] => 0.0 - v[1:end, 1] => 0.0 - v[1:end, n] => 0.0 - end - - #2D Diffusion composed - @makearray ucx[1:n, 1:n] begin - ucx[1:end, 1:end] => 0.0 # fill zeros - ucx[2:end-1, 2:end-1] => @arrayop (i, j) u[i-1, j] + u[i+1, j] - 2 * u[i, j] (j in 2:n-1) - end - - @makearray ucy[1:n, 1:n] begin - ucy[1:end, 1:end] => 0.0 # fill zeros - ucy[2:end-1, 2:end-1] => @arrayop (i, j) u[i, j-1] + u[i, j+1] - 2 * u[i, j] (i in 2:n-1) - end - - uc = ucx .+ ucy - - global V, UC, UCX - V, UC, UCX = v, uc, (ucx, ucy) - @test isequal(collect(v), collect(uc)) -end - -@testset "ND Diffusion, Stencils with CartesianIndices" begin - n = rand(8:32) - N = 2 - - @variables t u(t)[fill(1:n, N)...] - - Igrid = CartesianIndices((fill(1:n, N)...,)) - Iinterior = CartesianIndices((fill(2:n-1, N)...,)) - - function unitindices(N::Int) #create unit CartesianIndex for each dimension - null = zeros(Int, N) - if N == 0 - return CartesianIndex() - else - return map(1:N) do i - unit_i = copy(null) - unit_i[i] = 1 - CartesianIndex(Tuple(unit_i)) - end - end - end - function Diffusion(N, n) - ē = unitindices(N) # for i.e N = 3 => ē = [CartesianIndex((1,0,0)),CartesianIndex((0,1,0)),CartesianIndex((0,0,1))] - - Dss = map(1:N) do d - ranges = CartesianIndices((map(i->d == i ? (2:n-1) : (1:n), 1:N)...,)) - @makearray x[1:n, 1:n] begin - x[1:n, 1:n] => 0 - x[ranges] => @arrayop (i, j) u[CartesianIndex(i, j)-ē[d]] + - u[CartesianIndex(i, j)+ē[d]] - 2 * u[i, j] - end - end - - return reduce((D1, D2) -> D1 .+ D2, Dss) - end - - D = Diffusion(N, n) - - @makearray Dxxu[1:n, 1:n] begin - Dxxu[1:n, 1:n] => 0 - Dxxu[2:end-1, 1:end] => @arrayop (i, j) u[i-1, j] + u[i+1, j] - 2 * u[i, j] - end - - @makearray Dyyu[1:n, 1:n] begin - Dyyu[1:n, 1:n] => 0 - Dyyu[1:end, 2:end-1] => @arrayop (i, j) u[i, j-1] + u[i, j+1] - 2 * u[i, j] - end - - @test isequal(collect(D), collect(Dxxu .+ Dyyu)) -end +# @testset "2D Diffusion Composed With Stencil Interface" begin +# n = rand(8:32) + +# @variables u[1:n, 1:n] +# @makearray v[1:n, 1:n] begin +# #interior +# v[2:end-1, 2:end-1] => @arrayop (i, j) u[i-1, j] + u[i+1, j] + u[i, j-1] + u[i, j+1] - 4 * u[i, j] +# #BCs +# v[1, 1:end] => 0.0 +# v[n, 1:end] => 0.0 +# v[1:end, 1] => 0.0 +# v[1:end, n] => 0.0 +# end + +# #2D Diffusion composed +# @makearray ucx[1:n, 1:n] begin +# ucx[1:end, 1:end] => 0.0 # fill zeros +# ucx[2:end-1, 2:end-1] => @arrayop (i, j) u[i-1, j] + u[i+1, j] - 2 * u[i, j] (j in 2:n-1) +# end + +# @makearray ucy[1:n, 1:n] begin +# ucy[1:end, 1:end] => 0.0 # fill zeros +# ucy[2:end-1, 2:end-1] => @arrayop (i, j) u[i, j-1] + u[i, j+1] - 2 * u[i, j] (i in 2:n-1) +# end + +# uc = ucx .+ ucy + +# global V, UC, UCX +# V, UC, UCX = v, uc, (ucx, ucy) +# @test isequal(collect(v), collect(uc)) +# end + +# @testset "ND Diffusion, Stencils with CartesianIndices" begin +# n = rand(8:32) +# N = 2 + +# @variables t u(t)[fill(1:n, N)...] + +# Igrid = CartesianIndices((fill(1:n, N)...,)) +# Iinterior = CartesianIndices((fill(2:n-1, N)...,)) + +# function unitindices(N::Int) #create unit CartesianIndex for each dimension +# null = zeros(Int, N) +# if N == 0 +# return CartesianIndex() +# else +# return map(1:N) do i +# unit_i = copy(null) +# unit_i[i] = 1 +# CartesianIndex(Tuple(unit_i)) +# end +# end +# end +# function Diffusion(N, n) +# ē = unitindices(N) # for i.e N = 3 => ē = [CartesianIndex((1,0,0)),CartesianIndex((0,1,0)),CartesianIndex((0,0,1))] + +# Dss = map(1:N) do d +# ranges = CartesianIndices((map(i->d == i ? (2:n-1) : (1:n), 1:N)...,)) +# @makearray x[1:n, 1:n] begin +# x[1:n, 1:n] => 0 +# x[ranges] => @arrayop (i, j) u[CartesianIndex(i, j)-ē[d]] + +# u[CartesianIndex(i, j)+ē[d]] - 2 * u[i, j] +# end +# end + +# return reduce((D1, D2) -> D1 .+ D2, Dss) +# end + +# D = Diffusion(N, n) + +# @makearray Dxxu[1:n, 1:n] begin +# Dxxu[1:n, 1:n] => 0 +# Dxxu[2:end-1, 1:end] => @arrayop (i, j) u[i-1, j] + u[i+1, j] - 2 * u[i, j] +# end + +# @makearray Dyyu[1:n, 1:n] begin +# Dyyu[1:n, 1:n] => 0 +# Dyyu[1:end, 2:end-1] => @arrayop (i, j) u[i, j-1] + u[i, j+1] - 2 * u[i, j] +# end + +# @test isequal(collect(D), collect(Dxxu .+ Dyyu)) +# end @testset "Brusselator stencil" begin n = 8 @@ -434,7 +432,7 @@ end args = arguments(sym) @test length(args) == 2 @test args[1] === unwrap(k) - @test args[2] === i + @test unwrap_const(args[2]) === i end @test_throws BoundsError k[-1] @@ -444,5 +442,6 @@ end @testset "Arrayop sorted_arguments" begin @variables x[1:3] y[1:3] sym = unwrap(x + y) - @test all(splat(isequal), zip(SymbolicUtils.sorted_arguments(sym), [+, x, y])) + @test isequal(SymbolicUtils.sorted_arguments(sym), [unwrap(x), unwrap(y)]) + @test operation(sym) === + end diff --git a/test/build_function_tests/intermediate-exprs-inplace.jl b/test/build_function_tests/intermediate-exprs-inplace.jl index 54fac9ef3..82447707d 100644 --- a/test/build_function_tests/intermediate-exprs-inplace.jl +++ b/test/build_function_tests/intermediate-exprs-inplace.jl @@ -1,40 +1,14 @@ :(function (ˍ₋out, u) begin - ˍ₋out_input_1 = begin - _out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5)))) - var"%_out" = for var"%jj′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - j = var"%jj′"[1] - j′ = var"%jj′"[2] - for var"%ii′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - i = var"%ii′"[1] - i′ = var"%ii′"[2] - begin - _out[(CartesianIndex)(i′, j′)] = (+)((getindex)(_out, i′, j′), (getindex)(u, (Main.limit2)((+)(-1, i), 5), (Main.limit2)((+)(1, j), 5))) - nothing - end - end - end - end - end - _out - end - for var"%jj′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - j = var"%jj′"[1] - j′ = var"%jj′"[2] - for var"%ii′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) + _out = ˍ₋out + var"%_out" = for _2 = 1:1:5 + for _1 = 1:1:5 begin - i = var"%ii′"[1] - i′ = var"%ii′"[2] - begin - ˍ₋out[(CartesianIndex)(i′, j′)] = (+)((getindex)(ˍ₋out, i′, j′), (getindex)(ˍ₋out_input_1, j, i)) - nothing - end + _out[(CartesianIndex)(_1, _2)] = (+)((getindex)(_out, _1, _2), (getindex)(u, (Main.limit2)((+)(-1, (getindex)(1:1:5, _2)), 5), (Main.limit2)((+)(1, (getindex)(1:1:5, _1)), 5))) + nothing end end end - end + _out end end) \ No newline at end of file diff --git a/test/build_function_tests/intermediate-exprs-outplace.jl b/test/build_function_tests/intermediate-exprs-outplace.jl index 776af7441..b82681a78 100644 --- a/test/build_function_tests/intermediate-exprs-outplace.jl +++ b/test/build_function_tests/intermediate-exprs-outplace.jl @@ -1,41 +1,11 @@ :(function (u,) begin - _out = (zeros)(Float64, (map)(length, (1:5, 1:5))) - var"%_out" = begin - _out_input_1 = begin - _out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5)))) - var"%_out" = for var"%jj′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - j = var"%jj′"[1] - j′ = var"%jj′"[2] - for var"%ii′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - i = var"%ii′"[1] - i′ = var"%ii′"[2] - begin - _out[(CartesianIndex)(i′, j′)] = (+)((getindex)(_out, i′, j′), (getindex)(u, (Main.limit2)((+)(-1, i), 5), (Main.limit2)((+)(1, j), 5))) - nothing - end - end - end - end - end - _out - end - for var"%jj′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) + _out = (zeros)(Float64, (5, 5)) + var"%_out" = for _2 = 1:1:5 + for _1 = 1:1:5 begin - j = var"%jj′"[1] - j′ = var"%jj′"[2] - for var"%ii′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - i = var"%ii′"[1] - i′ = var"%ii′"[2] - begin - _out[(CartesianIndex)(i′, j′)] = (+)((getindex)(_out, i′, j′), (getindex)(_out_input_1, j, i)) - nothing - end - end - end + _out[(CartesianIndex)(_1, _2)] = (+)((getindex)(_out, _1, _2), (getindex)(u, (Main.limit2)((+)(-1, (getindex)(1:1:5, _2)), 5), (Main.limit2)((+)(1, (getindex)(1:1:5, _1)), 5))) + nothing end end end diff --git a/test/build_function_tests/manual-limits-inplace.jl b/test/build_function_tests/manual-limits-inplace.jl index 2ff5ae5fb..1a5ba29e0 100644 --- a/test/build_function_tests/manual-limits-inplace.jl +++ b/test/build_function_tests/manual-limits-inplace.jl @@ -1,18 +1,14 @@ :(function (ˍ₋out, u) - for var"%jj′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - j = var"%jj′"[1] - j′ = var"%jj′"[2] - for var"%ii′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - i = var"%ii′"[1] - i′ = var"%ii′"[2] + begin + _out = ˍ₋out + var"%_out" = for _2 = 1:1:5 + for _1 = 1:1:5 begin - ˍ₋out[(CartesianIndex)(i′, j′)] = (+)((getindex)(ˍ₋out, i′, j′), (getindex)(u, (Main.limit2)((+)(-1, i), 5), (Main.limit2)((+)(1, j), 5))) + _out[(CartesianIndex)(_1, _2)] = (+)((getindex)(_out, _1, _2), (getindex)(u, (Main.limit2)((+)(-1, _1), 5), (Main.limit2)((+)(1, _2), 5))) nothing end end end - end + _out end end) \ No newline at end of file diff --git a/test/build_function_tests/manual-limits-outplace.jl b/test/build_function_tests/manual-limits-outplace.jl index 6cc98dcc4..7ab27f643 100644 --- a/test/build_function_tests/manual-limits-outplace.jl +++ b/test/build_function_tests/manual-limits-outplace.jl @@ -1,19 +1,11 @@ :(function (u,) begin - _out = (zeros)(Float64, (map)(length, (Base.OneTo(5), Base.OneTo(5)))) - var"%_out" = for var"%jj′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - j = var"%jj′"[1] - j′ = var"%jj′"[2] - for var"%ii′" = (zip)(Base.OneTo(5), (Symbolics.reset_to_one)(Base.OneTo(5))) - begin - i = var"%ii′"[1] - i′ = var"%ii′"[2] - begin - _out[(CartesianIndex)(i′, j′)] = (+)((getindex)(_out, i′, j′), (getindex)(u, (Main.limit2)((+)(-1, i), 5), (Main.limit2)((+)(1, j), 5))) - nothing - end - end + _out = (zeros)(Float64, (5, 5)) + var"%_out" = for _2 = 1:1:5 + for _1 = 1:1:5 + begin + _out[(CartesianIndex)(_1, _2)] = (+)((getindex)(_out, _1, _2), (getindex)(u, (Main.limit2)((+)(-1, _1), 5), (Main.limit2)((+)(1, _2), 5))) + nothing end end end diff --git a/test/build_function_tests/transpose-inplace.jl b/test/build_function_tests/transpose-inplace.jl index 431a6fd1b..3fad2f2ae 100644 --- a/test/build_function_tests/transpose-inplace.jl +++ b/test/build_function_tests/transpose-inplace.jl @@ -1,18 +1,14 @@ :(function (ˍ₋out, x) - for var"%jj′" = (zip)(Base.OneTo(4), (Symbolics.reset_to_one)(Base.OneTo(4))) - begin - j = var"%jj′"[1] - j′ = var"%jj′"[2] - for var"%ii′" = (zip)(Base.OneTo(4), (Symbolics.reset_to_one)(Base.OneTo(4))) - begin - i = var"%ii′"[1] - i′ = var"%ii′"[2] + begin + _out = ˍ₋out + var"%_out" = for _2 = 1:1:4 + for _1 = 1:1:4 begin - ˍ₋out[(CartesianIndex)(i′, j′)] = (+)((getindex)(ˍ₋out, i′, j′), (getindex)(x, j, i)) + _out[(CartesianIndex)(_1, _2)] = (+)((getindex)(_out, _1, _2), (getindex)(x, _2, _1)) nothing end end end - end + _out end end) \ No newline at end of file diff --git a/test/build_function_tests/transpose-outplace.jl b/test/build_function_tests/transpose-outplace.jl index a2e6d8880..d105f443d 100644 --- a/test/build_function_tests/transpose-outplace.jl +++ b/test/build_function_tests/transpose-outplace.jl @@ -1,19 +1,11 @@ :(function (x,) begin - _out = (zeros)(Float64, (map)(length, (1:4, 1:4))) - var"%_out" = for var"%jj′" = (zip)(Base.OneTo(4), (Symbolics.reset_to_one)(Base.OneTo(4))) - begin - j = var"%jj′"[1] - j′ = var"%jj′"[2] - for var"%ii′" = (zip)(Base.OneTo(4), (Symbolics.reset_to_one)(Base.OneTo(4))) - begin - i = var"%ii′"[1] - i′ = var"%ii′"[2] - begin - _out[(CartesianIndex)(i′, j′)] = (+)((getindex)(_out, i′, j′), (getindex)(x, j, i)) - nothing - end - end + _out = (zeros)(Float64, (4, 4)) + var"%_out" = for _2 = 1:1:4 + for _1 = 1:1:4 + begin + _out[(CartesianIndex)(_1, _2)] = (+)((getindex)(_out, _1, _2), (getindex)(x, _2, _1)) + nothing end end end diff --git a/test/build_function_tests/transpose-term-inplace.jl b/test/build_function_tests/transpose-term-inplace.jl index 431a6fd1b..daac6df2f 100644 --- a/test/build_function_tests/transpose-term-inplace.jl +++ b/test/build_function_tests/transpose-term-inplace.jl @@ -1,18 +1,3 @@ :(function (ˍ₋out, x) - for var"%jj′" = (zip)(Base.OneTo(4), (Symbolics.reset_to_one)(Base.OneTo(4))) - begin - j = var"%jj′"[1] - j′ = var"%jj′"[2] - for var"%ii′" = (zip)(Base.OneTo(4), (Symbolics.reset_to_one)(Base.OneTo(4))) - begin - i = var"%ii′"[1] - i′ = var"%ii′"[2] - begin - ˍ₋out[(CartesianIndex)(i′, j′)] = (+)((getindex)(ˍ₋out, i′, j′), (getindex)(x, j, i)) - nothing - end - end - end - end - end + (copy!)(ˍ₋out, (adjoint)(x)) end) \ No newline at end of file diff --git a/test/cartesianindex.jl b/test/cartesianindex.jl deleted file mode 100644 index c74ae2e2b..000000000 --- a/test/cartesianindex.jl +++ /dev/null @@ -1,40 +0,0 @@ -using Symbolics, Test -using Symbolics: Arr -#using SymbolicUtils: substitute - -@testset "Symbolic CartesianIndex" begin - @syms i::Int j::Int k::Int - I = CartesianIndex(i, j, k) - @test isequal(I[1], i) - @test isequal(I[2], j) - @test isequal(I[3], k) - - J = CartesianIndex(1, 2, 3) + I - @test isequal(J[1], 1 + i) - @test isequal(J[2], 2 + j) - @test isequal(J[3], 3 + k) - - @test isequal(I + I, CartesianIndex(2i, 2j, 2k)) - @test isequal(I + CartesianIndex(1, 2, 3), CartesianIndex(i+1, j+2, k+3)) - @test isequal(CartesianIndex(1, 2, 3) + I, CartesianIndex(1+i, 2+j, 3+k)) - - @test isequal(I - I, CartesianIndex(0, 0, 0)) - @test isequal(I - CartesianIndex(1, 2, 3), CartesianIndex(i-1, j-2, k-3)) - @test isequal(CartesianIndex(1, 2, 3) - I, CartesianIndex(1-i, 2-j, 3-k)) - - @test isequal(2I, CartesianIndex(2i, 2j, 2k)) - - A = rand(2, 4, 6) - - @test substitute(A[J], Dict(i=>1, j=>2, k=>3)) == A[2, 4, 6] - - II = substitute(I, Dict(i=>1, j=>2, k=>3)) - - @test A[II] == A[1, 2, 3] -end - -@testset "Num Index" begin - a = rand(5) - i = Num(1) - a[i] -end diff --git a/test/coeff.jl b/test/coeff.jl index a53440765..30f28be46 100644 --- a/test/coeff.jl +++ b/test/coeff.jl @@ -16,7 +16,7 @@ import Symbolics: coeff @test isequal(coeff(a*x^sqrt(2), x^sqrt(2)), a) @test isequal(coeff(a + x, x), 1) -@test isequal(coeff(2(a + x), x), 2) +@test isequal(unwrap_const(coeff(2(a + x), x)), 2) e = 4 + x + 3x^2 + 2x^4 + a*x^2 + b @test isequal(coeff(e), 4) @@ -51,13 +51,13 @@ e = x*y^2 + 2x + y^3*x^3 @test isequal(coeff(e, x^0), 0) @test isequal(coeff(a*x + 3, x^0), 3) -@test isequal(coeff(x / 5, x), 1//5) +@test isequal(unwrap_const(coeff(x / 5, x)), 1//5) @test isequal(coeff(x / y, x), 1/y) @test isequal(coeff(x * 5y / (1 + y + z) , x), 5y / (1 + y + z)) # issue #1041 - coefficient of cross term in multivariate polynomial -@test isequal(coeff(2*x*y + y, x*y), 2) -@test isequal(coeff(2*x^2*y + y, x^2*y), 2) +@test isequal(unwrap_const(coeff(2*x*y + y, x*y)), 2) +@test isequal(unwrap_const(coeff(2*x^2*y + y, x^2*y)), 2) @test_throws AssertionError coeff(2*x*y + y, 2*x*y) # numerical factors not allowed in second argument of `coeff` @testset "Issue#1610 non-numeric coeff" begin @variables x a b c d diff --git a/test/complex.jl b/test/complex.jl index d293e91ec..e9b77cbb3 100644 --- a/test/complex.jl +++ b/test/complex.jl @@ -1,5 +1,5 @@ using Symbolics, Test -using SymbolicUtils: metadata +using SymbolicUtils: metadata, unwrap_const using Symbolics: unwrap using SymbolicIndexingInterface: getname, hasname @@ -8,7 +8,7 @@ using SymbolicIndexingInterface: getname, hasname @testset "types" begin @test a isa Num @test b isa Num - @test eltype(Z) <: Complex{Real} + @test eltype(Z) <: Complex{Num} for x in [z, Z[1], z+a, z*a, z^2, z/z] # z/z is sus @test x isa Complex{Num} @@ -20,8 +20,8 @@ using SymbolicIndexingInterface: getname, hasname # issue #314 bi = a+a*im bs = substitute(bi, (Dict(a=>1.0))) # returns 1.0 + im - typeof(bs) # Complex{Num} - bv = Symbolics.value.(bs) + @test bs isa Complex{Num} + bv = unwrap_const(Symbolics.value(bs)) @test typeof(bv) == ComplexF64 end diff --git a/test/diff.jl b/test/diff.jl index 6d30183cd..08729aa19 100644 --- a/test/diff.jl +++ b/test/diff.jl @@ -1,6 +1,8 @@ using Symbolics +using SymbolicUtils using Test -using Symbolics: value +using LinearAlgebra, SparseArrays +using Symbolics: value, unwrap # Derivatives @variables t σ ρ β @@ -12,9 +14,9 @@ Dx = Differential(x) @test Symbol(D(D(uu))) === Symbol("uuˍtt(t)") @test Symbol(D(uuˍt)) === Symbol(D(D(uu))) -@test Symbol(D(v[2])) === Symbol("getindex(vˍt(t), 2)") +@test Symbol(D(v[2])) === Symbol("(vˍt(t))[2]") -test_equal(a, b) = @test isequal(simplify(a), simplify(b)) +test_equal(a, b) = @test isequal(unwrap_const(simplify(unwrap(a))), unwrap_const(simplify(unwrap(b)))) @testset "ZeroOperator handling" begin @test_throws ErrorException Differential(0.1)(x) @@ -24,8 +26,8 @@ end #@test @macroexpand(@derivatives D'~t D2''~t) == @macroexpand(@derivatives (D'~t), (D2''~t)) -@test isequal(expand_derivatives(D(t)), 1) -@test isequal(expand_derivatives(D(D(t))), 0) +@test isequal(unwrap_const(expand_derivatives(D(t))), 1) +@test isequal(unwrap_const(expand_derivatives(D(D(t)))), 0) dsin = D(sin(t)) @test isequal(expand_derivatives(dsin), cos(t)) @@ -33,12 +35,12 @@ dsin = D(sin(t)) dcsch = D(csch(t)) @test isequal(expand_derivatives(dcsch), simplify(-coth(t) * csch(t))) -@test isequal(expand_derivatives(D(-7)), 0) +@test isequal(unwrap_const(unwrap(expand_derivatives(D(-7)))), 0) @test isequal(expand_derivatives(D(sin(2t))), simplify(cos(2t) * 2)) @test isequal(expand_derivatives(D2(sin(t))), simplify(-sin(t))) @test isequal(expand_derivatives(D2(sin(2t))), simplify(-sin(2t) * 4)) -@test isequal(expand_derivatives(D2(t)), 0) -@test isequal(expand_derivatives(D2(5)), 0) +@test isequal(unwrap_const(unwrap(expand_derivatives(D2(t)))), 0) +@test isequal(unwrap_const(unwrap(expand_derivatives(D2(5)))), 0) # Chain rule dsinsin = D(sin(sin(t))) @@ -65,23 +67,23 @@ test_equal(jac[3,3], -1β) # issue #545 z = t + t^2 -#test_equal(expand_derivatives(D(z)), 1 + t * 2) +test_equal(expand_derivatives(D(z)), 1 + t * 2) z = t-2t -#test_equal(expand_derivatives(D(z)), -1) +test_equal(expand_derivatives(D(z)), -1) # Variable dependence checking in differentiation @variables a(t) b(a) @test !isequal(D(b), 0) -@test isequal(expand_derivatives(D(t)), 1) -@test isequal(expand_derivatives(Dx(x)), 1) +test_equal(expand_derivatives(D(t)), 1) +test_equal(expand_derivatives(Dx(x)), 1) @variables x(t) y(t) z(t) @test isequal(expand_derivatives(D(x * y)), simplify(y*D(x) + x*D(y))) @test isequal(expand_derivatives(D(x * y)), simplify(D(x)*y + x*D(y))) -@test isequal(expand_derivatives(D(2t)), 2) +test_equal(expand_derivatives(D(2t)), 2) @test isequal(expand_derivatives(D(2x)), 2D(x)) @test isequal(expand_derivatives(D(x^2)), simplify(2 * x * D(x))) @@ -92,7 +94,7 @@ z = t-2t @test iszero(expand_derivatives(D(42))) @test all(iszero, Symbolics.gradient(42, [t, x, y, z])) @test all(iszero, Symbolics.hessian(42, [t, x, y, z])) -@test isequal(Symbolics.jacobian([t, x, 42], [t, x]), +foreach(test_equal, Symbolics.jacobian([t, x, 42], [t, x]), Num[1 0 Differential(t)(x) 1 0 0]) @@ -112,7 +114,6 @@ t1 = Symbolics.gradient(tmp, [x1, x2]) D = Differential(k) @test Symbolics.tosymbol(value(D(x))) === Symbol("xˍk(t)") -using Symbolics @variables t x(t) ∂ₜ = Differential(t) ∂ₓ = Differential(x) @@ -133,10 +134,7 @@ dxyu = Dx(Dy(u(x,y))) dxxu = Dx(Dx(u(x,y))) @test isequal(expand_derivatives(dxxu), dxxu) -using Symbolics, LinearAlgebra, SparseArrays -using Test - -canonequal(a, b) = isequal(simplify(a), simplify(b)) +canonequal(a, b) = isequal(simplify(unwrap_const(unwrap(a))), simplify(unwrap_const(unwrap(b)))) # Calculus @variables t σ ρ β @@ -202,8 +200,8 @@ end input=rand(3) output=rand(8) -findnz(Symbolics.jacobian_sparsity(f!, output, input))[[1,2]] == findnz(reference_jac)[[1,2]] -findnz(Symbolics.jacobian_sparsity(f1!, output, input,1,2,c=3))[[1,2]] == findnz(reference_jac)[[1,2]] +@test findnz(Symbolics.jacobian_sparsity(f!, output, input))[[1,2]] == findnz(reference_jac)[[1,2]] +@test findnz(Symbolics.jacobian_sparsity(f1!, output, input,1,2,c=3))[[1,2]] == findnz(reference_jac)[[1,2]] input = rand(2,2) function f2!(res,u,a,b,c) @@ -211,7 +209,7 @@ function f2!(res,u,a,b,c) res.=[a*x^2, y^3, b*x^4, sin(y), c*x+y, x+z^2, a*z+x, x+y^2+sin(z)] end -findnz(Symbolics.jacobian_sparsity(f!, output, input))[[1,2]] == findnz(reference_jac)[[1,2]] +@test findnz(Symbolics.jacobian_sparsity(f!, output, input))[[1,2]] == findnz(reference_jac)[[1,2]] # Check for failures due to du[4] undefined function f_undef(du,u) @@ -226,9 +224,7 @@ udef_ref = sparse([1 0 0 0 0 1 0 0 0 0 1 1 0 0 0 0]) -findnz(sparsity_pattern)[[1,2]] == findnz(udef_ref)[[1,2]] - -using Symbolics +@test findnz(sparsity_pattern)[[1,2]] == findnz(udef_ref)[[1,2]] rosenbrock(X) = sum(1:length(X)-1) do i 100 * (X[i+1] - X[i]^2)^2 + (1 - X[i])^2 @@ -261,7 +257,7 @@ expression2 = substitute(expression, Dict(collect(Differential(t).(x) .=> ẋ))) @test isequal( Symbolics.derivative(ifelse(signbit(b), b^2, sqrt(b)), b), - ifelse(signbit(b), 2b,(SymbolicUtils.unstable_pow(2Symbolics.unwrap(sqrt(b)), -1))) + ifelse(signbit(b), 2b,(^(2Symbolics.unwrap(sqrt(b)), -1))) ) # Chain rule @@ -291,12 +287,11 @@ sub_eqs = substitute(eqs, Dict([D(x)=>D(x), x=>1])) @test sub_eqs == [D(x) ~ 1, D(y) ~ 1 + y] @variables x y -@test substitute([x + y; x - y], Dict(x=>1, y=>2)) == [3, -1] +@test unwrap_const.(unwrap.(substitute([x + y; x - y], Dict(x=>1, y=>2)))) == [3, -1] # 530#discussion_r825125589 let - using Symbolics @variables u[1:2] y[1:1] t u = collect(u) y = collect(y) @@ -311,9 +306,9 @@ let @variables t a(t) vars = collect(@variables(x(t)[1:1])[1]) ps = collect(@variables(ps[1:1])[1]) - @test Symbolics.derivative(ps[1], vars[1]) == 0 - @test Symbolics.derivative(ps[1], a) == 0 - @test Symbolics.derivative(x[1], a) == 0 + @test unwrap_const(unwrap(Symbolics.derivative(ps[1], vars[1]))) == 0 + @test unwrap_const(unwrap(Symbolics.derivative(ps[1], a))) == 0 + @test unwrap_const(unwrap(Symbolics.derivative(x[1], a))) == 0 end # 580 @@ -332,10 +327,10 @@ end @variables t t2 x(t) D = Differential(t) ex = D(x) -ex2 = substitute(ex, [t=>t2]) +ex2 = substitute(ex, [t=>t2]; filterer = Returns(true)) @test isequal(operation(Symbolics.unwrap(ex2)).x, t2) -ex3 = substitute(D(x) * 2 + x / t, [t=>t2]) -xt2 = substitute(x, [t => t2]) +ex3 = substitute(D(x) * 2 + x / t, [t=>t2]; filterer = Returns(true)) +xt2 = substitute(x, [t => t2]; filterer = Returns(true)) @test isequal(ex3, xt2 / t2 + 2Differential(t2)(xt2)) # 581 @@ -348,9 +343,8 @@ end #908 # let - using Symbolics @variables t - @test isequal(expand_derivatives(Differential(t)(im*t)), im) + @test isequal(unwrap_const(unwrap(expand_derivatives(Differential(t)(im*t)))), im) @test isequal(expand_derivatives(Differential(t)(t^2 + im*t)), 2t + im) end @@ -434,13 +428,13 @@ end # Derivative of a `BasicSymbolic` (#1085) let - x = Symbolics.Sym{Int}(:x) + x = Symbolics.Sym{SymReal}(:x; type = Int) @testset for f in [sqrt, sin, acos, exp] @test isequal( Symbolics.derivative(f, x), Symbolics.derivative( f, - Symbolics.BasicSymbolic(x) + x ) ) end @@ -519,16 +513,16 @@ let @variables p[1:1] x[1:1] p = collect(p) x = collect(x) - @test collect(Symbolics.sparsehessian(p[1] * x[1], x)) == [0;;] + test_equal.(collect(Symbolics.sparsehessian(p[1] * x[1], x)), [0;;]) @test isequal(collect(Symbolics.sparsehessian(p[1] * x[1]^2, x)), [2p[1];;]) # second example @variables a[1:2] a = collect(a) ex = (a[1]+a[2])^2 - @test Symbolics.hessian(ex, [a[1]]) == [2;;] - @test collect(Symbolics.sparsehessian(ex, [a[1]])) == [2;;] - @test collect(Symbolics.sparsehessian(ex, a)) == fill(2, 2, 2) + test_equal.(Symbolics.hessian(ex, [a[1]]), [2;;]) + test_equal.(collect(Symbolics.sparsehessian(ex, [a[1]])), [2;;]) + test_equal.(collect(Symbolics.sparsehessian(ex, a)), fill(2, 2, 2)) end # issue #847 @@ -631,7 +625,7 @@ end f = - 1.5log(5 + p[3]) - 1.5log(7 + p[1]) - 1.5log(8 - p[2]) - 1.5log(9 - p[4]) + 0.08p[4]*p[9] -1.5log(5 + p[3]) - 1.5log(7 + p[1]) - 1.5log(8 - p[2]) - 1.5log(9 - p[4]) + 0.08p[4]*p[9] - @test iszero(Symbolics.unwrap.(Symbolics.gradient(f, vp) .- Symbolics.gradient(f, p))) - @test iszero(Symbolics.unwrap.(Symbolics.hessian(f, vp) .- Symbolics.hessian(f, p))) - @test iszero(Symbolics.unwrap.(Symbolics.jacobian([f], vp) .- Symbolics.jacobian([f], p))) -end \ No newline at end of file + test_equal.(Symbolics.unwrap.(Symbolics.gradient(f, vp) .- Symbolics.gradient(f, p)), 0) + test_equal.(Symbolics.unwrap.(Symbolics.hessian(f, vp) .- Symbolics.hessian(f, p)), 0) + test_equal.(Symbolics.unwrap.(Symbolics.jacobian([f], vp) .- Symbolics.jacobian([f], p)), 0) +end diff --git a/test/diffeqs.jl b/test/diffeqs.jl index b9316f387..d1c564f97 100644 --- a/test/diffeqs.jl +++ b/test/diffeqs.jl @@ -14,13 +14,13 @@ Dt = Symbolics.Differential(t) @test_broken isapprox(solve_linear_ode_system([-1 -2; 2 -1], [1, -1], t), [exp(-t)*(cos(2t) + sin(2t)), exp(-t)*(sin(2t) - cos(2t))]) # can't handle complex eigenvalues (though it should be able to) -@test isapprox(solve_linear_ode_system([1 -1 0; 1 2 1; -2 1 -1], [7, 2, 3], t), (5//3)*exp(-t)*[-1, -2, 7] - 14exp(t)*[-1, 0, 1] + (16//3)*exp(2t)*[-1, 1, 1]) +@test isapprox(expand.(solve_linear_ode_system([1 -1 0; 1 2 1; -2 1 -1], [7, 2, 3], t)), (5//3)*exp(-t)*[-1, -2, 7] - 14exp(t)*[-1, 0, 1] + (16//3)*exp(2t)*[-1, 1, 1]) @test isequal(solve_linear_ode_system([1 0; 0 -1], [1, -1], t), [exp(t), -exp(-t)]) @test isequal(solve_linear_ode_system([-3 4; -2 3], [7, 2], t), [10exp(-t) - 3exp(t), 5exp(-t) - 3exp(t)]) @test isapprox(solve_linear_ode_system([4 -3; 8 -6], [7, 2], t), [18 - 11exp(-2t), 24 - 22exp(-2t)]) -@test isequal(solve_linear_ode_system([1 -1 0; 1 2 1; -2 1 -1], [7, 2, 3], t), (5//3)*exp(-t)*[-1, -2, 7] - 14exp(t)*[-1, 0, 1] + (16//3)*exp(2t)*[-1, 1, 1]) +@test isequal(expand.(solve_linear_ode_system([1 -1 0; 1 2 1; -2 1 -1], [7, 2, 3], t)), (5//3)*exp(-t)*[-1, -2, 7] - 14exp(t)*[-1, 0, 1] + (16//3)*exp(2t)*[-1, 1, 1]) @test_throws ArgumentError solve_linear_ode_system([1 2; 3 4], [1, 2, 3], t) # mismatch between A and x0 @test_throws ArgumentError solve_linear_ode_system([1 2 3; 4 5 6], [1, 2], t) # A isn't square @@ -90,4 +90,4 @@ ys = Symbolics.variables(:y, 1:2) @test isequal(Symbolics.unreduce_order([ys[1], ys[2]], x, t, ys), [x, Dt(x)]) @test Symbolics.is_solution(C[1]*exp(3t) + C[2]*t*exp(3t) + 2(t^2)*exp(3t), SymbolicLinearODE(x, t, [9, -6], 4exp(3t))) -@test Symbolics.is_solution(C[1]*exp(-t) + C[2]*t*exp(-t), (Dt^2)(x) + 2(Dt^1)(x) + x ~ 0, x, t) \ No newline at end of file +@test Symbolics.is_solution(C[1]*exp(-t) + C[2]*t*exp(-t), (Dt^2)(x) + 2(Dt^1)(x) + x ~ 0, x, t) diff --git a/test/domains.jl b/test/domains.jl index 5e29c8b71..74d08b739 100644 --- a/test/domains.jl +++ b/test/domains.jl @@ -31,7 +31,5 @@ var_domain_pair = t ∈ (0,1) @test var_domain_pair.domain isa Interval # Other types -t = Symbolics.Num(:t) -@assert (t ∈ domain) isa VarDomainPairing t = Symbolics.variable(:t) @assert (t ∈ domain) isa VarDomainPairing diff --git a/test/extensions/groebner.jl b/test/extensions/groebner.jl index ac3cc8bfc..0d2d3acff 100644 --- a/test/extensions/groebner.jl +++ b/test/extensions/groebner.jl @@ -11,8 +11,8 @@ syms = [ [x, sin((44 // 31)y) * z] ] for sym in syms - polynoms, pvar2sym, sym2term = Symbolics.symbol_to_poly(sym) - sym2 = Symbolics.poly_to_symbol(polynoms, pvar2sym, sym2term, Real) + polynoms, poly_to_bs = Symbolics.symbol_to_poly(sym) + sym2 = Symbolics.poly_to_symbol(polynoms, poly_to_bs) @test isequal(expand.(sym2), expand.(sym)) end diff --git a/test/forwarddiff_symbolic_dual_ops.jl b/test/forwarddiff_symbolic_dual_ops.jl index eba74b7c1..5aee37d8e 100644 --- a/test/forwarddiff_symbolic_dual_ops.jl +++ b/test/forwarddiff_symbolic_dual_ops.jl @@ -25,7 +25,7 @@ for f ∈ SymbolicUtils.monadic # The polygamma and trigamma functions seem to be missing rules in ForwardDiff. # The abs rule uses conditionals and cannot be used with Symbolics.Num. # acsc, asech, NanMath.log2 and NaNMath.log10 are tested separately - if f ∈ (abs, SF.polygamma, SF.trigamma, acsc, acsch, asech, NaNMath.log2, NaNMath.log10) + if f ∈ (abs, SF.polygamma, SF.trigamma, acsc, acsch, asech, NaNMath.log2, NaNMath.log10, sign, signbit, factorial, expinti, sinint) continue end @@ -56,8 +56,7 @@ for f ∈ SymbolicUtils.basic_diadic fd = ForwardDiff.derivative(fun, x) sym = Symbolics.Differential(x)(fun(x)) |> expand_derivatives - - @test isequal(fd, sym) + @test isequal(fd, unwrap_const(sym)) end for f ∈ SymbolicUtils.diadic diff --git a/test/fuzz-arrays.jl b/test/fuzz-arrays.jl index 46deac37f..53337aae4 100644 --- a/test/fuzz-arrays.jl +++ b/test/fuzz-arrays.jl @@ -1,6 +1,8 @@ using Symbolics +using SymbolicUtils using Test using LinearAlgebra +using SymbolicUtils: symtype # Transpose and multiply @@ -17,8 +19,14 @@ function adjmul_rand_leaf() a = rand(adjmul_leaf_nodes) rand(Bool) ? a' : a end -isadjvec(x::Symbolics.ArrayOp) = x.term.f == adjoint && ndims(Symbolics.arguments(x.term)[1]) == 1 -isadjvec(x) = false +function isadjvec(x) + iscall(x) && (operation(x) === adjoint || operation(x) === transpose) && ( + ndims(arguments(x)[1]) == 1 || + ndims(arguments(x)[1]) == 2 && length(axes(arguments(x)[1], 2)) == 1) +end +isadjvec(x::Adjoint) = ndims(parent(x)) == 1 +isadjvec(x::Transpose) = ndims(parent(x)) == 1 +isdot(A, b) = isadjvec(A) && ndims(b) == 1 function rand_mul_expr(a=adjmul_rand_leaf(), b=adjmul_rand_leaf()) @@ -39,19 +47,15 @@ function rand_mul_expr(a=adjmul_rand_leaf(), sz = try size(a * b) catch err; nothing end if sz !== nothing - try - size(a) - size(b) - @goto test_size - catch err - return # no size known + if size(a) isa SymbolicUtils.Unknown || size(b) isa SymbolicUtils.Unknown + return end - @label test_size - if (isadjvec(Symbolics.unwrap(a)) && ndims(b) == 1) || Symbolics.isdot(a, b) + if (isadjvec(Symbolics.unwrap(a)) && ndims(b) == 1) || isdot(a, b) if size(a*b) != () println("a * b is wrong:") @show a b + @show symtype(a) symtype(b) @show typeof(a) typeof(b) return @test size(a*b) == () else @@ -62,12 +66,20 @@ function rand_mul_expr(a=adjmul_rand_leaf(), ab_sample = rand(stype(eltype(a)), size(a)...) * rand(stype(eltype(b)), size(b)...) if size(a * b) == size(ab_sample) @test true - @test (eltype(a*b) <: Real && eltype(ab_sample) <: Real) || (eltype(ab_sample) <: Complex && eltype(a*b) <: Complex) + @test (eltype(symtype(a*b)) <: Real && eltype(ab_sample) <: Real) || (eltype(ab_sample) <: Complex && eltype(symtype(a*b)) <: Complex) else println("a * b is wrong:") @show a b + @show symtype(a) symtype(b) @show typeof(a) typeof(b) - @test size(a * b) == size(rand(size(a)...) * rand(size(b)...)) + @show a * b + @show isadjvec(Symbolics.unwrap(a)) + target_size = size(rand(size(a)...) * rand(size(b)...)) + @show target_size + if target_size == (1,) && isadjvec(Symbolics.unwrap(a)) + target_size = () + end + @test size(a * b) == target_size end end end diff --git a/test/invalidations.jl b/test/invalidations.jl index 7f0c4054a..d145bd0b9 100644 --- a/test/invalidations.jl +++ b/test/invalidations.jl @@ -1,5 +1,5 @@ using SnoopCompile: @snoop_invalidations -using Symbolics +using Symbolics, Test struct FakeType end diff --git a/test/inverse.jl b/test/inverse.jl index 091fa458a..3f37584c5 100644 --- a/test/inverse.jl +++ b/test/inverse.jl @@ -1,4 +1,5 @@ using Symbolics +using Test @test inverse(sin) == left_inverse(sin) == right_inverse(sin) == asin @test inverse(asin) == left_inverse(asin) == right_inverse(asin) == sin diff --git a/test/linear_solver.jl b/test/linear_solver.jl index 3d00e3431..27483f39b 100644 --- a/test/linear_solver.jl +++ b/test/linear_solver.jl @@ -1,4 +1,5 @@ using Symbolics +using Symbolics: value, unwrap using LinearAlgebra using Test @@ -16,7 +17,7 @@ sol = Symbolics.symbolic_linear_solve(expr, p) a, b, islinear = Symbolics.linear_expansion(expr, p) @test eltype((a, b)) <: Num @test isequal((a, b, islinear), (-(x - y), -y, true)) -@test isequal(Symbolics.symbolic_linear_solve(x * p ~ 0, p), 0) +@test isequal(unwrap_const(unwrap(Symbolics.symbolic_linear_solve(x * p ~ 0, p))), 0) @test_throws Any Symbolics.symbolic_linear_solve(1/x + p * p/x ~ 0, p) @test isequal(Symbolics.symbolic_linear_solve(x * y ~ p, x), p / y) @test isequal(Symbolics.symbolic_linear_solve(x * -y ~ p, y), -p / x) @@ -34,7 +35,7 @@ expr = Dx * x + Dx*t - 2//3*x + y*Dx a, b, islinear = Symbolics.linear_expansion(expr, x) @test iszero(expand(a * x + b - expr)) @test isequal(Symbolics.symbolic_linear_solve(expr ~ Dx, Dx), (-2//3*x)/(1 - t - x - y)) -@test isequal(Symbolics.symbolic_linear_solve(expr ~ Dx, x), (t*Dx + y*Dx - Dx) / ((2//3) - Dx)) +@test isequal(Symbolics.symbolic_linear_solve(expr ~ Dx, x), (t*Dx + y*Dx - Dx) / ((2/3) - Dx)) exprs = [ 3//2*x + 2y + 10 @@ -43,8 +44,8 @@ exprs = [ xs = [x, y] A, b, islinear = Symbolics.linear_expansion(exprs, xs) @test islinear -@test isequal(A, [3//2 2; 7 3]) -@test isequal(b, [10; -8]) +@test isequal(unwrap_const.(unwrap.(A)), [3//2 2; 7 3]) +@test isequal(unwrap_const.(unwrap.(b)), [10; -8]) @variables x y z eqs = [ @@ -52,29 +53,22 @@ eqs = [ 2//1 + y - z ~ 3//1*x 2//1 + y - 2z ~ 3//1*z ] -@test [2 1 -1; -3 1 -1; 0 1 -5] * Symbolics.symbolic_linear_solve(eqs, [x, y, z]) == [2; -2; -2] +@test unwrap_const.(unwrap.([2 1 -1; -3 1 -1; 0 1 -5] * Symbolics.symbolic_linear_solve(eqs, [x, y, z]))) ≈ [2; -2; -2] @test isequal(Symbolics.symbolic_linear_solve(2//1*x + y - 2//1*z ~ 9//1*x, 1//1*x), (1//7)*(y - 2//1*z)) @test isequal(Symbolics.symbolic_linear_solve(x + y ~ 0, x), Symbolics.symbolic_linear_solve([x + y ~ 0], x)) @test isequal(Symbolics.symbolic_linear_solve([x + y ~ 0], [x]), Symbolics.symbolic_linear_solve(x + y ~ 0, [x])) @test isequal(Symbolics.symbolic_linear_solve(2x/z + sin(z), x), sin(z) / (-2 / z)) -@variables t x -D = Symbolics.Difference(t; dt=1) -a, b, islinear = Symbolics.linear_expansion(D(x) - x, x) -@test islinear -@test isequal(a, -1) -@test isequal(b, D(x)) - @testset "linear_expansion with array variables" begin @variables x[1:2] y[1:2] z(..) @test !Symbolics.linear_expansion(z(x) + x[1], x[1])[3] @test !Symbolics.linear_expansion(z(x[1]) + x[1], x[1])[3] a, b, islin = Symbolics.linear_expansion(z(x[2]) + x[1], x[1]) - @test islin && isequal(a, 1) && isequal(b, z(x[2])) + @test islin && isequal(value(a), 1) && isequal(value(b), z(x[2])) a, b, islin = Symbolics.linear_expansion((x + x)[1], x[1]) - @test islin && isequal(a, 2) && isequal(b, 0) + @test islin && isequal(value(a), 2) && isequal(value(b), 0) a, b, islin = Symbolics.linear_expansion(y[1], x[1]) - @test islin && isequal(a, 0) && isequal(b, y[1]) + @test islin && isequal(value(a), 0) && isequal(value(b), y[1]) @test !Symbolics.linear_expansion(z([x...]), x[1])[3] @test !Symbolics.linear_expansion(z(collect(Symbolics.unwrap(x))), x[1])[3] @test !Symbolics.linear_expansion(z([x, 2x]), x[1])[3] diff --git a/test/logexpfunctions.jl b/test/logexpfunctions.jl index 37b7f76ef..8eb0a4213 100644 --- a/test/logexpfunctions.jl +++ b/test/logexpfunctions.jl @@ -1,5 +1,6 @@ using Symbolics using LogExpFunctions +using Test N = 10 @@ -22,4 +23,4 @@ vals = Dict(a => _a, b => _b, c => _c, x => _x) @test substitute(log1pexp(a), vals) ≈ log1pexp(_a) @test substitute(logexpm1(c), vals) ≈ logexpm1(_c) @test substitute(logmxp1(c), vals) ≈ logmxp1(_c) -@test substitute(logsumexp(x), vals) ≈ logsumexp(_x) \ No newline at end of file +@test substitute(logsumexp(x), vals) ≈ logsumexp(_x) diff --git a/test/macro.jl b/test/macro.jl index 2faae0eff..89d378e02 100644 --- a/test/macro.jl +++ b/test/macro.jl @@ -1,15 +1,15 @@ using Symbolics -import Symbolics: CallWithMetadata, getsource, getdefaultval, wrap, unwrap, getname -import SymbolicUtils: Term, symtype, FnType, BasicSymbolic, promote_symtype +import Symbolics: getsource, getdefaultval, wrap, unwrap, getname +import SymbolicUtils: Term, symtype, FnType, BasicSymbolic, promote_symtype, SymReal, Const using LinearAlgebra using Test @variables t Symbolics.@register_symbolic fff(t) -@test isequal(fff(t), Symbolics.Num(Symbolics.Term{Real}(fff, [Symbolics.value(t)]))) +@test isequal(fff(t), Symbolics.Num(Symbolics.Term{SymReal}(fff, [Symbolics.value(t)]; type = Real))) const SymMatrix{T,N} = Symmetric{T, AbstractArray{T, N}} -many_vars = @variables t=0 a=1 x[1:4]=2 y(t)[1:4]=3 w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4] +many_vars = @variables t=0 a=1 x[1:4] y(t)[1:4] w[1:4] = 1:4 z(t)[1:4] = 2:5 p(..)[1:4] let @register_array_symbolic ggg(x::AbstractVector) begin @@ -24,14 +24,14 @@ let @test ndims(gg) == 2 @test size(gg) == (8,8) - @test eltype(gg) == Real + @test eltype(gg) == Num @test symtype(unwrap(gg)) == SymMatrix{Real, 2} @test promote_symtype(ggg, symtype(unwrap(x))) == Any # no promote_symtype defined gg = ggg([a, 2a]) @test ndims(gg) == 2 @test size(gg) == (4, 4) - @test eltype(gg) == Real + @test eltype(gg) == Num @test symtype(unwrap(gg)) == SymMatrix{Real, 2} @test promote_symtype(ggg, Vector{symtype(typeof(a))}) == Any @@ -39,7 +39,7 @@ let gg = ggg([_a, 2_a]) @test ndims(gg) == 2 @test size(gg) == (4, 4) - @test eltype(gg) == Real + @test eltype(symtype(gg)) == Real @test symtype(unwrap(gg)) == SymMatrix{Real, 2} @test promote_symtype(ggg, Vector{symtype(typeof(a))}) == Any end @@ -86,29 +86,24 @@ end false # without promote_symtype hh = ccwa(gg, x) @test size(hh) == (8,4,10) -@test eltype(hh) == Real +@test eltype(hh) == Num @test isequal(arguments(unwrap(hh)), unwrap.([gg, x])) _args = [[a 2a; 4a 6a; 3a 5a], [4a, 6a]] hh = ccwa(_args...) @test size(hh) == (3, 2, 10) -@test eltype(hh) == Real -@test isequal(arguments(unwrap(hh)), unwrap.(_args)) +@test eltype(hh) == Num +@test isequal(arguments(unwrap(hh)), Const{SymReal}.(unwrap.(_args))) @test all(t->getsource(t)[1] === :variables, many_vars) @test getdefaultval(t) == 0 @test getdefaultval(a) == 1 -@test getdefaultval(x) == 2 -@test getdefaultval(x[1]) == 2 -@test getdefaultval(y[2]) == 3 @test getdefaultval(w[2]) == 2 @test getdefaultval(w[4]) == 4 @test getdefaultval(z[3]) == 4 @test symtype(p) <: FnType{Tuple, Array{Real,1}} @test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == Any -@test p(t)[1] isa Symbolics.Num - struct CanCallWithArray2{T} params::T @@ -119,7 +114,7 @@ ccwa = CanCallWithArray2((length=10,)) size=(size(x, 1), length(b), c.params.length) eltype=Real end -@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == AbstractArray{Real} +@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == Array{Real} struct CanCallWithArray3{T} params::T @@ -132,7 +127,7 @@ ccwa = CanCallWithArray3((length=10,)) eltype=Real ndims = 3 end -@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == AbstractArray{Real, 3} +@test promote_symtype(ccwa, symtype(unwrap(gg)), symtype(unwrap(x))) == Array{Real, 3} ## Wrapper types @@ -165,9 +160,9 @@ end @test applicable(foo, wrap(x), 2) @test applicable(foo, wrap(x), wrap(2)) -@test foo(x, wrap(2)) isa FooWrap -@test foo(x, wrap(1)) isa Num -@test foo(x, wrap(6)) isa String +# @test foo(x, wrap(2)) isa FooWrap +# @test foo(x, wrap(1)) isa Num +# @test foo(x, wrap(6)) isa String let @@ -199,7 +194,7 @@ Symbolics.@register_symbolic bar(t, x::A) Symbolics.@register_symbolic baz(x, y) if !@isdefined(bar_catchall_defined) @test_throws MethodError bar(0.1, A()) - @test_throws MethodError bar(Num(0.1), A()) + # @test_throws MethodError bar(Num(0.1), A()) else @warn("skipping 2 tests because this file was run more than once") end @@ -239,15 +234,15 @@ let end @variables t y(t) -yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple{Any}, Real}) +yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple, Real, Nothing}) yyy = yy(t) @test isequal(yyy, y) @test yyy isa Num @test y isa Num -yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple, Real}) +yy = Symbolics.variable(:y, T = Symbolics.FnType{Tuple{Real}, Real, Nothing}) yyy = yy(t) @test !isequal(yyy, y) -@variables y(..) +@variables y(::Real) @test isequal(yyy, y(t)) spam(x) = 2x @@ -255,7 +250,8 @@ spam(x) = 2x sym = spam([a, 2a]) @test sym isa Num -@test unwrap(sym) isa BasicSymbolic{Real} +@test unwrap(sym) isa BasicSymbolic{SymReal} +@test symtype(sym) === Real fn_defaults = [print, min, max, identity, (+), (-), max, sum, vcat, (*)] fn_names = [Symbol(:f, i) for i in 1:10] @@ -266,34 +262,34 @@ Symbolics.option_to_metadata_type(::Val{:foo}) = VariableFoo function test_all_functions(fns) f1, f2, f3, f4, f5, f6, f7, f8, f9, f10 = fns @variables x y::Int z::Function w[1:3, 1:3] v[1:3, 1:3]::String - @test f1 isa CallWithMetadata{FnType{Tuple, Real}} + @test symtype(unwrap(f1)) === FnType{Tuple, Real, Nothing} @test all(x -> symtype(x) <: Real, [f1(), f1(1), f1(x), f1(x, y), f1(x, y, x+y)]) - @test f2 isa CallWithMetadata{FnType{Tuple{Any, Vararg}, Int}} + @test symtype(unwrap(f2)) === FnType{Tuple{Any, Vararg{Any}}, Int, Nothing} @test all(x -> symtype(x) <: Int, [f2(1), f2(z), f2(x), f2(x, y), f2(x, y, x+y)]) @test_throws ErrorException f2() - @test f3 isa CallWithMetadata{FnType{Tuple, Real, typeof(max)}} + @test symtype(unwrap(f3)) === FnType{Tuple, Real, typeof(max)} @test all(x -> symtype(x) <: Real, [f3(), f3(1), f3(x), f3(x, y), f3(x, y, x+y)]) - @test f4 isa CallWithMetadata{FnType{Tuple{Int}, Real}} + @test symtype(unwrap(f4)) === FnType{Tuple{Int}, Real, Nothing} @test all(x -> symtype(x) <: Real, [f4(1), f4(y), f4(2y)]) @test_throws ErrorException f4(x) - @test f5 isa CallWithMetadata{FnType{Tuple{Int, Vararg{Int}}, Real}} + @test symtype(unwrap(f5)) === FnType{Tuple{Int, Vararg{Int}}, Real, Nothing} @test all(x -> symtype(x) <: Real, [f5(1), f5(y), f5(y, y), f5(2, 3)]) @test_throws ErrorException f5(x) - @test f6 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int}} + @test symtype(unwrap(f6)) === FnType{Tuple{Int, Int}, Int, Nothing} @test all(x -> symtype(x) <: Int, [f6(1, 1), f6(y, y), f6(1, y), f6(y, 1)]) @test_throws ErrorException f6() @test_throws ErrorException f6(1) @test_throws ErrorException f6(x, y) @test_throws ErrorException f6(y) - @test f7 isa CallWithMetadata{FnType{Tuple{Int, Int}, Int, typeof(max)}} + @test symtype(unwrap(f7)) === FnType{Tuple{Int, Int}, Int, typeof(max)} # call behavior tested by f6 - @test f8 isa CallWithMetadata{FnType{Tuple{Function, Vararg}, Real, typeof(sum)}} + @test symtype(unwrap(f8)) === FnType{Tuple{Function, Vararg}, Real, typeof(sum)} @test all(x -> symtype(x) <: Real, [f8(z), f8(z, x), f8(identity), f8(identity, x)]) @test_throws ErrorException f8(x) @test_throws ErrorException f8(1) - @test f9 isa CallWithMetadata{FnType{Tuple, Vector{Real}}} + @test symtype(unwrap(f9)) === FnType{Tuple, Vector{Real}, Nothing} @test all(x -> symtype(unwrap(x)) <: Vector{Real} && size(x) == (3,), [f9(), f9(1), f9(x), f9(x + y), f9(z), f9(1, x)]) - @test f10 isa CallWithMetadata{FnType{Tuple{Matrix{<:Real}, Matrix{<:Real}}, Matrix{Real}, typeof(*)}} + @test symtype(unwrap(f10)) === FnType{Tuple{Matrix{<:Real}, Matrix{<:Real}}, Matrix{Real}, typeof(*)} @test all(x -> symtype(unwrap(x)) <: Matrix{Real} && size(x) == (3, 3), [f10(w, w), f10(w, ones(3, 3)), f10(ones(3, 3), ones(3, 3)), f10(w + w, w)]) @test_throws ErrorException f10(w, v) end @@ -426,10 +422,10 @@ end @test getdefaultval(x) isa BasicSymbolic @test Symbolics.getmetadata(unwrap(x), VariableFoo, nothing) isa BasicSymbolic @test getdefaultval(y) isa BasicSymbolic - @test Symbolics.getmetadata(unwrap(y), VariableFoo, nothing) isa Vector{<:BasicSymbolic} + @test Symbolics.getmetadata(unwrap(y), VariableFoo, nothing) isa Vector{Num} end -@testset "`hash(::CallWithMetadata)` is consistent with `isequal`" begin +@testset "`hash` of callable is consistent with `isequal`" begin @variables f(..) ff = setmetadata(f, Int, 3) @test isequal(f, ff) diff --git a/test/overloads.jl b/test/overloads.jl index f148da191..3b4340650 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -11,8 +11,8 @@ vars = @variables t $a $b(t) $c(t)[1:3] @test b === :value_b @test c === :value_c @test isequal(vars[1], t) -@test isequal(vars[2], Num(Sym{Real}(a))) -@test isequal(vars[3], Num(Sym{FnType{Tuple{Any},Real}}(b)(value(t)))) +@test isequal(vars[2], Num(Sym{SymReal}(a; type = Real))) +@test isequal(vars[3], Num(Sym{SymReal}(b; type = FnType{Tuple,Real,Nothing})(value(t)))) vars = @variables a,b,c,d,e,f,g,h,i @test isequal(vars, [a,b,c,d,e,f,g,h,i]) @@ -30,7 +30,7 @@ aa = a; # old a @test isequal(a, aa) @test hash(a) == hash(aa) -@test isequal(Symbolics.get_variables(a+aa+1), [a]) +@test isequal(Symbolics.get_variables(a+aa+1), Set([a])) @test hash(a+b ~ c+d) == hash(a+b ~ c+d) @@ -112,26 +112,20 @@ M \ [1, 2] # test det @variables X[1:4,1:4] d1 = det(X, laplace=true) -d2 = det(X, laplace=false) _det1 = eval(build_function(d1,X)) -_det2 = eval(build_function(d2,X)) A = [1 1 1 1 1 0 1 1 1 1 0 1 1 1 1 0] @test _det1(map(Num, A)) == -1 -@test _det2(map(Num, A)) == -1 @variables X[1:3,1:3] d1 = det(X, laplace=true) -d2 = det(X, laplace=false) _det1 = eval(build_function(d1, X)) -_det2 = eval(build_function(d2, X)) A = [1 1 1 1 0 1 1 1 1] @test _det1(map(Num, A)) == 0 -@test _det2(map(Num, A)) == 0 @variables a b c d z1 = a + b * im @@ -163,7 +157,7 @@ z2 = c + d * im @test conj(a) === a @test imag(a) === Num(0) -@test isequal(sign(x), Num(SymbolicUtils.Term{Int}(sign, [Symbolics.value(x)]))) +@test isequal(sign(x), Num(SymbolicUtils.Term{SymReal}(sign, [Symbolics.value(x)]; type = Real))) @test sign(Num(1)) isa Num @test isequal(sign(Num(1)), Num(1)) @test isequal(sign(Num(-1)), Num(-1)) @@ -185,7 +179,7 @@ x = Num.(randn(10)) @test norm(x, 1) == norm(Symbolics.value.(x), 1) @test norm(x, 1.2) == norm(Symbolics.value.(x), 1.2) -@test clamp.(x, 0, 1) == clamp.(Symbolics.value.(x), 0, 1) +@test value.(clamp.(x, 0, 1)) == clamp.(Symbolics.value.(x), 0, 1) @test isequal(Symbolics.derivative(clamp(a, 0, 1), a), ifelse(a < 0, 0, ifelse(a>1, 0, 1))) @variables x[1:2] @@ -205,7 +199,7 @@ x = Num.(randn(10)) @variables x[1:3] ex = x[1]+x[2] -@test isequal(Symbolics.get_variables(ex), Symbolics.scalarize(x[1:2])) +@test issetequal(Symbolics.get_variables(ex), Symbolics.scalarize(x[1:2])) @variables x A = [x[1] 2 diff --git a/test/parsing.jl b/test/parsing.jl index 87f50de55..e908c8309 100644 --- a/test/parsing.jl +++ b/test/parsing.jl @@ -32,4 +32,4 @@ eqs = parse_expr_to_symbolic.(ex, (@__MODULE__,)) ex = [m[2] ~ m[1] m[2] ~ -2m[1] + 3 / m[3] m[3] ~ 2] -@test all(isequal.(eqs,ex)) \ No newline at end of file +@test all(isequal.(eqs,ex)) diff --git a/test/rewrite_helpers.jl b/test/rewrite_helpers.jl index f5fe955c8..7b34bcb35 100644 --- a/test/rewrite_helpers.jl +++ b/test/rewrite_helpers.jl @@ -12,7 +12,7 @@ my_f(x, y) = x^3 + 2y # Check `replacenode` function. let # Simple replacements. - @test isequal(replacenode(X + X + X, X =>1), 3) + @test isequal(unwrap_const(replacenode(X + X + X, X =>1)), 3) @test isequal(replacenode(X + X + X, Y => 1), 3X) res = replacenode(X + X + my_f(X, Z), X => Y) @test isequal(res, Y^3 + 2Y + 2Z) @@ -26,8 +26,8 @@ let @test isequal(replacenode(X + sin(Y + a) + a, rep_func), X + sin(Y + a) + a) # On non-symbolic inputs. - @test isequal(replacenode(1, X =>2.0), 1) - @test isequal(replacenode(1, rep_func), 1) + @test isequal(unwrap_const(replacenode(1, X =>2.0)), 1) + @test isequal(unwrap_const(replacenode(1, rep_func)), 1) end # Test `hasnode` function. @@ -127,7 +127,7 @@ let @test isequal(filterchildren(is_derivative, ex1), []) @test isequal(filterchildren(is_derivative, ex2), []) @test isequal(filterchildren(is_derivative, ex3), []) - @test isequal(filterchildren(is_derivative, ex4), [D(Y), D(my_f(1,Z))]) + @test issetequal(filterchildren(is_derivative, ex4), [D(Y), D(my_f(1,Z))]) end # https://github.com/JuliaSymbolics/Symbolics.jl/issues/1175 diff --git a/test/runtests.jl b/test/runtests.jl index d5725fe40..2c375e856 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,7 +42,6 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Differentiation Test" begin include("diff.jl") end @safetestset "Utils Test" begin include("utils.jl") end @safetestset "ADTypes Test" begin include("adtypes.jl") end - @safetestset "Difference Test" begin include("difference.jl") end @safetestset "Degree Test" begin include("degree.jl") end @safetestset "Coeff Test" begin include("coeff.jl") end @safetestset "Parsing Test" begin include("parsing.jl") end @@ -60,9 +59,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Domain Test" begin include("domains.jl") end @safetestset "Inequality Test" begin include("inequality.jl") end @safetestset "Integral Test" begin include("integral.jl") end - @safetestset "CartesianIndex Test" begin include("cartesianindex.jl") end @safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end - @safetestset "LuxCore extensions Test" begin include("extensions/lux.jl") end @safetestset "Registration without using Test" begin include("registration_without_using.jl") end @safetestset "Show Test" begin include("show.jl") end @safetestset "Utility Function Test" begin include("utils.jl") end @@ -88,19 +85,19 @@ end if GROUP == "All" || GROUP == "Core" || GROUP == "SymbolicIndexingInterface" @safetestset "SymbolicIndexingInterface Trait Test" begin - include("symbolic_indexing_interface_trait.jl") + include("symbolic_indexing_interface_trait.jl") end @safetestset "SymbolicIndexingInterface Parameter Indexing Test" begin - include("symbolic_indexing_interface_parameter_indexing.jl") + include("symbolic_indexing_interface_parameter_indexing.jl") end @safetestset "SymbolicIndexingInterface Symbolic Evaluate Test" begin - include("symbolic_indexing_interface_symbolic_evaluate.jl") + include("symbolic_indexing_interface_symbolic_evaluate.jl") end end if GROUP == "All" || GROUP == "Downstream" activate_downstream_env() - #@time @safetestset "ParameterizedFunctions MATLABDiffEq Regression Test" begin include("downstream/ParameterizedFunctions_MATLAB.jl") end + @time @safetestset "ParameterizedFunctions MATLABDiffEq Regression Test" begin include("downstream/ParameterizedFunctions_MATLAB.jl") end @safetestset "ModelingToolkit Variable Utils Test" begin include("downstream/modeling_toolkit_utils.jl") end @safetestset "DI Test" begin include("downstream/differentiation_interface.jl") end end diff --git a/test/semipoly.jl b/test/semipoly.jl index 208240766..88a2b56fe 100644 --- a/test/semipoly.jl +++ b/test/semipoly.jl @@ -1,25 +1,29 @@ using Symbolics +using Symbolics: unwrap +using SymbolicUtils: Const using Test using Random @variables x y z +const CONST_1 = Const{SymReal}(1) + @testset "simple expressions" begin d, r = semipolynomial_form(x, [x], 1) - @test d == Dict(x => 1) - @test r == 0 + @test isequal(d, Dict(x => CONST_1)) + @test unwrap_const(r) == 0 d, r = semipolynomial_form(x + sin(x) + 1 + y, [x], 1) - @test isequal(d, Dict(1 => 1 + y, x => 1)) + @test isequal(d, Dict(CONST_1 => 1 + y, x => CONST_1)) @test isequal(r, sin(x)) d, r = semipolynomial_form(x^2 + 1 + y, [x], 1) - @test isequal(d, Dict(1 => 1 + y)) + @test isequal(d, Dict(CONST_1 => 1 + y)) @test isequal(r, x^2) d, r = semipolynomial_form((x + 2)^12, [x], 1) - @test isequal(d, Dict(1 => 1 << 12, x => (1 << 11) * 12)) + @test isequal(d, Dict(CONST_1 => Const{SymReal}(1 << 12), x => Const{SymReal}((1 << 11) * 12))) end @testset "maintain SymbolicUtils.Symbolic subtype" begin @@ -48,7 +52,7 @@ end @test SymbolicUtils.isdiv(nl) dict, nl = semipolynomial_form(div_expr, [x], Inf) @test isequal(dict, Dict(x => 1 / y)) - @test iszero(nl) + @test iszero(unwrap_const(nl)) dict, nl = semipolynomial_form(div_expr, [y], Inf) @test isempty(dict) @test isequal(nl, div_expr) @@ -81,68 +85,68 @@ end expr = y_1 * (x + y) # (x + y)*(y^-1) d, r = semipolynomial_form(expr, [x, y], Inf) - @test isequal(d, Dict(1 => 1)) + @test isequal(d, Dict(CONST_1 => CONST_1)) @test isequal(r, x / y) d, r = semipolynomial_form(expr, [x, y], 0) - @test isequal(d, Dict(1 => 1)) + @test isequal(d, Dict(CONST_1 => CONST_1)) @test isequal(r, x / y) d, r = semipolynomial_form(expr, [x, y], 0.3) - @test isequal(d, Dict(1 => 1)) + @test isequal(d, Dict(CONST_1 => CONST_1)) @test isequal(r, x / y) d, r = semipolynomial_form(expr, [x, y], 1 // 2) - @test isequal(d, Dict(1 => 1)) + @test isequal(d, Dict(CONST_1 => CONST_1)) @test isequal(r, x / y) d, r = semipolynomial_form(expr, [x], 1) - @test isequal(d, Dict(1 => 1, x => 1 / y)) || isequal(d, Dict(1 => 1, x => y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => CONST_1, x => 1 / y)) || isequal(d, Dict(CONST_1 => CONST_1, x => y_1)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [x], 1 // 1) - @test isequal(d, Dict(1 => 1, x => 1 / y)) || isequal(d, Dict(1 => 1, x => y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => CONST_1, x => 1 / y)) || isequal(d, Dict(CONST_1 => CONST_1, x => y_1)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [x], Int32(1)) - @test isequal(d, Dict(1 => 1, x => 1 / y)) || isequal(d, Dict(1 => 1, x => y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => CONST_1, x => 1 / y)) || isequal(d, Dict(CONST_1 => CONST_1, x => y_1)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [x], 1.0) - @test isequal(d, Dict(1 => 1, x => 1 / y)) || isequal(d, Dict(1 => 1, x => y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => CONST_1, x => 1 / y)) || isequal(d, Dict(CONST_1 => CONST_1, x => y_1)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [x], Float32(1.0)) - @test isequal(d, Dict(1 => 1, x => 1 / y)) || isequal(d, Dict(1 => 1, x => y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => CONST_1, x => 1 / y)) || isequal(d, Dict(CONST_1 => CONST_1, x => y_1)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [x], 1.5) - @test isequal(d, Dict(1 => 1, x => 1 / y)) || isequal(d, Dict(1 => 1, x => y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => CONST_1, x => 1 / y)) || isequal(d, Dict(CONST_1 => CONST_1, x => y_1)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [x], 0.9) - @test isequal(d, Dict(1 => 1)) + @test isequal(d, Dict(CONST_1 => CONST_1)) @test isequal(r, x / y) || isequal(r, x * y_1) d, r = semipolynomial_form(expr, [x], 99 // 100) - @test isequal(d, Dict(1 => 1)) + @test isequal(d, Dict(CONST_1 => CONST_1)) @test isequal(r, x / y) || isequal(r, x * y_1) d, r = semipolynomial_form(expr, [], 0.9) - @test isequal(d, Dict(1 => 1 + x / y)) || isequal(d, Dict(1 => 1 + x * y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => expr)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [], 0) - @test isequal(d, Dict(1 => 1 + x / y)) || isequal(d, Dict(1 => 1 + x * y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => expr)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [], 0.0) - @test isequal(d, Dict(1 => 1 + x / y)) || isequal(d, Dict(1 => 1 + x * y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => expr)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [], 0 // 1) - @test isequal(d, Dict(1 => 1 + x / y)) || isequal(d, Dict(1 => 1 + x * y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => expr)) + @test iszero(unwrap_const(r)) end # 682 @@ -159,59 +163,59 @@ end @test isequal(r, y_1 * x^4 + y^3) || isequal(r, x^4 / y + y^3) d, r = semipolynomial_form(expr, [x, y], 3) - @test isequal(d, Dict(y^3 => 1)) + @test isequal(d, Dict(y^3 => CONST_1)) @test isequal(r, y_1 * x^4) || isequal(r, x^4 / y) d, r = semipolynomial_form(expr, [x, y], 4) - @test isequal(d, Dict(y^3 => 1)) + @test isequal(d, Dict(y^3 => CONST_1)) @test isequal(r, y_1 * x^4) || isequal(r, x^4 / y) d, r = semipolynomial_form(expr, [x], 2) - @test isequal(d, Dict(1 => y^3)) + @test isequal(d, Dict(CONST_1 => y^3)) @test isequal(r, y_1 * x^4) || isequal(r, x^4 / y) d, r = semipolynomial_form(expr, [x], 3) - @test isequal(d, Dict(1 => y^3)) + @test isequal(d, Dict(CONST_1 => y^3)) @test isequal(r, y_1 * x^4) || isequal(r, x^4 / y) d, r = semipolynomial_form(expr, [x], 4) - @test isequal(d, Dict(1 => y^3, x^4 => y_1)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => y^3, x^4 => y_1)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [y], 2) @test isempty(d) @test isequal(r, y_1 * x^4 + y^3) || isequal(r, x^4 / y + y^3) d, r = semipolynomial_form(expr, [y], 3) - @test isequal(d, Dict(y^3 => 1)) + @test isequal(d, Dict(y^3 => CONST_1)) @test isequal(r, y_1 * x^4) || isequal(r, x^4 / y) d, r = semipolynomial_form(expr, [y], 4) - @test isequal(d, Dict(y^3 => 1)) + @test isequal(d, Dict(y^3 => CONST_1)) @test isequal(r, y_1 * x^4) || isequal(r, x^4 / y) d, r = semipolynomial_form(expr, [], 0) - @test isequal(d, Dict(1 => y_1 * x^4 + y^3)) || isequal(d, Dict(1 => x^4 / y + y^3)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => expr)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [], 2) - @test isequal(d, Dict(1 => y_1 * x^4 + y^3)) || isequal(d, Dict(1 => x^4 / y + y^3)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => expr)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [], 3) - @test isequal(d, Dict(1 => y_1 * x^4 + y^3)) || isequal(d, Dict(1 => x^4 / y + y^3)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => expr)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [], 4) - @test isequal(d, Dict(1 => y_1 * x^4 + y^3)) || isequal(d, Dict(1 => x^4 / y + y^3)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => expr)) + @test iszero(unwrap_const(r)) end @testset "nested ^ exponentiation" begin expr = ((x + 1)^4 + x)^3 d, r = semipolynomial_form(expr, [x], 2) - @test isequal(d, Dict(1 => 1, x => 15, x^2 => 93)) + @test isequal(d, Dict(CONST_1 => CONST_1, x => Const{SymReal}(15), x^2 => Const{SymReal}(93))) @test isequal(r, 317x^3 + 681x^4 + 1014x^5 + 1095x^6 + 876x^7 + 519x^8 + 223x^9 + 66x^10 + 12x^11 + x^12) end @@ -220,20 +224,20 @@ end expr = y^(1//1) d, r = semipolynomial_form(expr, [x], 0) - @test isequal(d, Dict(1 => y)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => y)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [x], 1) - @test isequal(d, Dict(1 => y)) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => y)) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [y], 0) @test isempty(d) @test isequal(r, y) d, r = semipolynomial_form(expr, [y], 1) - @test isequal(d, Dict(y => 1)) - @test iszero(r) + @test isequal(d, Dict(y => CONST_1)) + @test iszero(unwrap_const(r)) expr = (x^(4//3) + y^(5//2))^3 @@ -254,60 +258,60 @@ end @test isequal(r, x^4 + 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2)) d, r = semipolynomial_form(expr, [x, y], 4) - @test isequal(d, Dict(x^4 => 1)) + @test isequal(d, Dict(x^4 => CONST_1)) @test isequal(r, 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2)) d, r = semipolynomial_form(expr, [x, y], 5) - @test isequal(d, Dict(x^4 => 1)) + @test isequal(d, Dict(x^4 => CONST_1)) @test isequal(r, 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2)) d, r = semipolynomial_form(expr, [x, y], 6) - @test isequal(d, Dict(x^4 => 1)) + @test isequal(d, Dict(x^4 =>CONST_1)) @test isequal(r, 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2)) d, r = semipolynomial_form(expr, [y], 3) - @test isequal(d, Dict(1 => x^4)) + @test isequal(d, Dict(CONST_1 => x^4)) @test isequal(r, 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2)) d, r = semipolynomial_form(expr, [y], 4) - @test isequal(d, Dict(1 => x^4)) + @test isequal(d, Dict(CONST_1 => x^4)) @test isequal(r, 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2)) d, r = semipolynomial_form(expr, [y], 5) - @test isequal(d, Dict(1 => x^4, y^5 => 3x^(4//3))) + @test isequal(d, Dict(CONST_1 => x^4, y^5 => 3x^(4//3))) @test isequal(r, 3x^(8//3) * y^(5//2) + y^(15//2)) d, r = semipolynomial_form(expr, [y], 6) - @test isequal(d, Dict(1 => x^4, y^5 => 3x^(4//3))) + @test isequal(d, Dict(CONST_1 => x^4, y^5 => 3x^(4//3))) @test isequal(r, 3x^(8//3) * y^(5//2) + y^(15//2)) d, r = semipolynomial_form(expr, [x], 3) - @test isequal(d, Dict(1 => y^(15//2))) + @test isequal(d, Dict(CONST_1 => y^(15//2))) @test isequal(r, x^4 + 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5) d, r = semipolynomial_form(expr, [x], 4) - @test isequal(d, Dict(1 => y^(15//2), x^4 => 1)) + @test isequal(d, Dict(CONST_1 => y^(15//2), x^4 => CONST_1)) @test isequal(r, 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5) d, r = semipolynomial_form(expr, [x], 5) - @test isequal(d, Dict(1 => y^(15//2), x^4 => 1)) + @test isequal(d, Dict(CONST_1 => y^(15//2), x^4 => CONST_1)) @test isequal(r, 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5) d, r = semipolynomial_form(expr, [x], 6) - @test isequal(d, Dict(1 => y^(15//2), x^4 => 1)) + @test isequal(d, Dict(CONST_1 => y^(15//2), x^4 => CONST_1)) @test isequal(r, 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5) d, r = semipolynomial_form(expr, [], 0) - @test isequal(d, Dict(1 => x^4 + 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2))) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => x^4 + 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2))) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [], 1) - @test isequal(d, Dict(1 => x^4 + 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2))) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => x^4 + 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2))) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(expr, [], 2) - @test isequal(d, Dict(1 => x^4 + 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2))) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => x^4 + 3x^(8//3) * y^(5//2) + 3x^(4//3) * y^5 + y^(15//2))) + @test iszero(unwrap_const(r)) expr = (x + y)^(1//2) @@ -324,21 +328,21 @@ end @test isequal(r, x + 2x^0.5 * y + y^2) d, r = semipolynomial_form(expr, [x, y], 1) - @test isequal(d, Dict(x => 1)) + @test isequal(d, Dict(x => CONST_1)) @test isequal(r, 2x^0.5 * y + y^2) d, r = semipolynomial_form(expr, [x, y], 2) - @test isequal(d, Dict(x => 1, y^2 => 1)) + @test isequal(d, Dict(x => CONST_1, y^2 => CONST_1)) @test isequal(r, 2x^0.5 * y) d, r = semipolynomial_form(expr, [x, y], 3) - @test isequal(d, Dict(x => 1, y^2 => 1)) + @test isequal(d, Dict(x => CONST_1, y^2 => CONST_1)) @test isequal(r, 2x^0.5 * y) # 680 expr = (x^(1//2) + y^0.5)^2 d, r = semipolynomial_form(expr, [x, y], 4) - @test isequal(d, Dict(x => 1, y => 1)) + @test isequal(d, Dict(x => CONST_1, y => CONST_1)) @test isequal(r, 2x^(1//2) * y^(1//2)) expr = (3x^4 + y)^0.5 @@ -357,7 +361,7 @@ end expr = (x^2 - 1) / (x - 1) d, r = semipolynomial_form(expr, [x, y], Inf) @test isempty(d) - @test isequal(r, expr) + @test isequal(SymbolicUtils.simplify_fractions(r), SymbolicUtils.simplify_fractions(expr)) end @testset "semilinear" begin @@ -368,30 +372,30 @@ end y / z + 5z, x * y + y * z / x] A , c = semilinear_form(exprs, [x, y, z]) - @test A[1, 1] == 3 - @test A[1, 2] == 0 - @test A[1, 3] == 0 - @test A[2, 1] == 0 - @test A[2, 2] == 0 - @test A[2, 3] == 5 - @test A[3, 1] == 0 - @test A[3, 2] == 0 - @test A[3, 3] == 0 + @test unwrap_const(unwrap(A[1, 1])) == 3 + @test unwrap_const(unwrap(A[1, 2])) == 0 + @test unwrap_const(unwrap(A[1, 3])) == 0 + @test unwrap_const(unwrap(A[2, 1])) == 0 + @test unwrap_const(unwrap(A[2, 2])) == 0 + @test unwrap_const(unwrap(A[2, 3])) == 5 + @test unwrap_const(unwrap(A[3, 1])) == 0 + @test unwrap_const(unwrap(A[3, 2])) == 0 + @test unwrap_const(unwrap(A[3, 3])) == 0 @test isequal(c, [tan(z), y / z, x * y + y * z / x]) end @testset "expr = 0" begin d, r = semipolynomial_form(0, [], Inf) @test isempty(d) - @test iszero(r) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(0//1, [], Inf) @test isempty(d) - @test iszero(r) + @test iszero(unwrap_const(r)) d, r = semipolynomial_form(0.0, [], Inf) @test isempty(d) - @test iszero(r) + @test iszero(unwrap_const(r)) end @testset "sqrt" begin @@ -399,40 +403,40 @@ end d, r = semipolynomial_form(expr, [x], Inf) @test isempty(d) - @test isequal(r, x^(1//2)) + @test isequal(r, sqrt(x)) expr = sqrt(x)^2 d, r = semipolynomial_form(expr, [x], Inf) - @test isequal(d, Dict(x => 1)) - @test iszero(r) + @test isequal(d, Dict(x => CONST_1)) + @test iszero(unwrap_const(r)) expr = (sqrt(x) + sqrt(y))^2 d, r = semipolynomial_form(expr, [x, y], Inf) - @test isequal(d, Dict(x => 1, y => 1)) - @test isequal(r, 2x^(1//2) * y^(1//2)) + @test isequal(d, Dict(x => CONST_1, y => CONST_1)) + @test isequal(r, 2sqrt(x) * sqrt(y)) d, r = semipolynomial_form(expr, [x], Inf) - @test isequal(d, Dict(x => 1, 1 => y)) - @test isequal(r, 2x^(1//2) * y^(1//2)) + @test isequal(d, Dict(x => CONST_1, CONST_1 => y)) + @test isequal(r, 2sqrt(x) * sqrt(y)) d, r = semipolynomial_form(expr, [], Inf) - @test isequal(d, Dict(1 => x + y + 2x^(1//2) * y^(1//2))) - @test iszero(r) + @test isequal(d, Dict(CONST_1 => x + y + 2sqrt(x) * sqrt(y))) + @test iszero(unwrap_const(r)) end -@syms a b c +@variables a::Real b::Real c::Real const components = [2, a, b, c, x, y, z, (1+x), (1+y)^2, z*y, z*x] -function verify(t::Symbolics.BasicSymbolic{Number}, d, wrt, nl) +function verify(t::Symbolics.BasicSymbolic, d, wrt, nl) verify(Num(t), d, wrt, nl) end function verify(t, d, wrt, nl) try - iszero(t - (isempty(d) ? nl : sum(k*v for (k, v) in d) + nl)) + isequal(unwrap_const(expand(unwrap(t - (isempty(d) ? nl : sum(k*v for (k, v) in d) + nl)))), 0) catch err println("""Error verifying semi-pf result for $t wrt = $wrt @@ -463,17 +467,33 @@ function trial() for deg=Any[1,2,3,4,Inf] if deg == 1 A, c = semilinear_form([t], wrt) - res = iszero(A*wrt + c - [t]) + res = isequal(unwrap_const(expand(unwrap.(A*wrt + c - [t])[1])), 0) if !res println("Semi-linear form is wrong: [$t] w.r.t $wrt ") - @show A c + println("A = ") + display(A) + println() + println("c =") + display(c) + println() end elseif deg == 2 A,B,v2, c = semiquadratic_form([t], wrt) - res = iszero(A * wrt + B * v2 + c - [t]) + res = isequal(unwrap_const(expand(unwrap.(A * wrt + B * v2 + c - [t])[1])), 0) if !res println("Semi-quadratic form is wrong: $t w.r.t $wrt") - @show A B v2 c + println("A = ") + display(A) + println() + println("B = ") + display(B) + println() + println("v2 = ") + display(v2) + println() + println("c = ") + display(c) + println() end else if isfinite(deg) @@ -519,15 +539,15 @@ end @variables t x(t)[1:3] expr = x[2] * 4 + 2x[1] + 2x[3] * x[1] + foo(x) mapping, resid = semipolynomial_form(expr, collect(x), 2) - @test mapping[x[2]] == 4 - @test mapping[x[1]] == 2 - @test mapping[x[3] * x[1]] == 2 + @test unwrap_const(mapping[x[2]]) == 4 + @test unwrap_const(mapping[x[1]]) == 2 + @test unwrap_const(mapping[x[3] * x[1]]) == 2 @test isequal(resid, foo(x)) expr = x[2] * 4 + 2x[1] + 2x[3] * x[1] + foo([x[1]]) mapping, resid = semipolynomial_form(expr, collect(x), 2) - @test mapping[x[2]] == 4 - @test mapping[x[1]] == 2 - @test mapping[x[3] * x[1]] == 2 + @test unwrap_const(mapping[x[2]]) == 4 + @test unwrap_const(mapping[x[1]]) == 2 + @test unwrap_const(mapping[x[3] * x[1]]) == 2 @test isequal(resid, foo([x[1]])) end diff --git a/test/show.jl b/test/show.jl index 1320ace0b..1f3efea28 100644 --- a/test/show.jl +++ b/test/show.jl @@ -1,4 +1,5 @@ using Symbolics +using Test # https://github.com/JuliaSymbolics/Symbolics.jl/issues/1206 # test there is no extra white spaces on the left of e or f when displaying an array diff --git a/test/solver.jl b/test/solver.jl index 16089e01f..85ec3c827 100644 --- a/test/solver.jl +++ b/test/solver.jl @@ -1,13 +1,16 @@ using Symbolics import Symbolics: ssqrt, slog, scbrt, symbolic_solve, ia_solve, postprocess_root - -@testset "ia_solve without Nemo" begin - @test Base.get_extension(Symbolics, :SymbolicsNemoExt) === nothing - @variables x - roots = ia_solve(log(2 + x), x) - roots = @test_warn ["Nemo", "required"] ia_solve(log(2 + x^2), x) - @test operation(roots[1]) == Symbolics.RootsOf -end +using SymbolicUtils +using SymbolicUtils: Const +using Test + +# @testset "ia_solve without Nemo" begin +# @test Base.get_extension(Symbolics, :SymbolicsNemoExt) === nothing +# @variables x +# roots = ia_solve(log(2 + x), x) +# roots = @test_warn ["Nemo", "required"] ia_solve(log(2 + x^2), x) +# @test operation(roots[1]) === Symbolics.RootsOf +# end using Groebner, Nemo E = Base.MathConstants.e @@ -56,8 +59,8 @@ function check_approx(arr1, arr2) if !isequal(keys(arr1[i]), keys(arr2[i])) return false end - if !all(isapprox.(values(arr1[i]), values(arr2[i]), atol=1e-6)) - return false + for (k1, v1) in arr1[i] + isapprox(arr2[i][k1], v1; atol = 1e-6) || return false end end return true @@ -99,21 +102,21 @@ end end @testset "Deg 1 univar" begin - @test isequal(symbolic_solve(x+1, x), [-1]) + @test isequal(unwrap_const(only(symbolic_solve(x+1, x))), -1) - @test isequal(symbolic_solve(2x+1, x), [-1/2]) + @test isequal(unwrap_const(only(symbolic_solve(2x+1, x))), -1/2) - @test isequal(symbolic_solve(x, x), [0]) + @test isequal(unwrap_const(only(symbolic_solve(x, x))), 0) - @test isequal(symbolic_solve((x+1)^20, x), [-1]) + @test isequal(unwrap_const(only(symbolic_solve((x+1)^20, x))), -1) - @test isequal(Symbolics.get_roots_deg1(x + y^3, x), [-y^3]) + @test isequal(only(Symbolics.get_roots_deg1(x + y^3, x)), -y^3) expr = x - Symbolics.term(sqrt, 2) @test isequal(symbolic_solve(expr, x)[1], Symbolics.term(sqrt, 2)) expr = x + im - @test symbolic_solve(expr, x)[1] == -im + @test unwrap_const(symbolic_solve(expr, x)[1]) == -im end @testset "Deg 2 univar" begin @@ -209,13 +212,13 @@ end end @testset "Multipoly solver" begin - @test isequal(symbolic_solve([x^2 - 1, x + 1], x)[1], -1) + @test isequal(unwrap_const(symbolic_solve([x^2 - 1, x + 1], x)[1]), -1) @test isequal(symbolic_solve([x^2 - a^2, x + a], x)[1], -a) @test isequal(symbolic_solve([x^20 - a^20, x + a], x)[1], -a) end @testset "Multivar solver" begin @variables x y z - @test isequal(symbolic_solve([x^4 - 1, x - 2], [x]), []) + @test symbolic_solve([x^4 - 1, x - 2], [x]) === nothing # TODO: test this properly sol = symbolic_solve([x^3 + 1, x*y^3 - 1], [x, y]) @@ -263,8 +266,8 @@ end Dict(x=>(complex(-1))^(3/4), y=>0)], [x,y]) @test check_approx(arr_calcd_roots, arr_known_roots) - @test isequal(symbolic_solve([x*y - 1, y], [x,y]), []) - @test isequal(symbolic_solve([x+y+1, x+y+2], [x,y]), []) + @test symbolic_solve([x*y - 1, y], [x,y]) === nothing + @test symbolic_solve([x+y+1, x+y+2], [x,y]) === nothing eqs = [-1 + y + z + x^2, -1 + x + z + y^2, @@ -313,10 +316,10 @@ end @testset "Multivar parametric" begin @variables x y a - @test isequal(symbolic_solve([x + a, a - 1], x), [-1]) + @test isequal(value(only(symbolic_solve([x + a, a - 1], x))), -1) @test isequal(symbolic_solve([x - a, y + a], [x, y]), [Dict(y => -a, x => a)]) - @test isequal(symbolic_solve([x*y - a, x*y + x], [x, y]), [Dict(y => -1, x => -a)]) - @test isequal(symbolic_solve([x*y - a, 1 ~ 3], [x, y]), []) + @test isequal(symbolic_solve([x*y - a, x*y + x], [x, y]), [Dict(y => Const{SymReal}(-1), x => -a)]) + @test symbolic_solve([x*y - a, 1 ~ 3], [x, y]) === nothing @test isnothing(symbolic_solve([x*y - 1, sin(x)], [x, y])) @@ -324,7 +327,7 @@ end @variables t w u v sol = symbolic_solve([t*w - 1 ~ 4, u + v + w ~ 1], [t,w]) - @test isequal(sol, [Dict(t => -5 / (-1 + u + v), w => 1 - u - v)]) + @test isequal(sol, [Dict(t => -5 / (-1//1 + u + v), w => 1//1 - u - v)]) sol = symbolic_solve([x-y, y-z], [x]) @test isequal(sol, [z]) @@ -367,9 +370,9 @@ end @test isequal(Symbolics.postprocess_root(term(^, 0, __x)), 0) @test_broken isequal(Symbolics.postprocess_root(term(/, __x, 0)), Inf) - @test Symbolics.postprocess_root(term(^, __x, 0) ) == 1 - @test Symbolics.postprocess_root(term(^, Base.MathConstants.e, 0) ) == 1 - @test Symbolics.postprocess_root(term(^, Base.MathConstants.pi, 1) ) == Base.MathConstants.pi + @test value(Symbolics.postprocess_root(term(^, __x, 0) )) == 1 + @test value(Symbolics.postprocess_root(term(^, Base.MathConstants.e, 0) )) == 1 + @test value(Symbolics.postprocess_root(term(^, Base.MathConstants.pi, 1) )) ≈ Base.MathConstants.pi @test isequal(Symbolics.postprocess_root(term(^, __x, 1) ), __x) x = Symbolics.term(sqrt, 2) @@ -436,11 +439,11 @@ end rhs = Symbolics.term(^, -c.val/a.val, 1/b.val) @test_broken isequal(lhs, rhs) - @test isequal(symbolic_solve(2/x, x)[1], Inf) - @test isequal(symbolic_solve(x^1.5, x)[1], 0) + @test isequal(value(symbolic_solve(2/x, x)[1]), Inf) + @test isequal(value(symbolic_solve(x^1.5, x)[1]), 0) lhs = symbolic_solve(log(a*x)-b,x)[1] - @test isequal(Symbolics.unwrap(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E) + @test isequal(Symbolics.value(Symbolics.ssubs(lhs, Dict(a=>1, b=>1))), 1E) expr = x + 2 lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x))) @@ -470,8 +473,8 @@ end lhs = eval.(Symbolics.toexpr.(ia_solve(expr, x))) lhs_solve = eval.(Symbolics.toexpr.(symbolic_solve(expr, x))) rhs = [(-im*Base.MathConstants.pi + log(5) - log(12))/(log(2) - log(5))] - @test lhs[1] ≈ rhs[1] - @test lhs_solve[1] ≈ rhs[1] + @test lhs[1] ≈ rhs[1] || lhs[1] ≈ conj(rhs[1]) + @test lhs_solve[1] ≈ rhs[1] || lhs_solve[1] ≈ conj(rhs[1]) expr = exp(2x)*exp(x^2 + 3) + 3 lhs = sort_roots(eval.(Symbolics.toexpr.(ia_solve(expr, x)))) @@ -524,15 +527,15 @@ end @test length(roots) == 6 # 2 quadratic roots * 3 roots from cbrt(3) @test length(Symbolics.get_variables(roots[1])) == 1 _n = only(Symbolics.get_variables(roots[1])) - vals = substitute.(roots, (Dict(_n => 0),)) - @test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals) + vals = eval.(Symbolics.toexpr.(substitute.(roots, (Dict(_n => 0),)))) + @test all(x -> isapprox(value(abs(sec(x^2 + 4x + 4) ^ 3 - 3)), 0.0, atol = 1e-14), vals) roots = ia_solve(expr, x; complex_roots = false) @test length(roots) == 2 # the `n` in `θ + n * 2π` @test length(Symbolics.get_variables(roots[1])) == 1 _n = only(Symbolics.get_variables(roots[1])) - vals = substitute.(roots, (Dict(_n => 0),)) + vals = eval.(Symbolics.toexpr.(substitute.(roots, (Dict(_n => 0),)))) @test all(x -> isapprox(norm(sec(x^2 + 4x + 4) ^ 3 - 3), 0.0, atol = 1e-14), vals) roots = ia_solve(expr, x; complex_roots = false, periodic_roots = false) @@ -550,7 +553,7 @@ end lhs_ia = ia_solve(expr, x)[1] lhs_att = Symbolics.attract_and_solve_sqrtpoly(expr, x)[1] lhs_solve = symbolic_solve(expr, x)[1] - @test all(isequal(answer, 3) for answer in [lhs_ia, lhs_att, lhs_solve]) + @test all(isequal(value(answer), 3) for answer in [lhs_ia, lhs_att, lhs_solve]) expr = x^2 + x + sqrt(x) + 2 lhs = sort_roots(eval.(Symbolics.toexpr.(ia_solve(expr, x)))) @@ -624,7 +627,7 @@ using LambertW @test correctAns(symbolic_solve(exp(x^2)~7,x),[-sqrt(log(7.0)),sqrt(log(7.0))]) root = symbolic_solve(sin(x+3)~1//3,x)[1] - var = Symbolics.get_variables(root)[1] + var = first(Symbolics.get_variables(root)) root = Symbolics.ssubs(root, Dict(var=>0)) @test correctAns([root],[asin(1.0/3.0)-3.0]) diff --git a/test/stencils.jl b/test/stencils.jl index 1940dd30c..5e3d99c7b 100644 --- a/test/stencils.jl +++ b/test/stencils.jl @@ -28,35 +28,35 @@ end test_funcs("transpose-term", @arrayop((i, j), x[j, i], term=x'), x) # Partial view set to an arrayop - @makearray y[1:6, 1:6] begin - y[2:end-1, 2:end-1] => @arrayop (i, j) x[j, i] - end - @test isequal(scalarize(y[2,3]), x[2,1]) - @test isequal(scalarize(y[3,2]), x[1,2]) + # @makearray y[1:6, 1:6] begin + # y[2:end-1, 2:end-1] => @arrayop (i, j) x[j, i] + # end + # @test isequal(scalarize(y[2,3]), x[2,1]) + # @test isequal(scalarize(y[3,2]), x[1,2]) - # Test UndefRef is thrown - @test_throws UndefRefError scalarize(y[1,1]) + # # Test UndefRef is thrown + # @test_throws UndefRefError scalarize(y[1,1]) - test_funcs("stencil-transpose-arrayop", y, x) + # test_funcs("stencil-transpose-arrayop", y, x) - # Fill zero - @makearray y[1:6, 1:6] begin - y[:, :] => 0 - y[2:end-1, 2:end-1] => x .+ x' .+ 1 - end + # # Fill zero + # @makearray y[1:6, 1:6] begin + # y[:, :] => 0 + # y[2:end-1, 2:end-1] => x .+ x' .+ 1 + # end - @test iszero(scalarize(y[1,1])) + # @test iszero(scalarize(y[1,1])) - test_funcs("stencil-broadcast", y, x) + # test_funcs("stencil-broadcast", y, x) - @variables x[1:5, 1:5] - @makearray y[1:5, 1:5] begin - y[:, :] => 0 - y[2:end-1, 2:end-1] => @arrayop (i, j) (x[i+1,j] + x[i-1, j] + x[i, j+1] + x[i, j-1])/2 - end + # @variables x[1:5, 1:5] + # @makearray y[1:5, 1:5] begin + # y[:, :] => 0 + # y[2:end-1, 2:end-1] => @arrayop (i, j) (x[i+1,j] + x[i-1, j] + x[i, j+1] + x[i, j-1])/2 + # end - @test iszero(scalarize(y[1,1])) - test_funcs("stencil-extents", y, x) + # @test iszero(scalarize(y[1,1])) + # test_funcs("stencil-extents", y, x) @variables u[1:5, 1:5] n = 5 diff --git a/test/struct.jl b/test/struct.jl deleted file mode 100644 index 8eef42dc1..000000000 --- a/test/struct.jl +++ /dev/null @@ -1,26 +0,0 @@ -using Test, Symbolics -using Symbolics: symstruct, juliatype, symbolic_getproperty, symbolic_setproperty!, symbolic_constructor - -struct Jörgen - a::Int - b::Float64 -end - -S = symstruct(Jörgen) -@variables x::S -xa = Symbolics.unwrap(symbolic_getproperty(x, :a)) -@test Symbolics.symtype(xa) == Int -@test Symbolics.operation(xa) == Symbolics.typed_getfield -@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Val{:a}()]) -xa = Symbolics.unwrap(symbolic_setproperty!(x, :a, 10)) -@test Symbolics.operation(xa) == setfield! -@test isequal(Symbolics.arguments(xa), [Symbolics.unwrap(x), Meta.quot(:a), 10]) -@test Symbolics.symtype(xa) == Int - -xb = Symbolics.unwrap(symbolic_setproperty!(x, :b, 10)) -@test Symbolics.operation(xb) == setfield! -@test isequal(Symbolics.arguments(xb), [Symbolics.unwrap(x), Meta.quot(:b), 10]) -@test Symbolics.symtype(xb) == Float64 - -s = Symbolics.symbolic_constructor(S, 1, 1.0) -@test Symbolics.symtype(s) == S diff --git a/test/symbolic_indexing_interface_symbolic_evaluate.jl b/test/symbolic_indexing_interface_symbolic_evaluate.jl index 251bef619..2bc7a1a58 100644 --- a/test/symbolic_indexing_interface_symbolic_evaluate.jl +++ b/test/symbolic_indexing_interface_symbolic_evaluate.jl @@ -9,6 +9,7 @@ bar(x, p) = p * x @register_array_symbolic bar(x::AbstractVector, p::AbstractMatrix) begin size = size(x) eltype = promote_type(eltype(x), eltype(p)) + ndims = 1 end D = Differential(t) @@ -17,17 +18,17 @@ expr1 = x + y + D(x) @test isequal(symbolic_evaluate(expr1, Dict(x => 3)), 3 + y + D(3)) @test isequal(symbolic_evaluate(expr1, Dict(x => 3); operator = Operator), 3 + y + D(x)) @test isequal(symbolic_evaluate(expr1, Dict(x => 1, D(x) => 2)), y + 3) -@test symbolic_evaluate(expr1, Dict(x => 1, D(x) => 2, y => 3)) == 6 +@test value(symbolic_evaluate(expr1, Dict(x => 1, D(x) => 2, y => 3))) == 6 @test isequal(symbolic_evaluate(expr1, Dict(x => 3, y => 3x), operator = Operator), 12 + D(x)) -@test symbolic_evaluate(expr1, Dict(x => 3, y => 3x, D(x) => 2)) == 14 +@test value(symbolic_evaluate(expr1, Dict(x => 3, y => 3x, D(x) => 2))) == 14 expr2 = bar(q, p) @test isequal(symbolic_evaluate(expr2, Dict(p => ones(3, 3))), bar(q, ones(3, 3))) -@test symbolic_evaluate(expr2, Dict(p => ones(3, 3), q => ones(3))) == 3ones(3) +@test value(symbolic_evaluate(expr2, Dict(p => ones(3, 3), q => ones(3)))) == 3ones(3) expr3 = bar(3q, 3p) @test isequal(symbolic_evaluate(expr3, Dict(p => ones(3, 3))), bar(3q, 3ones(3, 3))) -@test symbolic_evaluate(expr3, Dict(p => ones(3, 3), q => ones(3))) == 27ones(3) +@test value(symbolic_evaluate(expr3, Dict(p => ones(3, 3), q => ones(3)))) == 27ones(3) expr4 = D(x) ~ 3x + y @test isequal(symbolic_evaluate(expr4, Dict(x => 3)), D(3) ~ 9 + y) diff --git a/test/symbolic_indexing_interface_trait.jl b/test/symbolic_indexing_interface_trait.jl index 3edd20183..6dee427e7 100644 --- a/test/symbolic_indexing_interface_trait.jl +++ b/test/symbolic_indexing_interface_trait.jl @@ -2,10 +2,9 @@ using Symbolics using SymbolicUtils using SymbolicIndexingInterface -@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num, Symbolics.CallWithMetadata]) .== +@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num, Symbolics.CallAndWrap]) .== (ScalarSymbolic(),)) -@test all(symbolic_type.([Symbolics.ArrayOp, Symbolics.Arr]) .== - (ArraySymbolic(),)) +@test symbolic_type(Symbolics.Arr) == ArraySymbolic() @variables x @test symbolic_type(x) == ScalarSymbolic() @variables y[1:3] @@ -20,9 +19,9 @@ using SymbolicIndexingInterface subs = Dict(x => 0.1, y => 2z) subs2 = merge(subs, Dict(z => 2x+3)) -@test symbolic_evaluate(x, subs) == 0.1 +@test Symbolics.value(symbolic_evaluate(x, subs)) == 0.1 @test isequal(symbolic_evaluate(y, subs), 2z) -@test symbolic_evaluate(y, subs2) == 6.4 +@test Symbolics.value(symbolic_evaluate(y, subs2)) == 6.4 @testset "`hasname` for `getindex`ed trees" begin @variables x[1:2] y[1:2] diff --git a/test/taylor.jl b/test/taylor.jl index d183ac1cc..9f9b2bd8d 100644 --- a/test/taylor.jl +++ b/test/taylor.jl @@ -20,27 +20,27 @@ Y, = @variables y[ns] # https://en.wikipedia.org/wiki/Taylor_series#List_of_Maclaurin_series_of_some_common_functions @variables x -@test taylor(exp(x), x, 0:9) - sum(x^n//factorial(n) for n in 0:9) == 0 -@test taylor(log(1-x), x, 0:9) - sum(-x^n/n for n in 1:9) == 0 -@test taylor(log(1+x), x, 0:9) - sum((-1)^(n+1)*x^n/n for n in 1:9) == 0 +@test expand(taylor(exp(x), x, 0:9) - sum(x^n/factorial(n) for n in 0:9)) == 0 +@test expand(taylor(log(1-x), x, 0:9) - sum(-x^n/n for n in 1:9)) == 0 +@test expand(taylor(log(1+x), x, 0:9) - sum((-1)^(n+1)*x^n/n for n in 1:9)) == 0 -@test taylor(1/(1-x), x, 0:9) - sum(x^n for n in 0:9) == 0 -@test taylor(1/(1-x)^2, x, 0:8) - sum(n * x^(n-1) for n in 1:9) == 0 -@test taylor(1/(1-x)^3, x, 0:7) - sum((n-1)*n*x^(n-2)/2 for n in 2:9) == 0 +@test expand(taylor(1/(1-x), x, 0:9) - sum(x^n for n in 0:9)) == 0 +@test expand(taylor(1/(1-x)^2, x, 0:8) - sum(n * x^(n-1) for n in 1:9)) == 0 +@test expand(taylor(1/(1-x)^3, x, 0:7) - sum((n-1)*n*x^(n-2)/2 for n in 2:9)) == 0 for α in (-1//2, 0, 1//2, 1, 2, 3) - @test taylor((1+x)^α, x, 0:7) - sum(binomial(α, n)*x^n for n in 0:7) == 0 + @test expand(taylor((1+x)^α, x, 0:7) - sum(binomial(α, n)*x^n for n in 0:7)) == 0 end -@test taylor(sin(x), x, 0:7) - sum((-1)^n/factorial(2*n+1) * x^(2*n+1) for n in 0:3) == 0 -@test taylor(cos(x), x, 0:7) - sum((-1)^n/factorial(2*n) * x^(2*n) for n in 0:3) == 0 -@test taylor(tan(x), x, 0:7) - taylor(taylor(sin(x), x, 0:7) / taylor(cos(x), x, 0:7), x, 0:7) == 0 -@test taylor(asin(x), x, 0:7) - sum(factorial(2*n)/(4^n*factorial(n)^2*(2*n+1)) * x^(2*n+1) for n in 0:3) == 0 -@test taylor(acos(x), x, 0:7) - (π/2 - taylor(asin(x), x, 0:7)) == 0 # TODO: make π/2 a proper fraction (like Num(π)/2) -@test taylor(atan(x), x, 0:7) - taylor(asin(x/√(1+x^2)), x, 0:7) == 0 +@test expand(taylor(sin(x), x, 0:7) - sum((-1)^n/factorial(2*n+1) * x^(2*n+1) for n in 0:3)) == 0 +@test expand(taylor(cos(x), x, 0:7) - sum((-1)^n/factorial(2*n) * x^(2*n) for n in 0:3)) == 0 +@test expand(taylor(tan(x), x, 0:7) - taylor(taylor(sin(x), x, 0:7) / taylor(cos(x), x, 0:7), x, 0:7)) == 0 +@test expand(taylor(asin(x), x, 0:7) - sum(factorial(2*n)/(4^n*factorial(n)^2*(2*n+1)) * x^(2*n+1) for n in 0:3)) == 0 +@test expand(taylor(acos(x), x, 0:7) - (π/2 - taylor(asin(x), x, 0:7))) == 0 # TODO: make π/2 a proper fraction (like Num(π)/2) +@test expand(taylor(atan(x), x, 0:7) - taylor(asin(x/√(1+x^2)), x, 0:7)) == 0 -@test taylor(sinh(x), x, 0:7) - sum(1/factorial(2*n+1) * x^(2*n+1) for n in 0:3) == 0 -@test taylor(cosh(x), x, 0:7) - sum(1/factorial(2*n) * x^(2*n) for n in 0:3) == 0 -@test taylor(tanh(x), x, 0:7) - (x - x^3/3 + 2/15*x^5 - 17/315*x^7) == 0 +@test expand(taylor(sinh(x), x, 0:7) - sum(1/factorial(2*n+1) * x^(2*n+1) for n in 0:3)) == 0 +@test expand(taylor(cosh(x), x, 0:7) - sum(1/factorial(2*n) * x^(2*n) for n in 0:3)) == 0 +@test expand(taylor(tanh(x), x, 0:7) - (x - x^3/3 + 2/15*x^5 - 17/315*x^7)) == 0 # around x ≠ 0 @test substitute(taylor(√(x), x, 1, 0:6), x => x + 1) - taylor(√(1+x), x, 0:6) == 0 diff --git a/test/utils.jl b/test/utils.jl index c3c8e659d..4632484fe 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,8 +1,9 @@ using Symbolics import Symbolics: symbolic_to_float, var_from_nested_derivative, unwrap, - isblock, flatten_expr!, build_expr, get_variables, get_differential_vars, + isblock, flatten_expr!, get_variables, get_differential_vars, is_singleton, diff2term, tosymbol, lower_varname, makesubscripts, degree, coeff +using SymbolicUtils: symtype using SparseArrays using Test @@ -14,24 +15,24 @@ using Test @test length(vars1) == 3 @test allunique(vars1) - sorted_vars1 = Symbolics.get_variables(ex1; sort = true) - @test isequal(sorted_vars1, [x, y, z]) + sorted_vars1 = Symbolics.get_variables(ex1) + @test isequal(sorted_vars1, Set([x, y, z])) ex2 = x - y vars2 = Symbolics.get_variables(ex2) @test length(vars2) == 2 @test allunique(vars2) - sorted_vars2 = Symbolics.get_variables(ex2; sort = true) - @test isequal(sorted_vars2, [x, y]) + sorted_vars2 = Symbolics.get_variables(ex2) + @test isequal(sorted_vars2, Set([x, y])) - @variables c(..) + @variables c(::Real) ex3 = c(x) + c(t) - c(c(t) + y) vars3 = Symbolics.get_variables(ex3) @test length(vars3) == 4 - sorted_vars3 = Symbolics.get_variables(ex3; sort = true) - @test isequal(sorted_vars3, [c.f, t, x, y]) + sorted_vars3 = Symbolics.get_variables(ex3) + @test isequal(sorted_vars3, Set([c, t, x, y])) end @testset "symbolic_to_float" begin @@ -46,10 +47,10 @@ end @testset "var_from_nested_derivative" begin @variables t x(t) p(..) D = Differential(t) - @test var_from_nested_derivative(x) == (x, 0) - @test var_from_nested_derivative(D(x)) == (x, 1) - @test var_from_nested_derivative(p) == (p, 0) - @test var_from_nested_derivative(D(p(x))) == (p(x), 1) + @test all(isequal.(var_from_nested_derivative(x), (x, 0))) + @test all(isequal.(var_from_nested_derivative(D(x)), (x, 1))) + @test all(isequal.(var_from_nested_derivative(p), (p, 0))) + @test all(isequal.(var_from_nested_derivative(D(p(x))), (p(x), 1))) end @testset "fixpoint_sub maxiters" begin @@ -64,8 +65,8 @@ end @variables p(..) x y arg = unwrap(substitute(p(x), [p => identity])) @test iscall(arg) && operation(arg) == identity && isequal(only(arguments(arg)), x) - @test unwrap(substitute(p(x), [p => sqrt, x => 4.0])) ≈ 2.0 - arg = Symbolics.fixpoint_sub(p(x), [p => sqrt, x => 2y + 3, y => 1.0 + p(4)]) + @test unwrap_const(unwrap(substitute(p(x), [p => sqrt, x => 4.0]))) ≈ 2.0 + arg = unwrap_const(Symbolics.fixpoint_sub(p(x), [p => sqrt, x => 2y + 3, y => 1.0 + p(4)])) @test arg ≈ 3.0 end @@ -87,12 +88,6 @@ end @test flatten_expr!(expr2.args) == Any[:(x + y), :z] end -@testset "build_expr" begin - expr = build_expr(:block, [:(x + y), :(y + z)]) - @test expr.head == :block - @test expr.args == [:(x + y), :(y + z)] -end - @testset "is_singleton" begin @test is_singleton(x) == false @test is_singleton(sin(x)) == false @@ -134,15 +129,6 @@ end @test coeff(expr2, x^2) == 1 end -@testset "makesubscripts" begin - sub1 = makesubscripts(5) - @test length(sub1) == 5 - @test typeof(sub1[1]) == SymbolicUtils.BasicSymbolic{Int64} - - sub2 = makesubscripts(10) - @test length(sub2) == 10 -end - @testset "diff2term" begin @variables x t u(x, t) z(t) Dt = Differential(t) @@ -154,7 +140,7 @@ end test_nested_derivative = Dx(Dt(Dt(u))) result = diff2term(Symbolics.value(test_nested_derivative)) - @test typeof(result) === Symbolics.BasicSymbolic{Real} + @test symtype(result) === Real @testset "staged diff2term on arrays" begin @variables t x(t)[1:2] @@ -244,7 +230,7 @@ end @test any(isequal(Symbolics.value(Dx(u))), diff_vars_eq) end -@testset "`fast_substitute` inside array symbolics" begin +@testset "`substitute` inside array symbolics" begin @variables x y z @register_symbolic foo(a::AbstractArray, b) ex = foo([x, y], z) @@ -252,25 +238,24 @@ end @test isequal(ex2, foo([x, 1.0], 2.0)) end -@testset "`fast_substitute` of subarray symbolics" begin +@testset "`substitute` of subarray symbolics" begin @variables p[1:4] q[1:5] - @test isequal(p[1:2], Symbolics.fast_substitute(p[1:2], Dict())) - @test isequal(p[1:2], Symbolics.fast_substitute(p[1:2], p => p)) - @test isequal(q[1:2], Symbolics.fast_substitute(p[1:2], Dict(p => q))) - @test isequal(q[1:2], Symbolics.fast_substitute(p[1:2], p => q)) + @test isequal(p[1:2], substitute(p[1:2], Dict())) + @test isequal(p[1:2], substitute(p[1:2], p => p)) + @test isequal(q[1:2], substitute(p[1:2], Dict(p => q))) + @test isequal(q[1:2], substitute(p[1:2], p => q)) end -@testset "`fast_substitute` folding `getindex`" begin +@testset "`substitute` folding `getindex`" begin @variables x[1:3] - @test isequal(Symbolics.fast_substitute(x[1], Dict(unwrap(x) => collect(unwrap(x)))), x[1]) - @test isequal(Symbolics.fast_substitute(x[1], unwrap(x) => collect(unwrap(x))), x[1]) + @test isequal(substitute(x[1], Dict(unwrap(x) => collect(unwrap(x)))), x[1]) + @test isequal(substitute(x[1], unwrap(x) => collect(unwrap(x))), x[1]) end -@testset "`fixpoint_sub` and `fast_substitute` on sparse arrays" begin +@testset "`fixpoint_sub` and `substitute` on sparse arrays" begin @variables x y z mat = Num[x 0 0; 0 y 0; 0 0 z] mat = sparse(mat) - mat = unwrap.(mat) rules = Dict(x => y, y => z, z => 1) res = Symbolics.fixpoint_sub(mat, rules) @test res isa SparseMatrixCSC @@ -288,12 +273,12 @@ end @testset "factors and terms" begin @variables x y z - @test Set(factors(0)) == Set([0]) - @test Set(factors(1)) == Set([1]) + @test Set(factors(0)) == Set(Num[0]) + @test Set(factors(1)) == Set(Num[1]) @test Set(factors(x)) == Set([x]) @test Set(factors(x*y*z)) == Set([x, y, z]) - @test Set(terms(0)) == Set([0]) + @test Set(terms(0)) == Set(Num[0]) @test Set(terms(x)) == Set([x]) @test Set(terms(x + y + z)) == Set([x, y, z]) @test Set(terms(-x - y + z)) == Set([-x, -y, z]) @@ -312,6 +297,3 @@ end @test Symbolics.evaluate(ltr, Dict(x => 1, y => 2)) @test !Symbolics.evaluate(ltr, Dict(x => 2, y => 1)) end - - -