From 57dc9d0df541925a473f49670ad419add75e58c7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 15 Sep 2025 11:45:55 +0530 Subject: [PATCH 01/28] refactor: remove trivial usages of `Symbolic{T}` --- src/Symbolics.jl | 6 +++--- src/build_function.jl | 2 +- src/diff.jl | 13 +++++------- src/domains.jl | 2 +- src/extra_functions.jl | 41 +++++--------------------------------- src/latexify_recipes.jl | 6 +++--- src/linear_algebra.jl | 14 ++++++------- src/num.jl | 44 +++++++++++------------------------------ src/register.jl | 4 +--- src/rewrite-helpers.jl | 16 +++++++-------- src/semipoly.jl | 13 ++++++------ src/utils.jl | 18 ++++++++--------- src/variable.jl | 23 ++++++--------------- src/wrapper-types.jl | 40 ++++++++++++++++++++++++------------- 14 files changed, 93 insertions(+), 149 deletions(-) diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 17ecb3d80..afc62c9f1 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -26,7 +26,7 @@ import DomainSets: Domain, DomainSets using TermInterface import TermInterface: maketerm, iscall, operation, arguments, metadata -import SymbolicUtils: Term, Add, Mul, Pow, Sym, Div, BasicSymbolic, +import SymbolicUtils: Term, Add, Mul, Sym, Div, BasicSymbolic, Const, FnType, @rule, Rewriters, substitute, symtype, promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv @@ -177,7 +177,7 @@ export limit # Hacks to make wrappers "nicer" const NumberTypes = Union{AbstractFloat,Integer,Complex{<:AbstractFloat},Complex{<:Integer}} -(::Type{T})(x::SymbolicUtils.Symbolic) where {T<:NumberTypes} = throw(ArgumentError("Cannot convert Sym to $T since Sym is symbolic and $T is concrete. Use `substitute` to replace the symbolic unwraps.")) +(::Type{T})(x::SymbolicUtils.BasicSymbolic) where {T<:NumberTypes} = throw(ArgumentError("Cannot convert Sym to $T since Sym is symbolic and $T is concrete. Use `substitute` to replace the symbolic unwraps.")) for T in [Num, Complex{Num}] @eval begin #(::Type{S})(x::$T) where {S<:Union{NumberTypes,AbstractArray}} = S(Symbolics.unwrap(x))::S @@ -197,7 +197,7 @@ for T in [Num, Complex{Num}] Broadcast.broadcastable(x::$T) = x end - for S in [:(Symbolic{<:FnType}), :CallWithMetadata] + for S in [:(BasicSymbolic{<:FnType}), :CallWithMetadata] @eval (f::$S)(x::$T, y...) = wrap(f(unwrap(x), unwrap.(y)...)) end end diff --git a/src/build_function.jl b/src/build_function.jl index 3045dd389..8a7ee48e9 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -603,7 +603,7 @@ function buildvarnumbercache(args...) return Dict(varnumsdict) end -function numbered_expr(O::Symbolic,varnumbercache,args...;varordering = args[1],offset = 0, +function numbered_expr(O::BasicSymbolic,varnumbercache,args...;varordering = args[1],offset = 0, states = LazyState(), lhsname=:du,rhsnames=[Symbol("MTK$i") for i in 1:length(args)]) O = value(O) diff --git a/src/diff.jl b/src/diff.jl index 08a73268d..ddeff688d 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -63,9 +63,6 @@ Base.nameof(D::Differential) = :Differential Base.:(==)(D1::Differential, D2::Differential) = isequal(D1.x, D2.x) Base.hash(D::Differential, u::UInt) = hash(D.x, xor(u, 0xdddddddddddddddd)) -_isfalse(occ::Bool) = occ === false -_isfalse(occ::Symbolic) = iscall(occ) && _isfalse(operation(occ)) - """ $(TYPEDSIGNATURES) @@ -237,7 +234,7 @@ passed differential and not any other Differentials it encounters. # Arguments - `D::Differential`: The differential to apply -- `arg::Symbolic`: The symbolic expression to apply the differential on. +- `arg::BasicSymbolic`: The symbolic expression to apply the differential on. - `simplify::Bool=false`: Whether to simplify the resulting expression using [`SymbolicUtils.simplify`](@ref). - `occurrences=nothing`: Information about the occurrences of the independent @@ -349,7 +346,7 @@ function executediff(D, arg, simplify=false; throw_no_derivative=false) if _iszero(x) continue - elseif x isa Symbolic + elseif x isa BasicSymbolic push!(exprs, x) else c += x @@ -376,7 +373,7 @@ This function recursively traverses a symbolic expression, applying the chain ru and other derivative rules to expand any derivatives it encounters. # Arguments -- `O::Symbolic`: The symbolic expression to expand. +- `O::BasicSymbolic`: The symbolic expression to expand. - `simplify::Bool=false`: Whether to simplify the resulting expression using [`SymbolicUtils.simplify`](@ref). @@ -398,7 +395,7 @@ julia> dfx = expand_derivatives(Dx(f)) (k*((2abs(x - y)) / y - 2z)*ifelse(signbit(x - y), -1, 1)) / y ``` """ -function expand_derivatives(O::Symbolic, simplify=false; throw_no_derivative=false) +function expand_derivatives(O::BasicSymbolic, simplify=false; throw_no_derivative=false) if iscall(O) && isa(operation(O), Differential) arg = only(arguments(O)) arg = expand_derivatives(arg, false; throw_no_derivative) @@ -461,7 +458,7 @@ sin(x) ``` """ derivative_idx(O::Any, ::Any) = 0 -function derivative_idx(O::Symbolic, idx) +function derivative_idx(O::BasicSymbolic, idx) iscall(O) ? derivative(operation(O), (arguments(O)...,), Val(idx)) : 0 end diff --git a/src/domains.jl b/src/domains.jl index 02d23f3bb..c3f07c6c6 100644 --- a/src/domains.jl +++ b/src/domains.jl @@ -6,7 +6,7 @@ struct VarDomainPairing domain::Domain end -const DomainedVar = Union{Symbolic{<:Number}, Num} +const DomainedVar = Union{BasicSymbolic, Num} Base.:∈(variable::DomainedVar,domain::Domain) = VarDomainPairing(value(variable),domain) Base.:∈(variable::DomainedVar,domain::Interval) = VarDomainPairing(value(variable),domain) diff --git a/src/extra_functions.jl b/src/extra_functions.jl index 0adc291e3..d7e2359ac 100644 --- a/src/extra_functions.jl +++ b/src/extra_functions.jl @@ -1,39 +1,8 @@ -@register_symbolic Base.binomial(n, k)::Int true -function _binomial(nothing, n, k) - begin - args = [n, k] - unwrapped_args = map(Symbolics.unwrap, args) - res = if !(any((x->begin - SymbolicUtils.issym(x) || SymbolicUtils.iscall(x) - end), unwrapped_args)) - Base.binomial(unwrapped_args...) - else - SymbolicUtils.Term{Int}(Base.binomial, unwrapped_args) - end - if typeof.(args) == typeof.(unwrapped_args) - return res - else - return Symbolics.wrap(res) - end - end -end - -for (T1, T2) in ((Symbolics.SymbolicUtils.Symbolic{<:Real}, Int64), - (Num, Int64), - (Real, Symbolics.SymbolicUtils.Symbolic{<:Int64}), - (Symbolics.SymbolicUtils.Symbolic{<:Real}, Symbolics.SymbolicUtils.Symbolic{<:Int64}), - (Num, Symbolics.SymbolicUtils.Symbolic{<:Int64})) - - @eval function Base.binomial(n::$T1, k::$T2) - if any(Symbolics.iswrapped, (n, k)) - Symbolics.wrap(_binomial(nothing, Symbolics.unwrap(n), Symbolics.unwrap(k))) - else - _binomial(nothing, n, k) - end - end -end +@register_symbolic Base.binomial(n::Number, k::Integer)::Integer false +SymbolicUtils.promote_symtype(::typeof(binomial), ::Type{T}, ::Type{S}) where {T <: Number, S <: Integer} = T -@register_symbolic Base.sign(x)::Int +@register_symbolic Base.sign(x) false +SymbolicUtils.promote_symtype(::typeof(sign), ::Type{T}) where {T <: Number} = T derivative(::typeof(sign), args::NTuple{1,Any}, ::Val{1}) = 0 @register_symbolic Base.signbit(x)::Bool @@ -59,7 +28,7 @@ end @register_symbolic Base.ceil(x) @register_symbolic Base.floor(x) -@register_symbolic Base.factorial(x) +@register_symbolic Base.factorial(x::Integer)::Integer function derivative(::Union{typeof(ceil),typeof(floor),typeof(factorial)}, args::NTuple{1,Any}, ::Val{1}) zero(args[1]) diff --git a/src/latexify_recipes.jl b/src/latexify_recipes.jl index 30d0b5ca6..37dc90676 100644 --- a/src/latexify_recipes.jl +++ b/src/latexify_recipes.jl @@ -109,7 +109,7 @@ end return n.f end -@latexrecipe function f(n::Symbolic) +@latexrecipe function f(n::BasicSymbolic) env --> :equation mult_symbol --> "" index --> :subscript @@ -133,7 +133,7 @@ end env --> :equation index --> :subscript - if hide_lhs(eq.lhs) || !(eq.lhs isa Union{Number, AbstractArray, Symbolic}) + if hide_lhs(eq.lhs) || !(eq.lhs isa Union{Number, AbstractArray, BasicSymbolic}) return eq.rhs else return Expr(:(=), Num(eq.lhs), Num(eq.rhs)) @@ -146,7 +146,7 @@ end end Base.show(io::IO, ::MIME"text/latex", x::RCNum) = print(io, "\$\$ " * latexify(x) * " \$\$") -Base.show(io::IO, ::MIME"text/latex", x::Symbolic) = print(io, "\$\$ " * latexify(x) * " \$\$") +Base.show(io::IO, ::MIME"text/latex", x::BasicSymbolic) = print(io, "\$\$ " * latexify(x) * " \$\$") Base.show(io::IO, ::MIME"text/latex", x::Equation) = print(io, "\$\$ " * latexify(x) * " \$\$") Base.show(io::IO, ::MIME"text/latex", x::Vector{Equation}) = print(io, "\$\$ " * latexify(x) * " \$\$") Base.show(io::IO, ::MIME"text/latex", x::AbstractArray{<:RCNum}) = print(io, "\$\$ " * latexify(x) * " \$\$") diff --git a/src/linear_algebra.jl b/src/linear_algebra.jl index aef33569f..1ada1ba04 100644 --- a/src/linear_algebra.jl +++ b/src/linear_algebra.jl @@ -7,7 +7,7 @@ function nterms(t) end # Soft pivoted -# Note: we call this function with a matrix of Union{SymbolicUtils.Symbolic, Any} +# Note: we call this function with a matrix of Union{SymbolicUtils.BasicSymbolic, Any} function sym_lu(A; check=true) SINGULAR = typemax(Int) m, n = size(A) @@ -28,7 +28,7 @@ function sym_lu(A; check=true) p[k] = kp - if amin == SINGULAR && !(amin isa Symbolic) && (amin isa Number) && iszero(info) + if amin == SINGULAR && !(amin isa BasicSymbolic) && (amin isa Number) && iszero(info) info = k end @@ -132,7 +132,7 @@ function _solve(A::AbstractMatrix, b::AbstractArray, do_simplify) do_simplify ? SymbolicUtils.simplify_fractions.(sol) : sol end -LinearAlgebra.ldiv!(A::UpperTriangular{<:Union{Symbolic,RCNum}}, b::AbstractVector{<:Union{Symbolic,RCNum}}, x::AbstractVector{<:Union{Symbolic,RCNum}} = b) = symsub!(A, b, x) +LinearAlgebra.ldiv!(A::UpperTriangular{<:Union{BasicSymbolic,RCNum}}, b::AbstractVector{<:Union{BasicSymbolic,RCNum}}, x::AbstractVector{<:Union{BasicSymbolic,RCNum}} = b) = symsub!(A, b, x) function symsub!(A::UpperTriangular, b::AbstractVector, x::AbstractVector = b) LinearAlgebra.require_one_based_indexing(A, b, x) n = size(A, 2) @@ -152,7 +152,7 @@ function symsub!(A::UpperTriangular, b::AbstractVector, x::AbstractVector = b) x end -LinearAlgebra.ldiv!(A::UnitLowerTriangular{<:Union{Symbolic,RCNum}}, b::AbstractVector{<:Union{Symbolic,RCNum}}, x::AbstractVector{<:Union{Symbolic,RCNum}} = b) = symsub!(A, b, x) +LinearAlgebra.ldiv!(A::UnitLowerTriangular{<:Union{BasicSymbolic,RCNum}}, b::AbstractVector{<:Union{BasicSymbolic,RCNum}}, x::AbstractVector{<:Union{BasicSymbolic,RCNum}} = b) = symsub!(A, b, x) function symsub!(A::UnitLowerTriangular, b::AbstractVector, x::AbstractVector = b) LinearAlgebra.require_one_based_indexing(A, b, x) n = size(A, 2) @@ -220,7 +220,7 @@ end function LinearAlgebra.norm(x::AbstractArray{<:RCNum}, p::Real=2) p = value(p) - issym = p isa Symbolic + issym = p isa BasicSymbolic if !issym && p == 2 sqrt(sum(x->abs2(x), x)) elseif !issym && isone(p) @@ -258,7 +258,7 @@ function linear_expansion(ts::AbstractArray, xs::AbstractArray) @label FINISH return A, bvec, islinear end -# _linear_expansion always returns `Symbolic` +# _linear_expansion always returns `BasicSymbolic` function _linear_expansion(t::Equation, x) a₂, b₂, islinear = linear_expansion(t.rhs, x) islinear || return (a₂, b₂, false) @@ -272,7 +272,7 @@ is_expansion_leaf(t) = !iscall(t) || (operation(t) isa Operator) @noinline expansion_check(op) = op isa Operator && error("The operation is an Operator. This should never happen.") function _linear_expansion(t, x) t = value(t) - t isa Symbolic || return (0, t, true) + t isa BasicSymbolic || return (0, t, true) x = value(x) is_expansion_leaf(t) && return trivial_linear_expansion(t, x) isequal(t, x) && return (1, 0, true) diff --git a/src/num.jl b/src/num.jl index 9eaf4af7b..1b6795f2d 100644 --- a/src/num.jl +++ b/src/num.jl @@ -110,7 +110,7 @@ Base.nameof(n::Num) = nameof(value(n)) Base.iszero(x::Num) = SymbolicUtils.fraction_iszero(unwrap(x)) Base.isone(x::Num) = SymbolicUtils.fraction_isone(unwrap(x)) -import SymbolicUtils: <ₑ, Symbolic, Term, operation, arguments +import SymbolicUtils: <ₑ, Term, operation, arguments function Base.show(io::IO, n::Num) show_numwrap[] ? print(io, :(Num($(value(n))))) : Base.show(io, value(n)) @@ -118,20 +118,7 @@ end Base.promote_rule(::Type{<:Number}, ::Type{<:Num}) = Num Base.promote_rule(::Type{BigFloat}, ::Type{<:Num}) = Num -Base.promote_rule(::Type{<:Symbolic{<:Number}}, ::Type{<:Num}) = Num -function Base.getproperty(t::Union{Add, Mul, Pow, Term}, f::Symbol) - if f === :op - Base.depwarn( - "`x.op` is deprecated, use `operation(x)` instead", :getproperty) - operation(t) - elseif f === :args - Base.depwarn("`x.args` is deprecated, use `arguments(x)` instead", - :getproperty) - arguments(t) - else - getfield(t, f) - end -end +Base.promote_rule(::Type{<:BasicSymbolic}, ::Type{<:Num}) = Num <ₑ(s::Num, x) = value(s) <ₑ value(x) <ₑ(s, x::Num) = value(s) <ₑ value(x) <ₑ(s::Num, x::Num) = value(s) <ₑ value(x) @@ -179,12 +166,8 @@ end @num_method Base.isequal begin va = value(a) vb = value(b) - if va isa SymbolicUtils.BasicSymbolic{Real} && vb isa SymbolicUtils.BasicSymbolic{Real} - isequal(va, vb)::Bool - else - isequal(va, vb)::Bool - end -end (AbstractFloat, Number, Symbolic) + isequal(va, vb)::Bool +end (AbstractFloat, Number, BasicSymbolic) # Provide better error message for symbolic variables in ranges function Base.:(:)(a::Integer, b::Num) @@ -251,26 +234,23 @@ Base.to_index(x::Num) = Base.to_index(value(x)) Base.hash(x::Num, h::UInt) = hash(value(x), h)::UInt -Base.convert(::Type{Num}, x::Symbolic{<:Number}) = Num(x) -Base.convert(::Type{Num}, x::Number) = Num(x) +function Base.convert(::Type{Num}, x::BasicSymbolic) + symtype(x) <: Real || error("`symtype` must be `<:Real`") + Num(x) +end +# TODO: `Const{T}` instead of `Const{SymReal}` +Base.convert(::Type{Num}, x::Number) = Num(Const{SymReal}(x)) Base.convert(::Type{Num}, x::Num) = x Base.convert(::Type{T}, x::AbstractArray{Num}) where {T <: Array{Num}} = T(map(Num, x)) -function Base.convert(::Type{Sym}, x::Num) - value(x) isa Sym ? value(x) : error("cannot convert $x to Sym") -end function LinearAlgebra.lu( x::Union{Adjoint{<:RCNum}, Transpose{<:RCNum}, Array{<:RCNum}}; check = true, kw...) sym_lu(x; check = check) end -_iszero(x::Number) = iszero(x) -_isone(x::Number) = isone(x) -_iszero(::Symbolic) = false -_isone(::Symbolic) = false -_iszero(x::Num) = _iszero(value(x))::Bool -_isone(x::Num) = _isone(value(x))::Bool +SymbolicUtils._iszero(x::Num) = SymbolicUtils._iszero(value(x)) +SymbolicUtils._isone(x::Num) = SymbolicUtils._isone(value(x)) Code.cse(x::Num) = Code.cse(unwrap(x)) diff --git a/src/register.jl b/src/register.jl index 8d1c01dc1..e946530a5 100644 --- a/src/register.jl +++ b/src/register.jl @@ -1,5 +1,3 @@ -using SymbolicUtils: Symbolic - """ @register_symbolic(expr, define_promotion = true, Ts = [Real]) @@ -92,7 +90,7 @@ function is_symbolic_or_array_of_symbolic(arr::AbstractArray) end symbolic_eltype(x) = eltype(x) -symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: SymbolicUtils.Symbolic{eT}} = eT +symbolic_eltype(x::AbstractArray{BasicSymbolic{T}}) where {T} = eltype(symtype(Const{T}(x))) symbolic_eltype(::AbstractArray{Num}) = Real symbolic_eltype(::AbstractArray{symT}) where {eT, symT <: Arr{eT}} = eT diff --git a/src/rewrite-helpers.jl b/src/rewrite-helpers.jl index efa8b864e..99e0a2b54 100644 --- a/src/rewrite-helpers.jl +++ b/src/rewrite-helpers.jl @@ -1,5 +1,5 @@ """ - replacenode(expr::Symbolic, rules...) + replacenode(expr::BasicSymbolic, rules...) Walk the expression and replacenode subexpressions according to `rules`. `rules` could be rules constructed with `@rule`, a function, or a pair where the @@ -18,12 +18,12 @@ function replacenode(expr::Num, r::Pair, rules::Pair...; fixpoint = false) end # Fix ambiguity replacenode(expr::Num, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint) -replacenode(expr::Symbolic, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint) -replacenode(expr::Symbolic, r::Pair, rules::Pair...; fixpoint = false) = _replacenode(expr, r, rules...; fixpoint) +replacenode(expr::BasicSymbolic, rules...; fixpoint = false) = _replacenode(unwrap(expr), rules...; fixpoint) +replacenode(expr::BasicSymbolic, r::Pair, rules::Pair...; fixpoint = false) = _replacenode(expr, r, rules...; fixpoint) replacenode(expr::Number, rules...; fixpoint = false) = expr replacenode(expr::Number, r::Pair, rules::Pair...; fixpoint = false) = expr -function _replacenode(expr::Symbolic, rules...; fixpoint = false) +function _replacenode(expr::BasicSymbolic, rules...; fixpoint = false) rs = map(r -> r isa Pair ? (x -> isequal(x, unwrap(r[1])) ? unwrap(r[2]) : nothing) : r, rules) R = Prewalk(Chain(rs)) if fixpoint @@ -52,12 +52,12 @@ D = Differential(t) hasnode(Symbolics.is_derivative, X + D(X) + D(X^2)) # returns `true`. ``` """ -function hasnode(r::Function, y::Union{Num, Symbolic}) +function hasnode(r::Function, y::Union{Num, BasicSymbolic}) _hasnode(r, y) end -hasnode(r::Num, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y)) -hasnode(r::Symbolic, y::Union{Num, Symbolic}) = occursin(unwrap(r), unwrap(y)) -hasnode(r::Union{Num, Symbolic, Function}, y::Number) = false +hasnode(r::Num, y::Union{Num, BasicSymbolic}) = occursin(unwrap(r), unwrap(y)) +hasnode(r::BasicSymbolic, y::Union{Num, BasicSymbolic}) = occursin(unwrap(r), unwrap(y)) +hasnode(r::Union{Num, BasicSymbolic, Function}, y::Number) = false function _hasnode(r, y) y = unwrap(y) diff --git a/src/semipoly.jl b/src/semipoly.jl index 121a23c0d..ffa17eb88 100644 --- a/src/semipoly.jl +++ b/src/semipoly.jl @@ -3,8 +3,6 @@ using DataStructures export semipolynomial_form, semilinear_form, semiquadratic_form, polynomial_coeffs -import SymbolicUtils: unsorted_arguments - """ $(TYPEDEF) @@ -13,7 +11,7 @@ $(TYPEDFIELDS) """ struct SemiMonomial "monomial" - p::Union{S, N} where {S <: Symbolic, N <: Real} + p::Union{BasicSymbolic, Real} "coefficient" coeff::Any end @@ -35,7 +33,8 @@ function Base.:*(a::SemiMonomial, b::SemiMonomial) SemiMonomial(a.p * b.p, a.coeff * b.coeff) end Base.:*(m::SemiMonomial, n::Number) = SemiMonomial(m.p, m.coeff * n) -function Base.:*(m::SemiMonomial, t::Symbolic) +function Base.:*(m::SemiMonomial, t::BasicSymbolic) + isconst(t) && return m * unwrap_const(t) if iscall(t) op = operation(t) if op == (+) @@ -72,7 +71,7 @@ function pdegrees(x) Dict(keys(dict) .=> degrees) elseif issym(x) || iscall(x) return Dict(x=>1) - elseif x isa Number + elseif unwrap_const(x) isa Number return Dict() else error("pdegrees for $x unknown") @@ -80,7 +79,7 @@ function pdegrees(x) end pdegree(x::Number) = 0 -function pdegree(x::Symbolic) +function pdegree(x::BasicSymbolic) degree_dict = pdegrees(x) if isempty(degree_dict) return 0 @@ -116,7 +115,7 @@ end # Return true if the degrees of `m` are all 0s and its coefficient is a `Real`. Base.:isreal(m::SemiMonomial) = m.p isa Number && isone(m.p) && unwrap(m.coeff) isa Real -Base.:isreal(::Symbolic) = false +Base.:isreal(::BasicSymbolic) = false # Transform `m` to a `Real`. # Assume `isreal(m) == true`, otherwise calling this function does not make sense. diff --git a/src/utils.jl b/src/utils.jl index 8359430e9..9c8b3a9a2 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -70,7 +70,7 @@ end get_variables!(vars, e::Number, varlist=nothing) = vars -function get_variables!(vars, e::Symbolic, varlist=nothing) +function get_variables!(vars, e::BasicSymbolic, varlist=nothing) if is_singleton(e) if isnothing(varlist) || any(isequal(e), varlist) push!(vars, e) @@ -126,7 +126,7 @@ get_differential_vars!(vars, e, varlist=nothing) = vars get_differential_vars!(vars, e::Number, varlist=nothing) = vars -function get_differential_vars!(vars, e::Symbolic, varlist=nothing) +function get_differential_vars!(vars, e::BasicSymbolic, varlist=nothing) if is_derivative(e) if isnothing(varlist) || any(isequal(e), varlist) push!(vars, e) @@ -148,11 +148,11 @@ function get_differential_vars!(vars, e::Equation, varlist=nothing) end # Sym / Term --> Symbol -Base.Symbol(x::Union{Num,Symbolic}) = tosymbol(x) +Base.Symbol(x::Union{Num,BasicSymbolic}) = tosymbol(x) tosymbol(t::Num; kwargs...) = tosymbol(value(t); kwargs...) """ - diff2term(x, x_metadata::Dict{Datatype, Any}) -> Symbolic + diff2term(x, x_metadata::Dict{Datatype, Any}) -> BasicSymbolic Convert a differential variable to a `Term`. Note that it only takes a `Term` not a `Num`. @@ -215,7 +215,7 @@ end setname(v, name) = setmetadata(v, Symbolics.VariableSource, (:variables, name)) """ - tosymbol(x::Union{Num,Symbolic}; states=nothing, escape=true) -> Symbol + tosymbol(x::Union{Num,BasicSymbolic}; states=nothing, escape=true) -> Symbol Convert `x` to a symbol. `states` are the states of a system, and `escape` means if the target has escapes like `val"y(t)"`. If `escape` is false, then @@ -265,7 +265,7 @@ function tosymbol(t; states=nothing, escape=true) end end -function lower_varname(var::Symbolic, idv, order) +function lower_varname(var::BasicSymbolic, idv, order) order == 0 && return var D = Differential(idv) for _ in 1:order @@ -431,7 +431,7 @@ function coeff(p, sym=nothing) end else p isa Number && return sym === nothing ? p : 0 - p isa Symbolic && return coeff(p, sym) + p isa BasicSymbolic && return coeff(p, sym) throw(DomainError(p, "Datatype $(typeof(p)) not accepted.")) end end @@ -516,7 +516,7 @@ julia> numerator(x/y) x ``` """ -function Base.numerator(x::Union{Num, Symbolic}) +function Base.numerator(x::Union{Num, BasicSymbolic}) x = unwrap(x) if iscall(x) && operation(x) == / x = arguments(x)[1] # get numerator @@ -536,7 +536,7 @@ julia> denominator(x/y) y ``` """ -function Base.denominator(x::Union{Num, Symbolic}) +function Base.denominator(x::Union{Num, BasicSymbolic}) x = unwrap(x) if iscall(x) && operation(x) == / x = arguments(x)[2] # get denominator diff --git a/src/variable.jl b/src/variable.jl index fa751a6c2..b27e05eca 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -542,7 +542,7 @@ const _fail = Dict() _getname(x, _) = nameof(x) _getname(x::Symbol, _) = x -function _getname(x::Symbolic, val) +function _getname(x::BasicSymbolic, val) issym(x) && return nameof(x) if iscall(x) && issym(operation(x)) return nameof(operation(x)) @@ -562,7 +562,7 @@ getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val) SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() -function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: Symbolic{S}} +function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: BasicSymbolic{S}} ArraySymbolic() end # need this otherwise the `::Type{<:BasicSymbolic}` method in SymbolicUtils is @@ -573,18 +573,7 @@ end SymbolicIndexingInterface.hasname(x::Union{Num,Arr,Complex{Num}}) = hasname(unwrap(x)) -function SymbolicIndexingInterface.hasname(x::Symbolic) - issym(x) || !iscall(x) || iscall(x) && (issym(operation(x)) || operation(x) == getindex && hasname(arguments(x)[1])) -end - -# This is type piracy, but changing it breaks precompilation for MTK because it relies on this falling back to -# `_getname` which calls `nameof` which returns the name of the system, when `x::AbstractSystem`. -# FIXME: In a breaking release -function SymbolicIndexingInterface.getname(x, val = _fail) - _getname(unwrap(x), val) -end - -function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, Symbolic, Equation, Inequality}, d::Dict; kwargs...) +function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, BasicSymbolic, Equation, Inequality}, d::Dict; kwargs...) val = fixpoint_sub(ex, d; kwargs...) return _recursive_unwrap(val) end @@ -704,7 +693,7 @@ function fast_substitute(expr, subs; operator = Nothing) op = fast_substitute(operation(expr), subs; operator) args = SymbolicUtils.arguments(expr) if !(op isa operator) - canfold = Ref(!(op isa Symbolic)) + canfold = Ref(!(op isa BasicSymbolic)) args = let canfold = canfold map(args) do x symbolic_type(x) == NotSymbolic() && !is_array_of_symbolics(x) && return x @@ -735,7 +724,7 @@ function fast_substitute(expr, pair::Pair; operator = Nothing) op = fast_substitute(operation(expr), pair; operator) args = SymbolicUtils.arguments(expr) if !(op isa operator) - canfold = Ref(!(op isa Symbolic)) + canfold = Ref(!(op isa BasicSymbolic)) args = let canfold = canfold map(args) do x symbolic_type(x) == NotSymbolic() && !is_array_of_symbolics(x) && return x @@ -879,7 +868,7 @@ function rename(x::CallWithMetadata, name) rename_metadata(x, CallWithMetadata(rename(x.f, name), x.metadata), name) end -function rename(x::Symbolic, name) +function rename(x::BasicSymbolic, name) if issym(x) xx = @set! x.name = name xx = rename_metadata(x, xx, name) diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index bf30f8913..a669bfdfa 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -112,30 +112,30 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) # for every argument find the types that # should be allowed as argument. These are: # - # (1) T (2) wrapper_type(T) (3) Symbolic{T} + # (1) T (2) wrapper_type(T) (3) BasicSymbolic # # However later while emitting methods we omit the one # method where all arguments are (1) since those are # expected to be defined outside Symbolics if arg isa Expr && arg.head == :(::) T = Base.eval(mod, arg.args[2]) - Ts = has_symwrapper(T) ? (T, :(Symbolics.SymbolicUtils.Symbolic{<:$T}), wrapper_type(T)) : - (T,:(Symbolics.SymbolicUtils.Symbolic{<:$T})) + Ts = has_symwrapper(T) ? (T, BasicSymbolic, wrapper_type(T)) : + (T, BasicSymbolic) if T <: AbstractArray && wrap_arrays eT = eltype(T) if eT == Any eT = Real end _arr_type_fn = if hasmethod(ndims, Tuple{Type{T}}) - (elT) -> :(AbstractArray{T, $(ndims(T))} where {T <: $elT}) + (elT) -> AbstractArray{T, ndims(T)} where {T <: elT} else - (elT) -> :(AbstractArray{T} where {T <: $elT}) + (elT) -> AbstractArray{T} where {T <: elT} end if has_symwrapper(eT) - Ts = (Ts..., _arr_type_fn(:(Symbolics.SymbolicUtils.Symbolic{<:$eT})), + Ts = (Ts..., _arr_type_fn(BasicSymbolic), _arr_type_fn(wrapper_type(eT))) else - Ts = (Ts..., _arr_type_fn(:(Symbolics.SymbolicUtils.Symbolic{<:$eT}))) + Ts = (Ts..., _arr_type_fn(BasicSymbolic)) end end Ts @@ -158,19 +158,31 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) :($n::$T) end - fbody = :(if any($iswrapped, ($(names...),)) - $wrap($impl_name($self, $([:($unwrap($arg)) for arg in names]...))) - else - $impl_name($self, $(names...)) - end) + impl_args = map(enumerate(names)) do (i, name) + is_wrapper_type(Ts[i]) ? :($unwrap($name)) : name + end + implcall = :($impl_name($self, $(impl_args...))) + if any(is_wrapper_type, Ts) + implcall = :($wrap($implcall)) + end + + body = Expr(:block) + for (i, T) in enumerate(Ts) + if T === BasicSymbolic + push!(body.args, :(@assert $symtype($(names[i])) <: $(types[i][1]))) + else T === AbstractArray{T} where {T <: BasicSymbolic} + push!(body.args, :(@assert $symtype($(names[i])[1]) <: $(types[i][1]))) + end + end + push!(body.args, implcall) if isempty(kwargs) :(function $fname($(method_args...)) - $fbody + $body end) else :(function $fname($(method_args...); $(kwargs...)) - $fbody + $body end) end end From 48299c5519bc9f716e7f6c3225a4ed15baa11e98 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 15 Sep 2025 12:13:47 +0530 Subject: [PATCH 02/28] refactor: remove `ComplexTerm` --- src/complex.jl | 91 ++----------------------------------- src/diff.jl | 7 +-- src/integral.jl | 2 +- src/solver/preprocess.jl | 22 --------- src/solver/solve_helpers.jl | 4 +- 5 files changed, 10 insertions(+), 116 deletions(-) diff --git a/src/complex.jl b/src/complex.jl index b221fa30f..3525cef98 100644 --- a/src/complex.jl +++ b/src/complex.jl @@ -1,12 +1,3 @@ -abstract type AbstractComplexTerm{T} <: Symbolic{Complex{T}} -end - -struct ComplexTerm{T} <: AbstractComplexTerm{T} - re - im -end - -Base.imag(c::Symbolic{Complex{T}}) where {T} = term(imag, c) SymbolicUtils.promote_symtype(::typeof(imag), ::Type{Complex{T}}) where {T} = T Base.promote_rule(::Type{Complex{T}}, ::Type{S}) where {T<:Real, S<:Num} = Complex{S} # 283 @@ -17,84 +8,8 @@ function wrapper_type(::Type{Complex{T}}) where T Symbolics.has_symwrapper(T) ? Complex{wrapper_type(T)} : Complex{T} end -symtype(a::ComplexTerm{T}) where T = Complex{T} -iscall(a::ComplexTerm) = true -operation(a::ComplexTerm{T}) where T = Complex{T} -arguments(a::ComplexTerm) = [a.re, a.im] -metadata(a::ComplexTerm) = metadata(a.re) - -function maketerm(T::Type{<:ComplexTerm}, f, args, metadata) - if f <: Complex - ComplexTerm{real(f)}(args...) - else - maketerm(typeof(first(args)), f, args, metadata) - end -end - -function Base.show(io::IO, mime::MIME"text/plain", a::ComplexTerm) - print(io, "ComplexTerm(") - show(io, mime, wrap(a)) - print(io, ")") -end - -function Base.show(io::IO, a::Complex{Num}) - rr = unwrap(real(a)) - ii = unwrap(imag(a)) - - if iscall(rr) && (operation(rr) === real) && - iscall(ii) && (operation(ii) === imag) && - isequal(arguments(rr)[1], arguments(ii)[1]) - - return print(io, arguments(rr)[1]) - end - - i = Sym{Real}(:im) - show(io, real(a) + i * imag(a)) -end - -function unwrap(a::Complex{<:Num}) +function SymbolicUtils.unwrap(a::Complex{<:Num}) re, img = unwrap(real(a)), unwrap(imag(a)) - if re isa Real && img isa Real - return re + im * img - end - T = promote_type(symtype(re), symtype(img)) - ComplexTerm{T}(re, img) -end -wrap(a::ComplexTerm) = Complex(wrap.(arguments(a))...) -wrap(a::Symbolic{<:Complex}) = Complex(wrap(real(a)), wrap(imag(a))) - -SymbolicUtils.@number_methods( - ComplexTerm, - unwrap(f(wrap(a))), - unwrap(f(wrap(a), wrap(b))), - ) - -function Base.isequal(a::ComplexTerm{T}, b::ComplexTerm{S}) where {T,S} - T === S && isequal(a.re, b.re) && isequal(a.im, b.im) -end - -function Base.hash(a::ComplexTerm{T}, h::UInt) where T - hash(hash(a.im, hash(a.re, hash(T, hash(h ⊻ 0x1af5d7582250ac4a))))) -end - -Base.iszero(x::Complex{<:Num}) = iszero(real(x)) && iszero(imag(x)) -Base.isone(x::Complex{<:Num}) = isone(real(x)) && iszero(imag(x)) -_iszero(x::Complex{<:Num}) = _iszero(unwrap(x)) -_isone(x::Complex{<:Num}) = _isone(unwrap(x)) - -function SymbolicIndexingInterface.hasname(x::ComplexTerm) - a = arguments(unwrap(x.im))[1] - b = arguments(unwrap(x.re))[1] - return isequal(a, b) && hasname(a) -end - -function _getname(x::ComplexTerm, val) - a = arguments(unwrap(x.im))[1] - b = arguments(unwrap(x.re))[1] - if isequal(a, b) - return _getname(a, val) - end - if val == _fail - throw(ArgumentError("Variable $x doesn't have a name.")) - end + sT = promote_type(symtype(re), symtype(img)) + return Term{vartype(re)}(complex, SymbolicUtils.ArgsT{vartype(re)}((re, img)); type = Complex{sT}, shape = SymbolicUtils.ShapeVecT()) end diff --git a/src/diff.jl b/src/diff.jl index ddeff688d..4be9f46f6 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -46,7 +46,7 @@ end (D::Differential)(x::Union{AbstractFloat, Integer}) = wrap(0) (D::Differential)(x::Union{Num, Arr}) = wrap(D(unwrap(x))) -(D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x))))) +(D::Differential)(x::Complex{Num}) = Complex{Num}(wrap(D(unwrap(real(x)))), wrap(D(unwrap(imag(x))))) SymbolicUtils.promote_symtype(::Differential, T) = T SymbolicUtils.isbinop(f::Differential) = false @@ -414,8 +414,9 @@ function expand_derivatives(n::Num, simplify=false; kwargs...) wrap(expand_derivatives(value(n), simplify; kwargs...)) end function expand_derivatives(n::Complex{Num}, simplify=false; kwargs...) - wrap(ComplexTerm{Real}(expand_derivatives(real(n), simplify; kwargs...), - expand_derivatives(imag(n), simplify; kwargs...))) + re = expand_derivatives(real(n), simplify; kwargs...) + img = expand_derivatives(imag(n), simplify; kwargs...) + Complex{Num}(wrap(re), wrap(img)) end expand_derivatives(x, simplify=false; kwargs...) = x diff --git a/src/integral.jl b/src/integral.jl index 0995330c5..c89e51fb1 100644 --- a/src/integral.jl +++ b/src/integral.jl @@ -30,7 +30,7 @@ function (I::Integral)(x::Union{Rational, AbstractIrrational, AbstractFloat, Int a, b = value.(DomainSets.endpoints(domain)) wrap((b - a)*x) end -(I::Integral)(x::Complex) = wrap(ComplexTerm{Real}(I(unwrap(real(x))), I(unwrap(imag(x))))) +(I::Integral)(x::Complex) = Complex{Num}(wrap(I(unwrap(real(x)))), wrap(I(unwrap(imag(x))))) (I::Integral)(x) = Term{SymbolicUtils.symtype(x)}(I, [x]) (I::Integral)(x::Num) = Num(I(Symbolics.value(x))) SymbolicUtils.promote_symtype(::Integral, x) = x diff --git a/src/solver/preprocess.jl b/src/solver/preprocess.jl index 5f11a3422..39e062c5b 100644 --- a/src/solver/preprocess.jl +++ b/src/solver/preprocess.jl @@ -116,28 +116,6 @@ function _filter_poly(expr, var) end args = copy(parent(arguments(expr))) - if expr isa ComplexTerm - subs1, subs2 = Dict(), Dict() - expr1, expr2 = 0, 0 - - if !isequal(expr.re, 0) - subs1, expr1 = _filter_poly(expr.re, var) - end - if !isequal(expr.im, 0) - subs2, expr2 = _filter_poly(expr.im, var) - end - - subs = merge(subs1, subs2) - i_var = gensym() - i_var = (@variables $i_var)[1] - - subs[i_var] = im - expr = unwrap(expr1 + i_var * expr2) - - args = map(unwrap, arguments(expr)) - oper = operation(expr) - return subs, term(oper, args...) - end subs = Dict{Any, Any}() for (i, arg) in enumerate(args) diff --git a/src/solver/solve_helpers.jl b/src/solver/solve_helpers.jl index f70ebdb05..35827d898 100644 --- a/src/solver/solve_helpers.jl +++ b/src/solver/solve_helpers.jl @@ -87,7 +87,7 @@ function check_expr_validity(expr) valid_type = false if type_expr <: Number || type_expr == Num || type_expr == SymbolicUtils.BasicSymbolic{Real} || - type_expr == Complex{Num} || type_expr == ComplexTerm{Real} || type_expr == SymbolicUtils.BasicSymbolic{Complex{Real}} + type_expr == Complex{Num} || type_expr == SymbolicUtils.BasicSymbolic{Complex{Real}} valid_type = true end iscall(unwrap(expr)) && @assert !hasderiv(unwrap(expr)) "Differential equations are not currently supported" @@ -110,7 +110,7 @@ end # converts everything to BIG function bigify(n) n = unwrap(n) - if n isa ComplexTerm || n isa Float64 || n isa Irrational + if n isa Float64 || n isa Irrational return n end From da22d7c125753b31756240bca16bf712699ba14b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 15 Sep 2025 12:32:22 +0530 Subject: [PATCH 03/28] refactor: remove `CallWithMetadata` --- src/Symbolics.jl | 5 +--- src/build_function.jl | 3 +- src/latexify_recipes.jl | 8 ----- src/num.jl | 6 +--- src/utils.jl | 2 +- src/variable.jl | 65 +++++------------------------------------ 6 files changed, 11 insertions(+), 78 deletions(-) diff --git a/src/Symbolics.jl b/src/Symbolics.jl index afc62c9f1..ebd7834c3 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -197,9 +197,6 @@ for T in [Num, Complex{Num}] Broadcast.broadcastable(x::$T) = x end - for S in [:(BasicSymbolic{<:FnType}), :CallWithMetadata] - @eval (f::$S)(x::$T, y...) = wrap(f(unwrap(x), unwrap.(y)...)) - end end for sType in [Pair, Vector, Dict] @@ -545,7 +542,7 @@ include("inverse.jl") export rootfunction, left_continuous_function, right_continuous_function, @register_discontinuity include("discontinuities.jl") -@public Arr, CallWithMetadata, NAMESPACE_SEPARATOR, Unknown, VariableDefaultValue, VariableSource +@public Arr, NAMESPACE_SEPARATOR, Unknown, VariableDefaultValue, VariableSource @public _parse_vars, derivative, gradient, jacobian, sparsejacobian, hessian, sparsehessian @public get_variables, get_variables!, get_differential_vars, getparent, option_to_metadata_type, scalarize, shape @public unwrap, variable, wrap diff --git a/src/build_function.jl b/src/build_function.jl index 8a7ee48e9..cc00280cc 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -86,7 +86,6 @@ end # Scalar output unwrap_nometa(x) = unwrap(x) -unwrap_nometa(x::CallWithMetadata) = unwrap(x.f) function destructure_arg(arg::Union{AbstractArray, Tuple,NamedTuple}, inbounds, name) if !(arg isa Arr) DestructuredArgs(map(unwrap_nometa, arg), name, inbounds=inbounds, create_bindings=false) @@ -107,7 +106,7 @@ SymbolicUtils.Code.cse_inside_expr(sym, ::Symbolics.Operator, args...) = false # don't CSE inside `getindex` of things created via `@variables` # EXCEPT called variables function SymbolicUtils.Code.cse_inside_expr(sym, ::typeof(getindex), x::BasicSymbolic, idxs...) - return !hasmetadata(sym, VariableSource) || hasmetadata(sym, CallWithParent) + return !hasmetadata(sym, VariableSource) || SymbolicUtils.is_called_function_symbolic(x) end function _build_function(target::JuliaTarget, op, args...; diff --git a/src/latexify_recipes.jl b/src/latexify_recipes.jl index 37dc90676..ad2de9f48 100644 --- a/src/latexify_recipes.jl +++ b/src/latexify_recipes.jl @@ -101,14 +101,6 @@ end return unwrap(n) end -@latexrecipe function f(n::CallWithMetadata) - env --> :equation - mult_symbol --> "" - index --> :subscript - - return n.f -end - @latexrecipe function f(n::BasicSymbolic) env --> :equation mult_symbol --> "" diff --git a/src/num.jl b/src/num.jl index 1b6795f2d..8a8cf25e6 100644 --- a/src/num.jl +++ b/src/num.jl @@ -89,15 +89,11 @@ end substitute(expr, s::Pair; kw...) = substituter([s[1] => s[2]])(expr; kw...) substitute(expr, s::Vector; kw...) = substituter(s)(expr; kw...) -function _unwrap_callwithmeta(x) - x = value(x) - return x isa CallWithMetadata ? x.f : x -end function subrules_to_dict(pairs) if pairs isa Pair pairs = (pairs,) end - return Dict(_unwrap_callwithmeta(k) => value(v) for (k, v) in pairs) + return Dict(k => value(v) for (k, v) in pairs) end function substituter(pairs) dict = subrules_to_dict(pairs) diff --git a/src/utils.jl b/src/utils.jl index 9c8b3a9a2..98aa732a5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -301,7 +301,7 @@ end function var_from_nested_derivative(x,i=0) x = unwrap(x) - if issym(x) || x isa CallWithMetadata + if issym(x) (x, i) elseif iscall(x) operation(x) isa Differential ? diff --git a/src/variable.jl b/src/variable.jl index b27e05eca..ed7f54b32 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -68,11 +68,7 @@ function scalarize_getindex(x, parent=Ref{Any}(x)) else xx = unwrap(scalarize(x)) xx = metadata(xx, metadata(x)) - if symtype(xx) <: FnType - setmetadata(CallWithMetadata(xx, metadata(xx)), GetindexParent, parent[]) - else - setmetadata(xx, GetindexParent, parent[]) - end + setmetadata(xx, GetindexParent, parent[]) end end @@ -212,7 +208,7 @@ function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, end end argtypes = arg_types_from_call_args(call_args) - ex = :($CallWithMetadata($Sym{$FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}}($_vname))) + ex = :($Sym{$FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}}($_vname)) else vname = lhs if isruntime @@ -333,45 +329,6 @@ function setprops_expr(expr, props, macroname, varname) expr end -struct CallWithMetadata{T,M} <: Symbolic{T} - f::Symbolic{T} - metadata::M -end - -for f in [:iscall, :operation, :arguments] - @eval SymbolicUtils.$f(x::CallWithMetadata) = $f(x.f) -end - -SymbolicUtils.Code.toexpr(x::CallWithMetadata, st) = SymbolicUtils.Code.toexpr(x.f, st) - -CallWithMetadata(f) = CallWithMetadata(f, nothing) - -SymbolicIndexingInterface.symbolic_type(::Type{<:CallWithMetadata}) = ScalarSymbolic() - -# HACK: -# A `DestructuredArgs` with `create_bindings = false` doesn't create a `Let` block, and -# instead adds the assignments to the rewrites dictionary. This is problematic, because -# if the `DestructuredArgs` contains a `CallWithMetadata` the key in the `Dict` will be -# a `CallWithMetadata` which won't match against the operation of the called symbolic. -# This is the _only_ hook we have and relies on the `DestructuredArgs` being converted -# into a list of `Assignment`s before being added to the `Dict` inside `toexpr(::Let, st)`. -# The callable symbolic is unwrapped so it matches the operation of the called version. -SymbolicUtils.Code.Assignment(f::CallWithMetadata, x) = SymbolicUtils.Code.Assignment(f.f, x) - -function Base.show(io::IO, c::CallWithMetadata) - show(io, c.f) - print(io, "⋆") -end - -struct CallWithParent end - -function (f::CallWithMetadata)(args...) - setmetadata(metadata(unwrap(f.f(map(unwrap, args)...)), metadata(f)), CallWithParent, f) -end - -Base.isequal(a::CallWithMetadata, b::CallWithMetadata) = isequal(a.f, b.f) -Base.hash(x::CallWithMetadata, h::UInt) = hash(x.f, h) - function arg_types_from_call_args(call_args) if length(call_args) == 1 && only(call_args) == :.. return Tuple @@ -411,7 +368,7 @@ function construct_var(macroname, var_name, type, call_args, val, prop) # (..)::ArgT is Vararg{ArgT} var_name, fntype = function_name_and_type(var_name) argtypes = arg_types_from_call_args(call_args) - :($CallWithMetadata($Sym{$FnType{$argtypes, $type, $(fntype...)}}($var_name))) + :($Sym{$FnType{$argtypes, $type, $(fntype...)}}($var_name)) # This elseif handles the special case with e.g. variables on the form # @variables X(deps...) where deps is a vector (which length might be unknown). elseif (call_args isa Vector) && (length(call_args) == 1) && (call_args[1] isa Expr) && @@ -447,8 +404,7 @@ function _construct_array_vars(macroname, var_name, type, call_args, val, prop, var_name, fntype = function_name_and_type(var_name) argtypes = arg_types_from_call_args(call_args) ex = :($Sym{Array{$FnType{$argtypes, $type, $(fntype...)}, $ndim}}($var_name)) - ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) - :($map($CallWithMetadata, $ex)) + :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) else # [(R -> R)(R) ....] need_scalarize = true @@ -818,10 +774,7 @@ Also see `variables`. function variable(name, idx...; T=Real) name_ij = Symbol(name, join(map_subscripts.(idx), "ˏ")) v = Sym{T}(name_ij) - if T <: FnType - v = CallWithMetadata(v) - end - Num(setmetadata(v, VariableSource, (:variables, name_ij))) + wrap(setmetadata(v, VariableSource, (:variables, name_ij))) end ##### Renaming ##### @@ -857,17 +810,13 @@ function rename(x::ArrayOp, name) t = x.term args = arguments(t) # Hack: - @assert operation(t) === (map) && (args[1] isa CallWith || args[1] == CallWithMetadata) + @assert operation(t) === (map) && args[1] isa CallWith rn = rename(args[2], name) xx = metadata(operation(t)(args[1], rn), metadata(x)) rename_getindex_source(rename_metadata(x, xx, name)) end -function rename(x::CallWithMetadata, name) - rename_metadata(x, CallWithMetadata(rename(x.f, name), x.metadata), name) -end - function rename(x::BasicSymbolic, name) if issym(x) xx = @set! x.name = name @@ -875,7 +824,7 @@ function rename(x::BasicSymbolic, name) symtype(xx) <: AbstractArray ? rename_getindex_source(xx) : xx elseif iscall(x) && operation(x) === getindex rename(arguments(x)[1], name)[arguments(x)[2:end]...] - elseif iscall(x) && symtype(operation(x)) <: FnType || operation(x) isa CallWithMetadata + elseif iscall(x) && symtype(operation(x)) <: FnType xx = @set x.f = rename(operation(x), name) @set! xx.hash = Ref{UInt}(0) return rename_metadata(x, xx, name) From b0336cffebb5baf19ce171977cb577e625d9093d Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 15 Sep 2025 18:27:38 +0530 Subject: [PATCH 04/28] refactor: remove array variants --- src/array-lib.jl | 377 ++----------- src/arrays.jl | 1111 ++------------------------------------- src/build_function.jl | 2 +- src/latexify_recipes.jl | 9 - src/utils.jl | 4 +- src/variable.jl | 24 +- 6 files changed, 102 insertions(+), 1425 deletions(-) diff --git a/src/array-lib.jl b/src/array-lib.jl index 7748347e8..02a15401c 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -14,201 +14,23 @@ struct GetindexPosthookCtx end setmetadata(x, GetindexPosthookCtx, f) end end -function Base.getindex(x::SymArray, idx::CartesianIndex) - return x[Tuple(idx)...] -end - -function Base.getindex(x::SymArray, idx...) - idx = unwrap.(idx) - meta = metadata(unwrap(x)) - if iscall(x) && (op = operation(x)) isa Operator - args = arguments(x) - return op(only(args)[idx...]) - elseif shape(x) !== Unknown() && all(i -> i isa Integer, idx) - II = CartesianIndices(axes(x)) - ii = CartesianIndex(idx) - @boundscheck begin - if !in(ii, II) - throw(BoundsError(x, idx)) - end - end - res = Term{eltype(symtype(x))}(getindex, [x, Tuple(ii)...]; metadata = meta) - elseif all(i -> symtype(i) <: Integer, idx) - shape(x) !== Unknown() && @boundscheck begin - if length(idx) > 1 - for (a, i) in zip(axes(x), idx) - if i isa Integer && !(i in a) - throw(BoundsError(x, idx)) - end - end - end - end - res = Term{eltype(symtype(x))}(getindex, [x, idx...]; metadata = meta) - elseif length(idx) == 1 && symtype(first(idx)) <: CartesianIndex - i = first(idx) - ii = i isa CartesianIndex ? Tuple(i) : arguments(i) - - return getindex(x, ii...) - else - input_idx = [] - output_idx = [] - ranges = Dict{BasicSymbolic,AbstractRange}() - subscripts = makesubscripts(length(idx)) - for (j, i) in enumerate(idx) - if symtype(i) <: Integer - push!(input_idx, i) - elseif i isa Colon - push!(output_idx, subscripts[j]) - push!(input_idx, subscripts[j]) - elseif i isa AbstractVector - isym = subscripts[j] - push!(output_idx, isym) - push!(input_idx, isym) - ranges[isym] = i - else - error("Don't know how to index by $i") - end - end - - term = Term{Any}(getindex, [x, idx...]; metadata = meta) - T = eltype(symtype(x)) - N = ndims(x) - count(i -> symtype(i) <: Integer, idx) - res = ArrayOp(atype(symtype(x)){T,N}, - (output_idx...,), - x[input_idx...], - +, - term, - ranges) - end - - if hasmetadata(x, GetindexPosthookCtx) - f = getmetadata(x, GetindexPosthookCtx, identity) - f(res, x, idx...) - else - res - end -end # Wrapped array should wrap the elements too function Base.getindex(x::Arr, idx...) wrap(unwrap(x)[idx...]) end -function Base.getindex(x::Arr, idx::Symbolic{<:Integer}...) +function Base.getindex(x::Arr, idx::BasicSymbolic...) wrap(unwrap(x)[idx...]) end -function Base.getindex(x::Arr, I::Symbolic{CartesianIndex}) - wrap(unwrap(x)[tup(I)...]) -end -Base.getindex(I::Symbolic{CartesianIndex}, i::Integer) = tup(I)[i] - -function Base.getindex(A::AbstractArray{T}, I::Symbolic{CartesianIndex}) where {T} - term(getindex, A, tup(I)..., type=T) -end - -function Base.CartesianIndex(x::Symbolic{<:Integer}, xs::Symbolic{<:Integer}...) - term(CartesianIndex, x, xs..., type=CartesianIndex) -end - import Base: +, -, * -tup(c::CartesianIndex) = Tuple(c) -tup(c::Symbolic{CartesianIndex}) = iscall(c) ? arguments(c) : error("Cartesian index not found") - -@wrapped function -(x::CartesianIndex, y::CartesianIndex) - CartesianIndex((tup(x) .- tup(y))...) -end -@wrapped function +(x::CartesianIndex, y::CartesianIndex) - CartesianIndex((tup(x) .+ tup(y))...) -end - -@wrapped function *(x::CartesianIndex, y::CartesianIndex) - CartesianIndex((tup(x) .* tup(y))...) -end - -@wrapped function *(a::Integer, x::CartesianIndex) - CartesianIndex((a * tup(x))...) -end - -@wrapped function *(x::CartesianIndex, b::Integer) - CartesianIndex((tup(x) * b)...) -end - - -function propagate_ndims(::typeof(getindex), x, idx...) - ndims(x) - count(x -> symtype(x) <: Integer, idx) -end - -function propagate_shape(::typeof(getindex), x, idx...) - @oops axes = shape(x) - - idx1 = to_indices(CartesianIndices(axes), axes, idx) - ([1:length(x) for x in idx1 if !(symtype(x) <: Number)]...,) -end - -propagate_eltype(::typeof(getindex), x, idx...) = geteltype(x) - -function SymbolicUtils.promote_symtype(::typeof(getindex), X, ii...) - @assert all(i -> i <: Integer, ii) - eltype(X) -end - #### Broadcast #### -# - -using Base.Broadcast - -Base.broadcastable(s::SymArray) = s -struct SymBroadcast <: Broadcast.BroadcastStyle end -Broadcast.BroadcastStyle(::Type{<:SymArray}) = SymBroadcast() -Broadcast.result_style(::SymBroadcast) = SymBroadcast() -Broadcast.BroadcastStyle(::SymBroadcast, ::Broadcast.BroadcastStyle) = SymBroadcast() - -isonedim(x, i) = shape(x) == Unknown() ? false : isone(size(x, i)) - -function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast}) - # Do the thing here - args = inner_unwrap.(bc.args) - ndim = mapfoldl(ndims, max, args, init=0) - subscripts = makesubscripts(ndim) - - onedim_count = mapreduce(+, args) do x - if ndims(x) != 0 - map(i -> isonedim(x, i) ? 1 : 0, 1:ndim) - else - map(i -> 1, 1:ndim) - end - end - - extruded = map(x -> x < length(args), onedim_count) - - expr_args′ = map(args) do x - if ndims(x) != 0 - subs = map(i -> extruded[i] && isonedim(x, i) ? - 1 : subscripts[i], 1:ndims(x)) - x[subs...] - elseif x isa Base.RefValue - x[] - else - x - end - end - expr = term(bc.f, expr_args′...) # Imagine x .=> y -- if you don't have a term - # then you get pairs, and index matcher cannot - # recurse into pairs - Atype = propagate_atype(broadcast, bc.f, args...) - args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, args) - ArrayOp(Atype{symtype(expr),ndim}, - (subscripts...,), - expr, - +, - Term{Any}(broadcast, [bc.f, args...])) -end # On wrapper: struct SymWrapBroadcast <: Broadcast.BroadcastStyle end -Base.broadcastable(s::Arr) = s +Broadcast.broadcastable(s::Arr) = s Broadcast.BroadcastStyle(::Type{<:Arr}) = SymWrapBroadcast() @@ -216,185 +38,64 @@ Broadcast.result_style(::SymWrapBroadcast) = SymWrapBroadcast() Broadcast.BroadcastStyle(::SymWrapBroadcast, ::Broadcast.BroadcastStyle) = SymWrapBroadcast() -Broadcast.BroadcastStyle(::SymBroadcast, +Broadcast.BroadcastStyle(::SymbolicUtils.SymBroadcast, ::SymWrapBroadcast) = Broadcast.Unknown() -function Broadcast.copy(bc::Broadcast.Broadcasted{SymWrapBroadcast}) - args = map(bc.args) do arg - if arg isa Broadcast.Broadcasted - return Broadcast.copy(arg) - else - return arg - end - end - wrap(broadcast(bc.f, map(unwrap, args)...)) +unwrap_broadcasts(head, args...) = (unwrap_broadcast(head), unwrap_broadcasts(args...)...) +unwrap_broadcasts() = () +unwrap_broadcast(x) = unwrap(x) +function unwrap_broadcast(bc::Broadcast.Broadcasted{SymWrapBroadcast}) + Broadcast.Broadcasted{SymbolicUtils.SymBroadcast{VartypeT}}(bc.f, unwrap_broadcasts(bc.args...), bc.axes) end +function Broadcast.copy(bc::Broadcast.Broadcasted{SymWrapBroadcast}) + return wrap(copy(unwrap_broadcast(bc))) +end -#################### TRANSPOSE ################ -# -@wrapped function Base.adjoint(A::AbstractMatrix) - @syms i::Int j::Int - @arrayop (i, j) A[j, i] term = A' -end false - -@wrapped function Base.adjoint(b::AbstractVector) - @syms i::Int - @arrayop (1, i) b[i] term = b' -end false - -import Base: *, \ - -using LinearAlgebra - -isdot(A, b) = isadjointvec(A) && ndims(b) == 1 - -isadjointvec(A::Adjoint) = ndims(parent(A)) == 1 -isadjointvec(A::Transpose) = ndims(parent(A)) == 1 +#################### POLYADIC ################ -function isadjointvec(A) - if iscall(A) - (operation(A) === (adjoint) || - operation(A) == (transpose)) && ndims(arguments(A)[1]) == 1 - else - false - end +function *(x::Arr, args...) + return *(unwrap(x), args...) end -isadjointvec(A::ArrayOp) = isadjointvec(A.term) - -__symtype(x::Type{<:Symbolic{T}}) where T = T -function symeltype(A) - T = eltype(A) - T <: Symbolic ? __symtype(T) : T +function *(a::SymbolicUtils.PolyadicNumericOpFirstArgT, b::Arr, bs...) + return *(a, unwrap(b), bs...) end -# TODO: add more such methods -function getindex(A::AbstractArray, i::Symbolic{<:Integer}, ii::Symbolic{<:Integer}...) - Term{symeltype(A)}(getindex, [A, i, ii...]) +function *(a::LinearAlgebra.Adjoint{T, <: AbstractVector}, b::Arr, bs...) where {T} + return *(a, unwrap(b), bs...) end - -function getindex(A::AbstractArray, i::Int, j::Symbolic{<:Integer}) - Term{symeltype(A)}(getindex, [A, i, j]) +function *(a::LinearAlgebra.Adjoint{T, <: AbstractVector}, b::Arr, c::AbstractVector, bs...) where {T} + return *(a, unwrap(b), unwrap(c), bs...) end - -function getindex(A::AbstractArray, j::Symbolic{<:Integer}, i::Int) - Term{symeltype(A)}(getindex, [A, j, i]) +function *(a::Number, b::Arr, bs...) + return *(a, unwrap(b), bs...) end - -function getindex(A::Arr, i::Int, j::Symbolic{<:Integer}) - wrap(unwrap(A)[i, j]) +function *(x1::Arr, x2::BasicSymbolic{VartypeT}, args...) + return *(unwrap(x1), x2, args...) end - -function getindex(A::Arr, j::Symbolic{<:Integer}, i::Int) - wrap(unwrap(A)[j, i]) +function *(x1::Arr, x2::Arr, args...) + return *(unwrap(x1), unwrap(x2), args...) end - -function _matmul(A, B) - A = inner_unwrap(A) - B = inner_unwrap(B) - @syms i::Int j::Int k::Int - if isadjointvec(A) - op = operation(A.term) - return op(op(B) * first(arguments(A.term))) - end - return @arrayop (i, j) A[i, k] * B[k, j] term = (A * B) +function *(x1::Arr, x2::Arr, x3::Arr, args...) + return *(unwrap(x1), unwrap(x2), unwrap(x3), args...) end - -@wrapped (*)(A::AbstractMatrix, B::AbstractMatrix) = _matmul(A, B) false -@wrapped (*)(A::AbstractVector, B::AbstractMatrix) = _matmul(A, B) false -# Resolve ambiguity with Base.*(::Adjoint{T, <:AbstractVector} where T, ::AbstractMatrix) -(*)(x::Adjoint{T, <:AbstractVector} where {T}, A::Symbolics.Arr{<:Any, 2}) = wrap(_matmul(unwrap(x), unwrap(A))) - -function _matvec(A, b) - A = inner_unwrap(A) - b = inner_unwrap(b) - @syms i::Int k::Int - sym_res = @arrayop (i,) A[i, k] * b[k] term=(A*b) - if isdot(A, b) - return sym_res[1] - else - return sym_res - end +function *(x1::Arr, x2::Arr, x3::Arr, x4::Arr, args...) + return *(unwrap(x1), unwrap(x2), unwrap(x3), unwrap(x4), args...) end -@wrapped (*)(A::AbstractMatrix, b::AbstractVector) = _matvec(A, b) false - -# specialize `dot` to dispatch on `Symbolic{<:Number}` to eventually work for -# arrays of (possibly unwrapped) Symbolic types, see issue #831 -@wrapped LinearAlgebra.dot(x::Number, y::Number) = conj(x) * y false - -#################### MAP-REDUCE ################ -# -@wrapped Base.map(f, x::AbstractArray) = _map(f, x) false -@wrapped Base.map(f, x::AbstractArray, xs...) = _map(f, x, xs...) false -@wrapped Base.map(f, x, y::AbstractArray, z...) = _map(f, x, y, z...) false -@wrapped Base.map(f, x, y, z::AbstractArray, w...) = _map(f, x, y, z, w...) false - -function _map(f, x, xs...) - N = ndims(x) - idx = makesubscripts(N) - x = inner_unwrap(x) - xs = inner_unwrap.(xs) - - expr = f(map(a -> a[idx...], [x, xs...])...) - - Atype = propagate_atype(map, f, x, xs...) - ArrayOp(Atype{symtype(expr),N}, - (idx...,), - expr, - +, - Term{Any}(map, [f, x, xs...])) +function +(x::Arr, args...) + return +(unwrap(x), args...) end -@inline _mapreduce(f, g, x, dims, kw) = mapreduce(f, g, x; dims=dims, kw...) - -function scalarize_op(::typeof(_mapreduce), t) - f, g, x, dims, kw = arguments(t) - # we wrap and unwrap to make things work smoothly. - # we need the result unwrapped to allow recursive scalarize to work. - unwrap(_mapreduce(f, g, collect(wrap(x)), dims, kw)) +function +(x1::Arr, x2::Arr, args...) + return +(unwrap(x1), unwrap(x2), args...) end -@wrapped function Base.mapreduce(f, g, x::AbstractArray; dims=:, kw...) - idx = makesubscripts(ndims(x)) - out_idx = [dims == (:) || i in dims ? 1 : idx[i] for i = 1:ndims(x)] - expr = f(x[idx...]) - T = symtype(g(expr, expr)) - if dims === (:) - return Term{T}(_mapreduce, [f, g, x, dims, (kw...,)]) - end - - Atype = propagate_atype(_mapreduce, f, g, x, dims, (kw...,)) - ArrayOp(Atype{T,ndims(x)}, - (out_idx...,), - expr, - g, - Term{Any}(_mapreduce, [f, g, x, dims, (kw...,)])) -end false - -for (ff, opts) in [ - any => (identity, (|), false), - all => (identity, (&), true)] - - f, g, init = opts - @eval @wrapped function (::$(typeof(ff)))(x::AbstractArray; - dims=:, init=$init) - mapreduce($f, $g, x, dims=dims, init=init) - end false - @eval @wrapped function (::$(typeof(ff)))(f::Function, x::AbstractArray; - dims=:, init=$init) - mapreduce(f, $g, x, dims=dims, init=init) - end false +function +(a::SymbolicUtils.PolyadicNumericOpFirstArgT, b::Arr, bs...) + return +(a, unwrap(b), bs...) end -for (ff, opts) in [sum => (identity, +), - prod => (identity, *)] +#################### MAP-REDUCE ################ - f, g = opts - @eval @wrapped function (::$(typeof(ff)))(x::AbstractArray; kw...) - mapreduce($f, $g, x; kw...) - end false - @eval @wrapped function (::$(typeof(ff)))(f::Function, x::AbstractArray; kw...) - mapreduce(f, $g, x; kw...) - end false -end +SymbolicUtils.@map_methods Arr unwrap wrap +SymbolicUtils.@mapreduce_methods Arr unwrap wrap diff --git a/src/arrays.jl b/src/arrays.jl index 4d60edd3e..2e83d47b6 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -1,491 +1,7 @@ using SymbolicUtils -using SymbolicUtils: @capture using StaticArraysCore import Base: eltype, length, ndims, size, axes, eachindex -export @arrayop, ArrayMaker, @makearray, @setview, @setview! - -### Store Shape as a metadata in Term{<:AbstractArray} objects -struct ArrayShapeCtx end - - -#= - There are 2 types of array terms: - `ArrayOp{T<:AbstractArray}` and `Term{<:AbstractArray}` - - - ArrayOp represents a Einstein-notation-inspired array operation. - it contains a field `term` which is a `Term` that represents the - operation that resulted in the `ArrayOp`. - I.e. will be `A*b` for the operation `(i,) => A[i,j] * b[j]` for example. - It can be `nothing` if not known. - - calling `shape` on an `ArrayOp` will return the shape of the array or `Unknown()` - - do not rely on the `symtype` or `shape` information of the `.term` when looking at an `ArrayOp`. - call `shape`, `symtype` and `ndims` directly on the `ArrayOp`. - - `Term{<:AbstractArray}` - - calling `shape` on it will return the shape of the array or `Unknown()`, and uses - the `ArrayShapeCtx` metadata context to store this. - - The Array type parameter must contain the dimension. -=# - -#### ArrayOp #### - -""" - ArrayOp(output_idx, expr, reduce) - -A tensor expression where `output_idx` are the output indices -`expr`, is the tensor expression and `reduce` is the function -used to reduce over contracted dimensions. -""" -struct ArrayOp{T<:AbstractArray} <: Symbolic{T} - output_idx # output indices - expr # Used in pattern matching - # Useful to infer eltype - reduce - term - shape - ranges::Dict{BasicSymbolic, AbstractRange} # index range each index symbol can take, - # optional for each symbol - metadata -end - -function ArrayOp(T::Type, output_idx, expr, reduce, term, ranges=Dict(); metadata=nothing) - sh = make_shape(output_idx, unwrap(expr), ranges) - ArrayOp{T}(output_idx, unwrap(expr), reduce, term, sh, ranges, metadata) -end - -function ArrayOp(a::AbstractArray) - i = makesubscripts(ndims(a)) - # TODO: formalize symtype(::Type) then! - ArrayOp(symtype(a), (i...,), a[i...], +, a) -end - -ConstructionBase.constructorof(s::Type{<:ArrayOp{T}}) where {T} = ArrayOp{T} - -function SymbolicUtils.maketerm(::Type{<:ArrayOp}, f, args, m) - args = map(args) do arg - if iscall(arg) && operation(arg) == Ref - inner = only(arguments(arg)) - if symbolic_type(inner) == NotSymbolic() - return Ref(inner) - else - return inner - end - else - return arg - end - end - - t = f(args...) - t isa Symbolic && !isnothing(m) ? - metadata(t, m) : t -end - -SymbolicUtils.sorted_arguments(s::ArrayOp) = sorted_arguments(s.term) - -shape(aop::ArrayOp) = aop.shape - -SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.ArrayOp}) = ArraySymbolic() - -const show_arrayop = Ref{Bool}(false) -function Base.show(io::IO, aop::ArrayOp) - if iscall(aop.term) && !show_arrayop[] - show(io, aop.term) - else - print(io, "@arrayop") - print(io, "(_[$(join(string.(aop.output_idx), ","))] := $(aop.expr))") - if aop.reduce != + - print(io, " ($(aop.reduce))") - end - - if !isempty(aop.ranges) - print(io, " ", join(["$k in $v" for (k, v) in aop.ranges], ", ")) - end - end -end - -Base.summary(io::IO, aop::ArrayOp) = Base.array_summary(io, aop, shape(aop)) -function Base.showarg(io::IO, aop::ArrayOp, toplevel) - show(io, aop) - toplevel && print(io, "::", typeof(aop)) - return nothing -end - -symtype(a::ArrayOp{T}) where {T} = T -iscall(a::ArrayOp) = true -function operation(a::ArrayOp) - isnothing(a.term) ? typeof(a) : operation(a.term) -end -function arguments(a::ArrayOp) - isnothing(a.term) ? [a.output_idx, a.expr, a.reduce, - a.term, a.shape, a.ranges, metadata(a)] : - arguments(a.term) -end - -function Base.isequal(a::ArrayOp, b::ArrayOp) - a === b && return true - isequal(a.shape, b.shape) && - isequal(a.ranges, b.ranges) && - isequal(a.output_idx, b.output_idx) && - isequal(a.reduce, b.reduce) && - isequal(operation(a), operation(b)) && - isequal(a.expr, b.expr) -end - -function Base.hash(a::ArrayOp, u::UInt) - hash(a.shape, hash(a.ranges, hash(a.expr, hash(a.output_idx, hash(operation(a), u))))) -end - -macro arrayop(output_idx, expr, options...) - rs = [] - reduce = + - call = nothing - - extra = [] - isexpr = MacroTools.isexpr - for o in options - if isexpr(o, :call) && o.args[1] == :in - push!(rs, :($(o.args[2]) => $(o.args[3]))) - elseif isexpr(o, :(=)) && o.args[1] == :reduce - reduce = o.args[2] - elseif isexpr(o, :(=)) && o.args[1] == :term - call = o.args[2] - else - push!(extra, o) - end - end - if length(extra) == 1 - @warn("@arrayop is deprecated, use @arrayop term= instead") - call = output_idx - output_idx = expr - expr = extra[1] - end - @assert output_idx.head == :tuple - - oidxs = filter(x->x isa Symbol, output_idx.args) - iidxs = find_indices(expr) - idxs = union(oidxs, iidxs) - fbody = call2term(deepcopy(expr)) - oftype(x,T) = :($x::$T) - aop = gensym("aop") - quote - let - @syms $(map(x->oftype(x, Int), idxs)...) - - expr = $fbody - #TODO: proper Atype - $ArrayOp(Array{$symtype(expr), - $(length(output_idx.args))}, - $output_idx, - expr, - $reduce, - $(call2term(call)), - Dict($(rs...))) - - end - end |> esc -end - -const SymArray = Union{ArrayOp, Symbolic{<:AbstractArray}} -const SymMat = Union{ArrayOp{<:AbstractMatrix}, Symbolic{<:AbstractMatrix}} -const SymVec = Union{ArrayOp{<:AbstractVector}, Symbolic{<:AbstractVector}} - -### Propagate ### -# -## Shape ## - - -function axis_in(a, b) - first(a) >= first(b) && last(a) <= last(b) -end - -function make_shape(output_idx, expr, ranges=Dict()) - matches = idx_to_axes(expr) - for (sym, ms) in matches - to_check = filter(m->!(shape(m.A) isa Unknown), ms) - # Only check known dimensions. It may be "known symbolically" - isempty(to_check) && continue - restricted = false - if haskey(ranges, sym) - ref_axis = ranges[sym] - restricted = true - else - ref_axis = axes(first(to_check).A, first(to_check).dim) - end - reference = ref_axis - for i in (restricted ? 1 : 2):length(ms) - m = ms[i] - s=shape(m.A) - if s !== Unknown() - if restricted - if !axis_in(ref_axis, axes(m.A, m.dim)) - throw(DimensionMismatch("expected $(ref_axis) to be within axes($(m.A), $(m.dim))")) - end - elseif !isequal(axes(m.A, m.dim), reference) - throw(DimensionMismatch("expected axes($(m.A), $(m.dim)) = $(reference)")) - end - end - end - end - - sz = map(output_idx) do i - if issym(i) - if haskey(ranges, i) - return axes(ranges[i], 1) - end - if !haskey(matches, i) - error("There was an error processing arrayop expression $expr.\n" * - "Dimension of output index $i in $output_idx could not be inferred") - end - mi = matches[i] - @assert !isempty(mi) - ext = get_extents(mi) - ext isa Unknown && return Unknown() - return 1:(length(ext)) - elseif i isa Integer - return 1:(1) - end - end - # TODO: maybe we can remove this restriction? - if any(x->x isa Unknown, sz) - Unknown() - else - sz - end -end - - -function ranges(a::ArrayOp) - rs = Dict{BasicSymbolic, Any}() - ax = idx_to_axes(a.expr) - for i in keys(ax) - if haskey(a.ranges, i) - rs[i] = a.ranges[i] - else - rs[i] = ax[i] #get_extents(ax[i]) - end - end - return rs -end - -## Eltype ## - -eltype(aop::ArrayOp) = symtype(aop.expr) - -## Ndims ## -function ndims(aop::ArrayOp) - length(aop.output_idx) -end - - -### Utils ### - - -# turn `f(x...)` into `term(f, x...)` -# -function call2term(expr, arrs=[]) - !(expr isa Expr) && return :($unwrap($expr)) - if expr.head == :call - if expr.args[1] == :(:) - return expr - end - return Expr(:call, term, map(call2term, expr.args)...) - elseif expr.head == :ref - return Expr(:ref, call2term(expr.args[1]), expr.args[2:end]...) - elseif expr.head == Symbol("'") - return Expr(:call, term, adjoint, map(call2term, expr.args)...) - end - - return Expr(expr.head, map(call2term, expr.args)...) -end - -# Find all symbolic indices in expr -function find_indices(expr, idxs=[]) - !(expr isa Expr) && return idxs - if expr.head == :ref - return append!(idxs, filter(x->x isa Symbol, expr.args[2:end])) - elseif expr.head == :call && expr.args[1] == :getindex || expr.args[1] == getindex - return append!(idxs, filter(x->x isa Symbol, expr.args[3:end])) - else - foreach(x->find_indices(x, idxs), expr.args) - return idxs - end -end - -struct AxisOf - A - dim - boundary -end - -function Base.get(a::AxisOf) - @oops shape(a.A) - axes(a.A, a.dim) -end - -function get_extents(xs) - boundaries = map(x->x.boundary, xs) - if all(iszero∘wrap, boundaries) - get(first(xs)) - else - ii = findfirst(x->issym(x) || iscall(x), boundaries) - if !isnothing(ii) - error("Could not find the boundary from symbolic index $(xs[ii]). Please manually specify the range of indices.") - end - extent = get(first(xs)) - start_offset = -reduce(min, filter(x->x<0, boundaries), init=0) - end_offset = reduce(max, filter(x->x>0, boundaries), init=0) - - (first(extent) + start_offset):(last(extent) - end_offset) - end -end - -get_extents(x::AbstractRange) = x - -## Walk expr looking for symbols used in getindex expressions -# Returns a dictionary of Sym to a vector of AxisOf objects. -# The vector has as many elements as the number of times the symbol -# appears in the expr. AxisOf has three fields: -# A: the array whose indexing it appears in -# dim: The dimension of the array indexed -# boundary: how much padding is this indexing requiring, for example -# boundary is 2 for x[i + 2], and boundary = -2 for x[i - 2] -function idx_to_axes(expr, dict=Dict{Any, Vector}(), ranges=Dict()) - if iscall(expr) - if operation(expr) === (getindex) - args = arguments(expr) - for (axis, idx_expr) in enumerate(@views args[2:end]) - if issym(idx_expr) || iscall(idx_expr) - vs = get_variables(idx_expr) - isempty(vs) && continue - sym = only(get_variables(idx_expr)) - axesvec = Base.get!(() -> [], dict, sym) - push!(axesvec, AxisOf(first(args), axis, idx_expr - sym)) - end - end - else - idx_to_axes(operation(expr), dict) - foreach(ex->idx_to_axes(ex, dict), arguments(expr)) - end - end - dict -end - - -#### Term{<:AbstractArray} -# - -""" - array_term(f, args...; - container_type = propagate_atype(f, args...), - eltype = propagate_eltype(f, args...), - size = map(length, propagate_shape(f, args...)), - ndims = propagate_ndims(f, args...)) - -Create a term of `Term{<: AbstractArray}` which -is the representation of `f(args...)`. - -Default arguments: -- `container_type=propagate_atype(f, args...)` - the container type, - i.e. `Array` or `StaticArray` etc. -- `eltype=propagate_eltype(f, args...)` - the output element type. -- `size=map(length, propagate_shape(f, args...))` - the - output array size. `propagate_shape` returns a tuple of index ranges. -- `ndims=propagate_ndims(f, args...)` the output dimension. - -`propagate_shape`, `propagate_atype`, `propagate_eltype` may -return `Unknown()` to say that the output cannot be determined -""" -function array_term(f, args...; - container_type = propagate_atype(f, args...), - eltype = propagate_eltype(f, args...), - size = Unknown(), - ndims = size !== Unknown() ? length(size) : propagate_ndims(f, args...), - shape = size !== Unknown() ? Tuple(map(x->1:x, size)) : propagate_shape(f, args...)) - - if container_type == Unknown() - # There's always a fallback for this - container_type = propagate_atype(f, args...) - end - - if eltype == Unknown() - eltype = Base.propagate_eltype(container_type) - end - - if ndims == Unknown() - ndims = if shape == Unknown() - Any - else - length(shape) - end - end - S = container_type{eltype, ndims} - setmetadata(Term{S}(f, Any[args...]), ArrayShapeCtx, shape) -end - -""" - shape(s::Any) - -Returns `axes(s)` or `Unknown()`. -""" -shape(s) = axes(s) - -""" - shape(s::SymArray) - -Extract the shape metadata from a SymArray. -If not known, returns `Unknown()` -""" -function shape(s::Symbolic{<:AbstractArray}) - if hasmetadata(s, ArrayShapeCtx) - getmetadata(s, ArrayShapeCtx) - else - Unknown() - end -end - -## `propagate_` interface: -# used in the `array_term` construction. - -atype(::Type{<:Array}) = Array -atype(::Type{<:SArray}) = SArray -atype(::Type) = AbstractArray - -_propagate_atype(::Type{T}, ::Type{T}) where {T} = T -_propagate_atype(::Type{<:Array}, ::Type{<:SArray}) = Array -_propagate_atype(::Type{<:SArray}, ::Type{<:Array}) = Array -_propagate_atype(::Any, ::Any) = AbstractArray -_propagate_atype(T) = T -_propagate_atype() = AbstractArray - -function propagate_atype(f, args...) - As = [atype(symtype(T)) - for T in Iterators.filter(x->x <: Symbolic{<:AbstractArray}, typeof.(args))] - if length(As) <= 1 - _propagate_atype(As...) - else - foldl(_propagate_atype, As) - end -end - -function propagate_eltype(f, args...) - As = [eltype(symtype(T)) - for T in Iterators.filter(x->symtype(x) <: AbstractArray, args)] - promote_type(As...) -end - -function propagate_ndims(f, args...) - if propagate_shape(f, args...) == Unknown() - error("Could not determine the output dimension of $f$args") - else - length(propagate_shape(f, args...)) - end -end - -function propagate_shape(f, args...) - error("Don't know how to propagate shape for $f$args") -end - ### Wrapper type for dispatch @symbolic_wrap struct Arr{T,N} <: AbstractArray{T, N} @@ -494,10 +10,8 @@ end Base.hash(x::Arr, u::UInt) = hash(unwrap(x), u) Base.isequal(a::Arr, b::Arr) = isequal(unwrap(a), unwrap(b)) -Base.isequal(a::Arr, b::Symbolic) = isequal(unwrap(a), b) -Base.isequal(a::Symbolic, b::Arr) = isequal(b, a) - -ArrayOp(x::Arr) = unwrap(x) +Base.isequal(a::Arr, b::BasicSymbolic) = isequal(unwrap(a), b) +Base.isequal(a::BasicSymbolic, b::Arr) = isequal(b, a) function Arr(x) A = symtype(x) @@ -505,14 +19,8 @@ function Arr(x) Arr{maybewrap(eltype(A)), ndims(A)}(x) end -const ArrayLike{T,N} = Union{ - ArrayOp{AbstractArray{T,N}}, - Symbolic{AbstractArray{T,N}}, - Arr{T,N}, - SymbolicUtils.Term{AbstractArray{T, N}} -} # Like SymArray but includes Arr and Term{Arr} - -unwrap(x::Arr) = x.value +SymbolicUtils.unwrap(x::Arr) = x.value +SymbolicUtils.symtype(x::Arr) = symtype(unwrap(x)) maybewrap(T) = has_symwrapper(T) ? wrapper_type(T) : T # These methods allow @wrapped methods to be more specific and not overwrite @@ -528,494 +36,74 @@ function Base.show(io::IO, arr::Arr) iscall(x) && print(io, "(") print(io, unwrap(arr)) iscall(x) && print(io, ")") - if !(shape(x) isa Unknown) - print(io, "[", join(string.(axes(arr)), ","), "]") + print(io, "[") + if shape(x) isa SymbolicUtils.Unknown + print(io, shape(x)) + else + print(io, join(string.(axes(arr)), ",")) end + print(io, "]") end Base.show(io::IO, ::MIME"text/plain", arr::Arr) = show(io, arr) ################# Base array functions -# - -# basic -# these methods are not symbolic but work if we know this info. - -geteltype(s::SymArray) = geteltype(symtype(s)) +geteltype(s::BasicSymbolic) = geteltype(symtype(s)) geteltype(::Type{<:AbstractArray{T}}) where {T} = T geteltype(::Type{<:AbstractArray}) = Unknown() -ndims(s::SymArray) = ndims(symtype(s)) ndims(::Type{<:Arr{<:Any, N}}) where N = N -function eltype(A::Union{Arr, SymArray}) - T = geteltype(unwrap(A)) - T === Unknown() && error("eltype of $A not known") - return T -end - -function length(A::Union{Arr, SymArray}) - s = shape(unwrap(A)) - s === Unknown() && error("length of $A not known") - return prod(length, s) -end - -function size(A::Union{Arr, SymArray}) - s = shape(unwrap(A)) - s === Unknown() && error("size of $A not known") - return length.(s) -end - -function size(A::SymArray, i::Integer) - @assert(i > 0) - i > ndims(A) ? 1 : size(A)[i] -end - -function axes(A::Union{Arr, SymArray}) - s = shape(unwrap(A)) - s === Unknown() && error("axes of $A not known") - return s -end - - -function axes(A::SymArray, i) - s = shape(A) - s === Unknown() && error("axes of $A not known") - return i <= length(s) ? s[i] : 1:(1) -end - -function eachindex(A::Union{Arr, SymArray}) - s = shape(unwrap(A)) - s === Unknown() && error("eachindex of $A not known") - return CartesianIndices(s) -end +Base.eltype(A::Arr) = geteltype(unwrap(A)) +Base.length(A::Arr) = length(unwrap(A)) +Base.size(A::Arr) = size(unwrap(A)) +Base.axes(A::Arr) = axes(unwrap(A)) +Base.eachindex(A::Arr) = eachindex(unwrap(A)) function get_variables!(vars, e::Arr, varlist=nothing) foreach(x -> get_variables!(vars, x, varlist), collect(e)) vars end - -### Scalarize - -scalarize(a::Array) = map(scalarize, a) -scalarize(term::Symbolic{<:AbstractArray}, idx) = term[idx...] -val2num(::Val{n}) where n = n - -function replace_by_scalarizing(ex, dict) - rule = @rule(getindex(~x, ~~i) => - scalarize(~x, (map(j->substitute(j, dict), ~~i)...,))) - - function rewrite_operation(x) - if iscall(x) && iscall(operation(x)) - f = operation(x) - ff = replace_by_scalarizing(f, dict) - if metadata(x) !== nothing - maketerm(typeof(x), ff, arguments(x), metadata(x)) - else - ff(arguments(x)...) - end - else - nothing - end - end - - prewalk_if(x->!(x isa ArrayOp || x isa ArrayMaker), - Rewriters.PassThrough(Chain([rewrite_operation, rule])), - ex) -end - -function prewalk_if(cond, f, t) - t′ = cond(t) ? f(t) : return t - if iscall(t′) - if metadata(t′) !== nothing - return maketerm(typeof(t′), TermInterface.head(t′), - map(x->prewalk_if(cond, f, x), children(t′)), metadata(t′)) - else - TermInterface.head(t′)(map(x->prewalk_if(cond, f, x), children(t′))...) - end - else - return t′ - end -end - -function scalarize(arr::AbstractArray, idx) - arr[idx...] -end - -function scalarize(arr, idx) - if iscall(arr) - scalarize_op(operation(arr), arr, idx) - else - error("scalarize is not defined for $arr at idx=$idx") - end -end - -scalarize_op(f, arr) = arr - -struct ScalarizeCache end - -function scalarize_op(f, arr, idx) - if hasmetadata(arr, ScalarizeCache) && getmetadata(arr, ScalarizeCache)[] !== nothing - getmetadata(arr, ScalarizeCache)[][idx...] - else - # wrap and unwrap to call generic methods - thing = unwrap(f(scalarize.(map(wrap, arguments(arr)))...)) - if metadata(arr) != nothing - # forward any metadata - try - thing = metadata(thing, metadata(arr)) - catch err - @warn "could not attach metadata of subexpression $arr to the scalarized form at idx" - end - end - if !hasmetadata(arr, ScalarizeCache) - arr = setmetadata(arr, ScalarizeCache, Ref{Any}(nothing)) - end - getmetadata(arr, ScalarizeCache)[] = thing - thing[idx...] - end -end - -@wrapped function Base.:(\)(A::AbstractMatrix, b::AbstractVecOrMat) - t = array_term(\, A, b) - setmetadata(t, ScalarizeCache, Ref{Any}(nothing)) -end false - -@wrapped function Base.inv(A::AbstractMatrix) - t = array_term(inv, A) - setmetadata(t, ScalarizeCache, Ref{Any}(nothing)) -end false - -_det(x, lp) = det(x, laplace=lp) - -function scalarize_op(f::typeof(_det), arr) - unwrap(det(map(wrap, collect(arguments(arr)[1])), laplace=arguments(arr)[2])) -end - -@wrapped function LinearAlgebra.det(x::AbstractMatrix; laplace=true) - Term{eltype(x)}(_det, [x, laplace]) -end false - - -# A * x = b -# A ∈ R^(m x n) x ∈ R^(n, k) = b ∈ R^(m, k) -propagate_ndims(::typeof(\), A, b) = ndims(b) -propagate_ndims(::typeof(inv), A) = ndims(A) - -# A(m,k) * B(k,n) = C(m,n) -# A(m,k) \ C(m,n) = B(k,n) -function propagate_shape(::typeof(\), A, b) - if ndims(b) == 1 - (axes(A,2),) - else - (axes(A,2), axes(b, 2)) - end -end - -function propagate_shape(::typeof(inv), A) - @oops shp = shape(A) - @assert ndims(A) == 2 && reverse(shp) == shp "Inv called on a non-square matrix" - shp -end - -function scalarize(arr::ArrayOp, idx) - @assert length(arr.output_idx) == length(idx) - - axs = ranges(arr) - - iidx = collect(keys(axs)) - contracted = setdiff(iidx, arr.output_idx) - - axes = [get_extents(axs[c]) for c in contracted] - summed = if isempty(contracted) - arr.expr - else - mapreduce(arr.reduce, Iterators.product(axes...)) do idx - replace_by_scalarizing(arr.expr, Dict(contracted .=> idx)) - end - end - - dict = Dict(oi => (unwrap(i) isa Symbolic ? unwrap(i) : get_extents(axs[oi])[i]) - for (oi, i) in zip(arr.output_idx, idx) if unwrap(oi) isa Symbolic) - - replace_by_scalarizing(summed, dict) -end - -scalarize(arr::Arr, idx) = wrap(scalarize(unwrap(arr), - unwrap.(idx))) - - -eval_array_term(op, arr) = arr -eval_array_term(op::typeof(inv), arr) = inv(scalarize(wrap(arguments(arr)[1]))) #issue 653 -eval_array_term(op::Arr) = wrap(eval_array_term(unwrap(op))) -eval_array_term(op) = eval_array_term(operation(op), op) - -""" - $(TYPEDSIGNATURES) - -Replace all occurrences of array symbolics (variables and subexpressions) in expression -`arr` with scalarized variants. -""" -function scalarize(arr) - if arr isa Arr || arr isa Symbolic{<:AbstractArray} - if iscall(arr) - arr = eval_array_term(arr) - end - map(Iterators.product(axes(arr)...)) do i - scalarize(arr[i...]) # Use arr[i...] here to trigger any getindex hooks - end - elseif iscall(arr) && operation(arr) == getindex - args = arguments(arr) - scalarize(args[1], (args[2:end]...,)) - elseif arr isa Num - wrap(scalarize(unwrap(arr))) - elseif iscall(arr) && symtype(arr) <: Number - t = maketerm(typeof(arr), operation(arr), map(scalarize, arguments(arr)), metadata(arr)) - iscall(t) ? scalarize_op(operation(t), t) : t - else - arr - end -end - -@wrapped Base.isempty(x::AbstractArray) = shape(unwrap(x)) !== Unknown() && _iszero(length(x)) false -Base.collect(x::Arr) = scalarize(x) -Base.collect(x::SymArray) = scalarize(x) -isarraysymbolic(x) = unwrap(x) isa Symbolic && SymbolicUtils.symtype(unwrap(x)) <: AbstractArray +# cannot use `@wrapped` since it will define `\(::BasicSymbolic, ::BasicSymbolic)` +# and because `\(::Arr, ::BasicSymbolic)` will be ambiguous. +for (T1, T2) in [ + (Arr{<:Any, 2}, Arr{<:Any, 1}), + (Arr{<:Any, 2}, Arr{<:Any, 2}), + (Arr{<:Any, 2}, BasicSymbolic{SymReal}), + (Arr{<:Any, 2}, BasicSymbolic{SafeReal}), + (Arr{<:Any, 2}, BasicSymbolic{TreeReal}), + (BasicSymbolic{SymReal}, Arr{<:Any, 1}), + (BasicSymbolic{SafeReal}, Arr{<:Any, 1}), + (BasicSymbolic{TreeReal}, Arr{<:Any, 1}), + (BasicSymbolic{SymReal}, Arr{<:Any, 2}), + (BasicSymbolic{SafeReal}, Arr{<:Any, 2}), + (BasicSymbolic{TreeReal}, Arr{<:Any, 2}), +] + @eval function Base.:(\)(A::$T1, b::$T2) + unwrap(A) \ unwrap(b) + end +end + +Base.inv(A::Arr{<:Any, 2}) = wrap(inv(unwrap(A))) +LinearAlgebra.det(A::Arr{<:Any, 2}) = wrap(det(unwrap(A))) +LinearAlgebra.adjoint(A::Arr{<:Any, 2}) = wrap(adjoint(unwrap(A))) +LinearAlgebra.adjoint(A::Arr{<:Any, 1}) = wrap(adjoint(unwrap(A))) + +SymbolicUtils.scalarize(x::Arr) = SymbolicUtils.scalarize(unwrap(x)) + +Base.isempty(x::Arr) = isempty(unwrap(x)) +Base.collect(x::Arr) = wrap.(collect(unwrap(x))) +isarraysymbolic(x) = false +# this should be validated in the constructor +isarraysymbolic(x::Arr) = true +isarraysymbolic(x::BasicSymbolic) = symtype(x) <: AbstractArray Base.convert(::Type{<:Array{<:Any, N}}, arr::Arr{<:Any, N}) where {N} = scalarize(arr) - -### Stencils - -struct ArrayMaker{T, AT<:AbstractArray} <: Symbolic{AT} - shape - sequence - metadata -end - -function ArrayMaker(a::ArrayLike; eltype=eltype(a)) - ArrayMaker{eltype}(size(a), Any[axes(a) => a]) -end - -function arraymaker(T, shape, views, seq...) - ArrayMaker{T}(shape, [(views .=> seq)...]) -end - -iscall(x::ArrayMaker) = true -operation(x::ArrayMaker) = arraymaker -arguments(x::ArrayMaker) = [eltype(x), shape(x), map(first, x.sequence), map(last, x.sequence)...] - -shape(am::ArrayMaker) = am.shape - -function ArrayMaker{T}(sz::NTuple{N, Integer}, seq::Array=[]; atype=Array, metadata=nothing) where {N,T} - ArrayMaker{T, atype{T, N}}(map(x->1:x, sz), seq, metadata) -end - -(::Type{ArrayMaker{T}})(i::Int...; atype=Array) where {T} = ArrayMaker{T}(i, atype=atype) - -function Base.show(io::IO, ac::ArrayMaker) - print(io, Expr(:call, :ArrayMaker, ac.shape, - Expr(:block, ac.sequence...))) -end - -function get_indexers(ex) - @assert ex.head == :ref - arr = ex.args[1] - args = map(((i,x),)->x == Symbol(":") ? :(1:lastindex($arr, $i)) : x, enumerate(ex.args[2:end])) - replace_ends(arr, args) -end - -function search_and_replace(expr, key, val) - isequal(expr, key) && return val - - expr isa Expr ? - Expr(expr.head, map(x->search_and_replace(x, key,val), expr.args)...) : - expr -end - -function replace_ends(arr, idx) - [search_and_replace(ix, :end, :(lastindex($arr, $i))) - for (i, ix) in enumerate(idx)] -end - -macro setview!(definition, arrayop) - setview(definition, arrayop, true) -end - -macro setview(definition, arrayop) - setview(definition, arrayop, false) -end - -output_index_ranges(c::CartesianIndices) = c.indices -output_index_ranges(ix...) = ix - -function setview(definition, arrayop, inplace) - output_view = get_indexers(definition) - output_ref = definition.args[1] - - function check_assignment(vw, op) - try Base.Broadcast.broadcast_shape(map(length, vw), size(op)) - catch err - if err isa DimensionMismatch - throw(DimensionMismatch("setview did not work while assigning indices " * - "$vw to $op. LHS has size $(map(length, vw)) "* - "and RHS has size $(size(op)) " * - "-- they need to be broadcastable.")) - else - rethrow(err) - end - end - end - - function push(inplace) - if inplace - function (am, vw, op) - check_assignment(vw, op) - # assert proper size match - push!(am.sequence, vw => op) - am - end - else - function (am, vw, op) - check_assignment(vw, op) - if am isa ArrayMaker - typeof(am)(am.shape, vcat(am.sequence, vw => op)) - else - am = ArrayMaker(am) - push!(am.sequence, vw => op) - am - end - end - end - end - quote - $(push(inplace))($output_ref, - $output_index_ranges($(output_view...)), $unwrap($arrayop)) - end |> esc -end - -macro makearray(definition, sequence) - output_shape = get_indexers(definition) - output_name = definition.args[1] - - seq = map(filter(x->!(x isa LineNumberNode), sequence.args)) do pair - @assert pair.head == :call && pair.args[1] == :(=>) - # TODO: make sure the same symbol is used for the lhs array - :(@setview! $(pair.args[2]) $(pair.args[3])) - end - - quote - $output_name = $ArrayMaker{Real}(map(length, ($(output_shape...),))) - $(seq...) - $output_name = $wrap($output_name) - end |> esc -end - -function best_order(output_idx, ks, rs) - unique!(filter(issym, vcat(reverse(output_idx)..., collect(ks)))) -end - -function _cat(x, xs...; dims) - arrays = (x, xs...) - if dims isa Integer - sz = Base.cat_size_shape(Base.dims2cat(dims), arrays...) - T = reduce(promote_type, eltype.(xs), init=eltype(x)) - newdim = cumsum(map(a->size(a, dims), arrays)) - start = 1 - A = ArrayMaker{T}(sz...) - for (dim, array) in zip(newdim, arrays) - idx = CartesianIndices(ntuple(n -> n==dims ? - (start:dim) : (1:sz[n]), length(sz))) - start = dim + 1 - - @setview! A[idx] array - end - return A - else - error("Block diagonal concatenation not supported") - end -end - -# Base.cat(x::Arr, xs...; dims) = _cat(x, xs...; dims) -# Base.cat(x::AbstractArray, y::Arr, xs...; dims) = _cat(x, y, xs...; dims) - -# vv uncomment these for a major release -# Base.vcat(x::Arr, xs::AbstractVecOrMat...) = cat(x, xs..., dims=1) -# Base.hcat(x::Arr, xs::AbstractVecOrMat...) = cat(x, xs..., dims=2) -# Base.vcat(x::AbstractVecOrMat, y::Arr, xs::AbstractVecOrMat...) = _cat(x, y, xs..., dims=1) -# Base.hcat(x::AbstractVecOrMat, y::Arr, xs::AbstractVecOrMat...) = _cat(x, y, xs..., dims=2) -# Base.vcat(x::Arr, y::Arr) = _cat(x, y, dims=1) # plug ambiguity -# Base.hcat(x::Arr, y::Arr) = _cat(x, y, dims=2) - -function scalarize(x::ArrayMaker) - T = eltype(x) - A = Array{wrapper_type(T)}(undef, size(x)) - for (vw, arr) in x.sequence - if any(x->x isa AbstractArray, vw) - A[vw...] .= scalarize(arr) - else - A[vw...] = scalarize(arr) - end - end - A -end - -function scalarize(x::ArrayMaker, idx) - for (vw, arr) in reverse(x.sequence) # last one wins - if any(x->issym(x) || iscall(x), idx) - return term(getindex, x, idx...) - end - if all(in.(idx, vw)) - if symtype(arr) <: AbstractArray - # Filter out non-array indices because the RHS will be one dim less - el = [searchsortedfirst(v, i) - for (v, i) in zip(vw, idx) if v isa AbstractArray] - return scalarize(arr[el...]) - else - return arr - end - end - end - if !any(x->issym(x) || iscall(x), idx) && all(in.(idx, axes(x))) - throw(UndefRefError()) - end - - throw(BoundsError(x, idx)) -end - - -### Codegen - -function SymbolicUtils.Code.toexpr(x::ArrayOp, st) - haskey(st.rewrites, x) && return st.rewrites[x] - - if iscall(x.term) - toexpr(x.term, st) - else - _array_toexpr(x, st) - end -end - function SymbolicUtils.Code.toexpr(x::Arr, st) toexpr(unwrap(x), st) end -function SymbolicUtils.Code.toexpr(x::ArrayMaker, st) - _array_toexpr(x, st) -end - -function _array_toexpr(x, st) - outsym = Symbol("_out") - N = length(shape(x)) - ex = Let( - [ - Assignment(outsym, term(zeros, Float64, term(map, length, shape(x)))), - Assignment(Symbol("%$outsym"), inplace_expr(x, outsym)) - ], outsym, false) - - toexpr(ex, st) -end - """ $(TYPEDSIGNATURES) @@ -1027,31 +115,15 @@ end function inplace_expr(x, out_array, intermediates = nothing) x = unwrap(x) - if symtype(x) <: Number + if SymbolicUtils.isarrayop(x) + return x + elseif symtype(x) <: Number return term(broadcast_assign!, out_array, x) else return term(copy!, out_array, x) end end -function inplace_expr(x::ArrayMaker, out, intermediates = nothing) - steps = Assignment[] - - _intermediates = something(intermediates, Dict()) - for (i, (vw, op)) in enumerate(x.sequence) - out′ = Symbol(out, "_", i) - push!(steps, Assignment(out′, term(view, out, vw...))) - push!(steps, Assignment(Symbol("%$out′"), inplace_expr(unwrap(op), out′, _intermediates))) - end - - expr = Let(steps, nothing, false) - if intermediates === nothing && !isempty(_intermediates) - steps = [map(k -> Assignment(_intermediates[k], k), collect(keys(_intermediates))); steps] - expr = Let(steps, nothing, false) - end - return expr -end - function inplace_expr(x::AbstractArray, out, intermediates = nothing) expr = SetArray(false, out, x, true) if intermediates !== nothing @@ -1069,85 +141,8 @@ function inplace_builtin(term, outsym) return nothing end -function find_inter(acc, expr) - if !issym(expr) && symtype(expr) <: AbstractArray - push!(acc, expr) - elseif iscall(expr) - foreach(x -> find_inter(acc, x), arguments(expr)) - end - acc -end - -function get_inputs(x::ArrayOp) - unique(find_inter([], x.expr)) -end - -function similar_arrayvar(ex, name) - Sym{symtype(ex)}(name) #TODO: shape? -end - -function reset_to_one(range) - @assert step(range) == 1 - 1:(length(range)) -end - -function reset_sym(i) - Sym{Int}(Symbol(nameof(i), "′")) -end - -function inplace_expr(x::ArrayOp, outsym = :_out, intermediates = nothing) - if x.term !== nothing - ex = inplace_builtin(x.term, outsym) - if ex !== nothing - return ex - end - end - - inters = filter(!issym, get_inputs(x)) - intermediate_exprs = map(enumerate(inters)) do (i, ex) - if intermediates !== nothing - if haskey(intermediates, ex) - return ex => intermediates[ex] - else - sym = similar_arrayvar(ex, Symbol(outsym, :_input_, i)) - intermediates[ex] = sym - return ex => sym - end - else - return ex => similar_arrayvar(ex, Symbol(outsym, :_input_, i)) - end - end - - - rs = copy(ranges(x)) - loops = best_order(x.output_idx, keys(rs), rs) - expr = substitute(unwrap(x.expr), Dict(intermediate_exprs)) - - out_idxs = map(reset_sym, x.output_idx) - inner_expr = SetArray(false, outsym, [AtIndex(term(CartesianIndex, out_idxs...), term(x.reduce, term(getindex, outsym, out_idxs...), expr))]) - - loops = foldl(reverse(loops), init=inner_expr) do acc, k - if any(isequal(k), x.output_idx) - k′ = reset_sym(k) - loopvar = Symbol("%$k$k′") - ext = get_extents(rs[k]) - if isone(first(ext)) && isone(step(ext)) - ext = Base.OneTo(last(ext)) - end - ForLoop(loopvar, term(zip, ext, term(reset_to_one, ext)), Let([DestructuredArgs([k, reset_sym(k)], loopvar)], acc, false)) - else - ForLoop(k, get_extents(rs[k]), acc) - end - end - - if intermediates === nothing && !isempty(intermediate_exprs) - return Let(map(x -> Assignment(x[2], x[1]), intermediate_exprs), loops, false) - end - return loops -end - hasnode(r::Function, y::Arr) = _hasnode(r, y) -hasnode(r::Union{Num, Symbolic, Arr}, y::Arr) = occursin(unwrap(r), unwrap(y)) +hasnode(r::Union{Num, BasicSymbolic, Arr}, y::Arr) = occursin(unwrap(r), unwrap(y)) #= """ diff --git a/src/build_function.jl b/src/build_function.jl index cc00280cc..c58feea06 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -148,7 +148,7 @@ end SymbolicUtils.Code.get_rewrites(x::Arr) = SymbolicUtils.Code.get_rewrites(unwrap(x)) -function _build_function(target::JuliaTarget, op::Union{Arr, ArrayOp, SymbolicUtils.BasicSymbolic{<:AbstractArray}}, args...; +function _build_function(target::JuliaTarget, op::Union{Arr, SymbolicUtils.BasicSymbolic{<:AbstractArray}}, args...; conv = toexpr, expression = Val{true}, expression_module = @__MODULE__(), diff --git a/src/latexify_recipes.jl b/src/latexify_recipes.jl index ad2de9f48..22e70323e 100644 --- a/src/latexify_recipes.jl +++ b/src/latexify_recipes.jl @@ -77,13 +77,6 @@ end return :($(recipe(z.re)) + $(recipe(z.im)) * $im) end -@latexrecipe function f(n::ArrayOp) - env --> :equation - mult_symbol --> "" - index --> :subscript - return recipe(n.term) -end - @latexrecipe function f(n::Function) env --> :equation mult_symbol --> "" @@ -143,8 +136,6 @@ Base.show(io::IO, ::MIME"text/latex", x::Equation) = print(io, "\$\$ " * latexif Base.show(io::IO, ::MIME"text/latex", x::Vector{Equation}) = print(io, "\$\$ " * latexify(x) * " \$\$") Base.show(io::IO, ::MIME"text/latex", x::AbstractArray{<:RCNum}) = print(io, "\$\$ " * latexify(x) * " \$\$") -_toexpr(O::ArrayOp; latexwrapper = default_latex_wrapper) = _toexpr(O.term; latexwrapper) - # `_toexpr` is only used for latexify function _toexpr(O; latexwrapper = default_latex_wrapper) if ismul(O) diff --git a/src/utils.jl b/src/utils.jl index 98aa732a5..dfc84cd56 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -639,4 +639,6 @@ function evaluate(ineq::Inequality, subs) end end - +vartype_from_args(::BasicSymbolic{T}, args...) where {T} = T +vartype_from_args(_, args...) = vartype_from_args(args...) +vartype_from_args() = error("Cannot infer `vartype`.") diff --git a/src/variable.jl b/src/variable.jl index ed7f54b32..f34af7171 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -192,6 +192,7 @@ end function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, prop, transform, isruntime) ndim = :($length(($(indices...),))) + shape = :($(SymbolicUtils.ShapeVecT)(($(indices...),))) if call_args_are_function(call_args) vname, fntype = function_name_and_type(lhs) # name was already unwrapped before calling this function and is of the form $x @@ -208,7 +209,7 @@ function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, end end argtypes = arg_types_from_call_args(call_args) - ex = :($Sym{$FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}}($_vname)) + ex = :($Sym{$FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}}($_vname; shape = $shape)) else vname = lhs if isruntime @@ -218,7 +219,6 @@ function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, end ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}}}($_vname)(map($unwrap, ($(call_args...),))...)) end - ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) if val !== nothing ex = :($setdefaultval($ex, $val)) @@ -394,22 +394,20 @@ end function _construct_array_vars(macroname, var_name, type, call_args, val, prop, indices...) # TODO: just use Sym here ndim = :($length(($(indices...),))) + shape = :($(SymbolicUtils.ShapeVecT)(($(indices...),))) need_scalarize = false expr = if call_args === nothing - ex = :($Sym{Array{$type, $ndim}}($var_name)) - :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) + ex = :($Sym{Array{$type, $ndim}}($var_name; shape = $shape)) elseif call_args_are_function(call_args) need_scalarize = true var_name, fntype = function_name_and_type(var_name) argtypes = arg_types_from_call_args(call_args) - ex = :($Sym{Array{$FnType{$argtypes, $type, $(fntype...)}, $ndim}}($var_name)) - :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) + ex = :($Sym{Array{$FnType{$argtypes, $type, $(fntype...)}, $ndim}}($var_name; shape = $shape)) else # [(R -> R)(R) ....] need_scalarize = true - ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name)) - ex = :($setmetadata($ex, $ArrayShapeCtx, ($(indices...),))) + ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name; shape = $shape)) :($map($CallWith(($(call_args...),)), $ex)) end @@ -806,16 +804,6 @@ function rename_metadata(from, to, name) end rename(x::Union{Num, Arr}, name) = wrap(rename(unwrap(x), name)) -function rename(x::ArrayOp, name) - t = x.term - args = arguments(t) - # Hack: - @assert operation(t) === (map) && args[1] isa CallWith - rn = rename(args[2], name) - - xx = metadata(operation(t)(args[1], rn), metadata(x)) - rename_getindex_source(rename_metadata(x, xx, name)) -end function rename(x::BasicSymbolic, name) if issym(x) From dde5ad0adf4078f35092bb6ce3db7ca2f6e1fe87 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 15 Sep 2025 19:56:34 +0530 Subject: [PATCH 05/28] refactor: remove redundant `symbolic_type` methods --- src/variable.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index f34af7171..ef0e06575 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -516,14 +516,6 @@ getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val) SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() -function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: BasicSymbolic{S}} - ArraySymbolic() -end -# need this otherwise the `::Type{<:BasicSymbolic}` method in SymbolicUtils is -# more specific -function SymbolicIndexingInterface.symbolic_type(::Type{T}) where {S <: AbstractArray, T <: BasicSymbolic{S}} - ArraySymbolic() -end SymbolicIndexingInterface.hasname(x::Union{Num,Arr,Complex{Num}}) = hasname(unwrap(x)) From 1cc3fd6bc93bb99ac20f08f33b742f7c79930201 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Mon, 15 Sep 2025 19:56:44 +0530 Subject: [PATCH 06/28] refactor: remove usages of `Polyform` --- src/solver/main.jl | 1 - src/utils.jl | 25 ++++++++++--------------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/src/solver/main.jl b/src/solver/main.jl index 9082a9518..74b3c84d6 100644 --- a/src/solver/main.jl +++ b/src/solver/main.jl @@ -281,7 +281,6 @@ function solve_univar(expression, x; dropmultiplicity=true, strict=true) args = [] mult_n = 1 expression = unwrap(expression) - expression = expression isa PolyForm ? SymbolicUtils.toterm(expression) : expression # handle multiplicities (repeated roots), i.e. (x+1)^20 if iscall(expression) diff --git a/src/utils.jl b/src/utils.jl index dfc84cd56..3d1ab044b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -441,9 +441,7 @@ end const DP = DynamicPolynomials # extracting underlying polynomial and coefficient type from Polyforms underlyingpoly(x::Number) = x -underlyingpoly(pf::PolyForm) = pf.p coefftype(x::Number) = typeof(x) -coefftype(pf::PolyForm) = DP.coefficient_type(underlyingpoly(pf)) #= Converts an array of symbolic polynomials @@ -456,31 +454,28 @@ function symbol_to_poly(sympolys::AbstractArray) stdsympolys = map(unwrap, sympolys) sort!(stdsympolys, lt=(<ₑ)) - pvar2sym = Bijections.Bijection{Any,Any}() - sym2term = Dict{BasicSymbolic,Any}() - polyforms = map(f -> PolyForm(f, pvar2sym, sym2term), stdsympolys) + symidx = findfirst(x -> x isa BasicSymbolic, stdsympolys) + varT = vartype(stdsympolys[symidx]) + + poly_to_bs = Bijections.Bijection{SymbolicUtils.PolyVarT, BasicSymbolic{varT}}() + bs_to_poly = Bijections.active_inv(poly_to_bs) + polyforms = Vector{SymbolicUtils.PolynomialT}(map(f -> SymbolicUtils.to_poly!(poly_to_bs, bs_to_poly, f), stdsympolys)) # Discover common coefficient type commontype = mapreduce(coefftype, promote_type, polyforms, init=Int) @assert commontype <: Union{Integer,Rational} "Only integer and rational coefficients are supported as input." - # Convert all to DP.Polynomial, so that coefficients are of same type, - # and constants are treated as polynomials - # We also need this because Groebner does not support abstract types as input - polynoms = Vector{DP.Polynomial{DP.Commutative{DP.CreationOrder},DP.Graded{DP.LexOrder},commontype}}(undef, length(sympolys)) - for (i, pf) in enumerate(polyforms) - polynoms[i] = underlyingpoly(pf) - end + polynoms = polyforms - polynoms, pvar2sym, sym2term + polynoms, poly_to_bs end #= Converts an array of AbstractPolynomialLike`s into an array of symbolic expressions mapping variables w.r.t pvar2sym =# -function poly_to_symbol(polys, pvar2sym, sym2term, ::Type{T}) where {T} - map(f -> PolyForm{T}(f, pvar2sym, sym2term), polys) +function poly_to_symbol(polys, poly_to_bs, ::Type{T}) where {T} + map(Base.Fix1(SymbolicUtils.from_poly, poly_to_bs), polys) end """ From 7950e983cd7dd626adb3cf9050b49b5531e52f4b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 16:46:48 +0530 Subject: [PATCH 07/28] feat: make `vartype` of variables a preference --- Project.toml | 2 ++ src/Symbolics.jl | 15 +++++++++++++++ 2 files changed, 17 insertions(+) diff --git a/Project.toml b/Project.toml index 7d1f81085..387e02df1 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" @@ -90,6 +91,7 @@ OffsetArrays = "1.15.0" PkgBenchmark = "0.2" PreallocationTools = "0.4" PrecompileTools = "1" +Preferences = "1.5.0" Primes = "0.5" RecipesBase = "1.1" Reexport = "1" diff --git a/src/Symbolics.jl b/src/Symbolics.jl index ebd7834c3..c2ddce89e 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -54,6 +54,21 @@ RuntimeGeneratedFunctions.init(@__MODULE__) import SciMLPublic: @public +import Preferences: @load_preference + +const DEFAULT_VARTYPE_PREF = @load_preference("vartype", "SymReal") +const VartypeT = @static if DEFAULT_VARTYPE_PREF == "SymReal" + SymReal +elseif DEFAULT_VARTYPE_PREF == "SafeReal" + SafeReal +elseif DEFAULT_VARTYPE_PREF == "TreeReal" + TreeReal +else + error(""" + Invalid vartype preference: $DEFAULT_VARTYPE_PREF. Must be one of "SymReal", \ + "SafeReal" or "TreeReal". + """) +end # re-export export simplify, substitute From 741c711a5dab392ca20d10f5cd48c3222c8ff6f8 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 16:47:57 +0530 Subject: [PATCH 08/28] refactor: use `@syms` parsing for `@variables` --- Project.toml | 2 + src/Symbolics.jl | 6 +- src/utils.jl | 6 +- src/variable.jl | 463 +++++++++++------------------------------------ 4 files changed, 114 insertions(+), 363 deletions(-) diff --git a/Project.toml b/Project.toml index 387e02df1..a7230190d 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -85,6 +86,7 @@ Latexify = "0.16" LogExpFunctions = "0.3" Lux = "1" MacroTools = "0.5" +Moshi = "0.3.7" NaNMath = "1" Nemo = "0.46, 0.47, 0.48, 0.49, 0.52" OffsetArrays = "1.15.0" diff --git a/src/Symbolics.jl b/src/Symbolics.jl index c2ddce89e..b3d84fe64 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -27,8 +27,8 @@ using TermInterface import TermInterface: maketerm, iscall, operation, arguments, metadata import SymbolicUtils: Term, Add, Mul, Sym, Div, BasicSymbolic, Const, -FnType, @rule, Rewriters, substitute, symtype, -promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv + FnType, @rule, Rewriters, substitute, symtype, shape, + promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv, BSImpl using SymbolicUtils.Code @@ -54,6 +54,8 @@ RuntimeGeneratedFunctions.init(@__MODULE__) import SciMLPublic: @public +using Moshi.Match: @match + import Preferences: @load_preference const DEFAULT_VARTYPE_PREF = @load_preference("vartype", "SymReal") diff --git a/src/utils.jl b/src/utils.jl index 3d1ab044b..e1f9d09cf 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -13,11 +13,7 @@ function flatten_expr!(x) end xs end -function build_expr(head::Symbol, args) - ex = Expr(head) - append!(ex.args, args) - ex -end + """ get_variables(e, varlist = nothing; sort::Bool = false) diff --git a/src/variable.jl b/src/variable.jl index ef0e06575..776c3718d 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -28,50 +28,7 @@ Symbolic metadata key for storing the macro used to create a symbolic variable. """ struct VariableSource <: AbstractVariableMetadata end -function recurse_and_apply(f, x) - if symtype(x) <: AbstractArray - getindex_posthook(x) do r,x,i... - recurse_and_apply(f, r) - end - else - f(x) - end -end - -function set_scalar_metadata(x, V, val) - val = unwrap(val) - if val isa AbstractArray - val = unwrap.(val) - end - if symtype(x) <: AbstractArray - x = if val isa AbstractArray - getindex_posthook(x) do r,x,i... - set_scalar_metadata(r, V, val[i...]) - end - else - getindex_posthook(x) do r,x,i... - set_scalar_metadata(r, V, val) - end - end - end - setmetadata(x, V, val) -end -setdefaultval(x, val) = set_scalar_metadata(x, VariableDefaultValue, val) - -struct GetindexParent end - -function scalarize_getindex(x, parent=Ref{Any}(x)) - if symtype(x) <: AbstractArray - parent[] = getindex_posthook(x) do r,x,i... - scalarize_getindex(r, parent) - end - else - xx = unwrap(scalarize(x)) - xx = metadata(xx, metadata(x)) - setmetadata(xx, GetindexParent, parent[]) - end -end - +setdefaultval(x, val) = setmetadata(x, VariableDefaultValue, val) function map_subscripts(indices) str = string(indices) @@ -79,12 +36,6 @@ function map_subscripts(indices) end -function unwrap_runtime_var(v) - isruntime = Meta.isexpr(v, :$) && length(v.args) == 1 - isruntime && (v = v.args[1]) - return isruntime, v -end - # Build variables more easily """ $(TYPEDSIGNATURES) @@ -97,11 +48,9 @@ macro. `transform` is an optional function that takes constructed variables and custom postprocessing to them, returning the created variables. This function returns the `Expr` for constructing the parsed variables. """ -parse_vars(macroname, type, x, transform=identity) = _parse_vars(macroname, type, x, transform=identity) - -function _parse_vars(macroname, type, x, transform=identity) +function parse_vars(macroname, type, x, transform = identity) ex = Expr(:block) - var_names = Symbol[] + var_names = Expr(:vect) # if parsing things in the form of # begin # x @@ -114,167 +63,98 @@ function _parse_vars(macroname, type, x, transform=identity) isoption(ex) = Meta.isexpr(ex, [:vect, :vcat, :hcat]) while cursor < length(x) cursor += 1 - v = x[cursor] - - # We need lookahead to the next `v` to parse - # `@variables x [connect=Flow,unit=u]` - nv = cursor < length(x) ? x[cursor+1] : nothing - val = unit = connect = options = nothing + var_expr = x[cursor] - # x = 1, [connect = flow; unit = u"m^3/s"] - if Meta.isexpr(v, :(=)) - v, val = v.args + default = nothing + options = nothing + if Meta.isexpr(var_expr, :(=)) + var_expr, default = var_expr.args # defaults with metadata for function variables - if Meta.isexpr(val, :block) - Base.remove_linenums!(val) - val = only(val.args) + if Meta.isexpr(default, :block) + Base.remove_linenums!(default) + default = only(default.args) end - if Meta.isexpr(val, :tuple) && length(val.args) == 2 && isoption(val.args[2]) - options = val.args[2].args - val = val.args[1] + if Meta.isexpr(default, :tuple) && length(default.args) == 2 && isoption(default.args[2]) + options = default.args[2].args + default = default.args[1] end + default = esc(default) end - - type′ = type - - if Meta.isexpr(v, :(::)) - v, type′ = v.args - type′ = type′ === :Complex ? Complex{type} : type′ + parse_result = SymbolicUtils.parse_variable(var_expr; default_type = type) + + sym = SymbolicUtils.sym_from_parse_result(parse_result, VartypeT) + + # is a function call and the function doesn't have a type and all arguments + # are named + if parse_result_is_dependent_variable(parse_result) + args = parse_result[:args] + argnames = Any[get(arg, :name, nothing) for arg in args] + # if the last arg is a `Vararg`, splat it + if !isempty(args) && Meta.isexpr(args[end][:type], :curly) && args[end][:type].args[1] == :Vararg + argnames[end] = Expr(:..., argnames[end]) + end + # Turn the result into something of the form `@variables x(..)`. + # This makes it so that the `FnType` is recognized as a dependent variable + # according to `SymbolicUtils.is_function_symtype` + parse_result[:args] = [SymbolicUtils.parse_variable(:(..); default_type = type)] + parse_result[:type].args[2] = Tuple + # Re-create the `Sym` + sym = SymbolicUtils.sym_from_parse_result(parse_result, VartypeT) + # Call the `Sym` with the arguments to create a dependent variable. + map!(esc, argnames, argnames) + sym = Expr(:call, sym) + append!(sym.args, argnames) end - - # x [connect = flow; unit = u"m^3/s"] - if isoption(nv) - options = nv.args + if options === nothing && cursor < length(x) && isoption(x[cursor + 1]) + options = x[cursor + 1].args cursor += 1 end + sym = _add_metadata(parse_result, sym, default, macroname, options) + sym = Expr(:call, wrap, sym) - isruntime, v = unwrap_runtime_var(v) - iscall = Meta.isexpr(v, :call) - isarray = Meta.isexpr(v, :ref) - if iscall && Meta.isexpr(v.args[1], :ref) && !call_args_are_function(map(last∘unwrap_runtime_var, @view v.args[2:end])) - @warn("The variable syntax $v is deprecated. Use $(Expr(:ref, Expr(:call, v.args[1].args[1], v.args[2]), v.args[1].args[2:end]...)) instead. - The former creates an array of functions, while the latter creates an array valued function. - The deprecated syntax will cause an error in the next major release of Symbolics. - This change will facilitate better implementation of various features of Symbolics.") - end - issym = v isa Symbol - @assert iscall || isarray || issym "@$macroname expects a tuple of expressions or an expression of a tuple (`@$macroname x y z(t) v[1:3] w[1:2,1:4]` or `@$macroname x y z(t) v[1:3] w[1:2,1:4] k=1.0`)" - - if isarray && Meta.isexpr(v.args[1], :call) - # This is the new syntax - isruntime, fname = unwrap_runtime_var(v.args[1].args[1]) - call_args = map(last∘unwrap_runtime_var, @view v.args[1].args[2:end]) - size = v.args[2:end] - var_name, expr = construct_dep_array_vars(macroname, fname, type′, call_args, size, val, options, transform, isruntime) - elseif iscall - isruntime, fname = unwrap_runtime_var(v.args[1]) - call_args = map(last∘unwrap_runtime_var, @view v.args[2:end]) - var_name, expr = construct_vars(macroname, fname, type′, call_args, val, options, transform, isruntime) - elseif isarray - var_name, expr = construct_vars(macroname, v, type′, nothing, val, options, transform, isruntime) + if parse_result[:isruntime] + varname = Symbol(parse_result[:name]) else - var_name, expr = construct_vars(macroname, v, type′, nothing, val, options, transform, isruntime) + varname = esc(parse_result[:name]) end - - push!(var_names, var_name) - push!(ex.args, expr) + push!(var_names.args, varname) + push!(ex.args, Expr(:(=), varname, sym)) end - rhs = build_expr(:vect, var_names) - push!(ex.args, rhs) + push!(ex.args, var_names) return ex end -call_args_are_function(_) = false -function call_args_are_function(call_args::AbstractArray) - !isempty(call_args) && (call_args[end] == :(..) || all(Base.Fix2(Meta.isexpr, :(::)), call_args)) +function parse_result_is_dependent_variable(parse_result) + # This means it is a function call + return haskey(parse_result, :args) && + # This checks `fnT` in `FnType{argsT, retT, fnT}` + parse_result[:type].args[4] === Nothing && + # This ensures all arguments have defined names + all(n -> n !== nothing && n != :.., + (get(arg, :name, nothing) for arg in parse_result[:args])) end -function construct_dep_array_vars(macroname, lhs, type, call_args, indices, val, prop, transform, isruntime) - ndim = :($length(($(indices...),))) - shape = :($(SymbolicUtils.ShapeVecT)(($(indices...),))) - if call_args_are_function(call_args) - vname, fntype = function_name_and_type(lhs) - # name was already unwrapped before calling this function and is of the form $x - if isruntime - _vname = vname - else - # either no ::fnType or $x::fnType - vname, fntype = function_name_and_type(lhs) - isruntime, vname = unwrap_runtime_var(vname) - if isruntime - _vname = vname - else - _vname = Meta.quot(vname) - end - end - argtypes = arg_types_from_call_args(call_args) - ex = :($Sym{$FnType{$argtypes, Array{$type, $ndim}, $(fntype...)}}($_vname; shape = $shape)) - else - vname = lhs - if isruntime - _vname = vname - else - _vname = Meta.quot(vname) - end - ex = :($Sym{$FnType{Tuple, Array{$type, $ndim}}}($_vname)(map($unwrap, ($(call_args...),))...)) - end - - if val !== nothing - ex = :($setdefaultval($ex, $val)) +function _add_metadata(parse_result, var::Expr, default, macroname::Symbol, metadata::Union{Nothing, Vector{Any}}) + @nospecialize var default metadata + if default !== nothing + var = Expr(:call, setdefaultval, var, default) end - ex = setprops_expr(ex, prop, macroname, Meta.quot(vname)) - #ex = :($scalarize_getindex($ex)) - - ex = :($wrap($ex)) - - ex = :($transform($ex)) - if isruntime - vname = gensym(Symbol(vname)) - end - vname, :($vname = $ex) -end - -function construct_vars(macroname, v, type, call_args, val, prop, transform, isruntime) - issym = v isa Symbol - isarray = !isruntime && Meta.isexpr(v, :ref) - if isarray - # this can't be an array of functions, since that was handled by `construct_dep_array_vars` - var_name = v.args[1] - if Meta.isexpr(var_name, :(::)) - var_name, type′ = var_name.args - type = type′ === :Complex ? Complex{type} : type′ - end - isruntime, var_name = unwrap_runtime_var(var_name) - indices = v.args[2:end] - expr = _construct_array_vars(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop, indices...) - elseif call_args_are_function(call_args) - var_name, fntype = function_name_and_type(v) - # name was already unwrapped before calling this function and is of the form $x - if isruntime - vname = var_name - else - # either no ::fnType or $x::fnType - var_name, fntype = function_name_and_type(v) - isruntime, var_name = unwrap_runtime_var(var_name) - if isruntime - vname = var_name - else - vname = Meta.quot(var_name) - end - end - expr = construct_var(macroname, fntype == () ? vname : Expr(:(::), vname, fntype[1]), type, call_args, val, prop) + varname = parse_result[:name] + if parse_result[:isruntime] + varname = esc(varname) else - var_name = v - if Meta.isexpr(v, :(::)) - var_name, type′ = v.args - type = type′ === :Complex ? Complex{type} : type′ - end - expr = construct_var(macroname, isruntime ? var_name : Meta.quot(var_name), type, call_args, val, prop) + varname = Meta.quot(varname) + end + var = Expr(:call, setmetadata, var, VariableSource, Expr(:tuple, Meta.quot(macroname), varname)) + metadata === nothing && return var + for ex in metadata + Meta.isexpr(ex, :(=)) || error("Metadata must of the form of `key = value`") + key, value = ex.args + key_type = option_to_metadata_type(Val{key}())::DataType + var = Expr(:call, setmetadata, var, key_type, esc(value)) end - lhs = isruntime ? gensym(Symbol(var_name)) : var_name - rhs = :($transform($expr)) - lhs, :($lhs = $rhs) + return var end """ @@ -309,122 +189,6 @@ option_to_metadata_type(::Val{:_____!_internal_3}) = error("Invalid option") option_to_metadata_type(::Val{:_____!_internal_4}) = error("Invalid option") option_to_metadata_type(::Val{:_____!_internal_5}) = error("Invalid option") -function setprops_expr(expr, props, macroname, varname) - expr = :($setmetadata($expr, $VariableSource, ($(Meta.quot(macroname)), $varname,))) - isnothing(props) && return expr - for opt in props - if !Meta.isexpr(opt, :(=)) - throw(Base.Meta.ParseError( - "Variable properties must be in " * - "the form of `a = b`. Got $opt.")) - end - - lhs, rhs = opt.args - - @assert lhs isa Symbol "the lhs of an option must be a symbol" - expr = :($set_scalar_metadata($expr, - $(option_to_metadata_type(Val{lhs}())), - $rhs)) - end - expr -end - -function arg_types_from_call_args(call_args) - if length(call_args) == 1 && only(call_args) == :.. - return Tuple - end - Ts = map(call_args) do arg - if arg == :.. - Vararg - elseif arg isa Expr && arg.head == :(::) - if length(arg.args) == 1 - arg.args[1] - elseif arg.args[1] == :.. - :(Vararg{$(arg.args[2])}) - else - arg.args[2] - end - else - error("Invalid call argument $arg") - end - end - return :(Tuple{$(Ts...)}) -end - -function function_name_and_type(var_name) - if var_name isa Expr && Meta.isexpr(var_name, :(::), 2) - var_name.args[1], (var_name.args[2],) - else - var_name, () - end -end - -function construct_var(macroname, var_name, type, call_args, val, prop) - expr = if call_args === nothing - :($Sym{$type}($var_name)) - elseif call_args_are_function(call_args) - # function syntax is (x::TFunc)(.. or ::TArg1, ::TArg2)::TRet - # .. is Vararg - # (..)::ArgT is Vararg{ArgT} - var_name, fntype = function_name_and_type(var_name) - argtypes = arg_types_from_call_args(call_args) - :($Sym{$FnType{$argtypes, $type, $(fntype...)}}($var_name)) - # This elseif handles the special case with e.g. variables on the form - # @variables X(deps...) where deps is a vector (which length might be unknown). - elseif (call_args isa Vector) && (length(call_args) == 1) && (call_args[1] isa Expr) && - call_args[1].head == :(...) && (length(call_args[1].args) == 1) - :($Sym{$FnType{NTuple{$length($(call_args[1].args[1])), Any}, $type}}($var_name)($value.($(call_args[1].args[1]))...)) - else - :($Sym{$FnType{NTuple{$(length(call_args)), Any}, $type}}($var_name)($(map(x->:($value($x)), call_args)...))) - end - - if val !== nothing - expr = :($setdefaultval($expr, $val)) - end - - :($wrap($(setprops_expr(expr, prop, macroname, var_name)))) -end - -struct CallWith - args -end - -(c::CallWith)(f) = unwrap(f(c.args...)) - -function _construct_array_vars(macroname, var_name, type, call_args, val, prop, indices...) - # TODO: just use Sym here - ndim = :($length(($(indices...),))) - shape = :($(SymbolicUtils.ShapeVecT)(($(indices...),))) - - need_scalarize = false - expr = if call_args === nothing - ex = :($Sym{Array{$type, $ndim}}($var_name; shape = $shape)) - elseif call_args_are_function(call_args) - need_scalarize = true - var_name, fntype = function_name_and_type(var_name) - argtypes = arg_types_from_call_args(call_args) - ex = :($Sym{Array{$FnType{$argtypes, $type, $(fntype...)}, $ndim}}($var_name; shape = $shape)) - else - # [(R -> R)(R) ....] - need_scalarize = true - ex = :($Sym{Array{$FnType{Tuple, $type}, $ndim}}($var_name; shape = $shape)) - :($map($CallWith(($(call_args...),)), $ex)) - end - - if val !== nothing - expr = :($setdefaultval($expr, $val)) - end - expr = setprops_expr(expr, prop, macroname, var_name) - if need_scalarize - expr = :($scalarize_getindex($expr)) - end - - expr = :($wrap($expr)) - - return expr -end - - """ Define one or more unknown variables. @@ -489,7 +253,7 @@ julia> (t, a, b, c) ``` """ macro variables(xs...) - esc(_parse_vars(:variables, Real, xs)) + parse_vars(:variables, Real, xs) end const _fail = Dict() @@ -777,55 +541,42 @@ end # x_t # sys.x -function rename_getindex_source(x, parent=x) - getindex_posthook(x) do r,x,i... - hasmetadata(r, GetindexParent) ? setmetadata(r, GetindexParent, parent) : r - end -end - -function rename_metadata(from, to, name) - if hasmetadata(from, VariableSource) - s = getmetadata(from, VariableSource) - to = setmetadata(to, VariableSource, (s[1], name)) - end - if hasmetadata(from, GetindexParent) - s = getmetadata(from, GetindexParent) - to = setmetadata(to, GetindexParent, rename(s, name)) +function renamed_metadata(metadata::Union{Nothing, SymbolicUtils.MetadataT}, name::Symbol) + @nospecialize metadata + if metadata === nothing + return metadata + elseif metadata isa Base.ImmutableDict{DataType, Any} + newmeta = Base.ImmutableDict{DataType, Any}() + for (k, v) in metadata + if k === VariableSource + v = v::NTuple{2, Symbol} + v = (v[1], name) + end + newmeta = Base.ImmutableDict(newmeta, k, v) + end + return newmeta end - return to + error() end rename(x::Union{Num, Arr}, name) = wrap(rename(unwrap(x), name)) -function rename(x::BasicSymbolic, name) - if issym(x) - xx = @set! x.name = name - xx = rename_metadata(x, xx, name) - symtype(xx) <: AbstractArray ? rename_getindex_source(xx) : xx - elseif iscall(x) && operation(x) === getindex - rename(arguments(x)[1], name)[arguments(x)[2:end]...] - elseif iscall(x) && symtype(operation(x)) <: FnType - xx = @set x.f = rename(operation(x), name) - @set! xx.hash = Ref{UInt}(0) - return rename_metadata(x, xx, name) - else - error("can't rename $x to $name") +function rename(x::BasicSymbolic{T}, newname::Symbol) where {T} + @match x begin + BSImpl.Sym(; name, type, shape, metadata) => begin + metadata = renamed_metadata(metadata, newname) + return BSImpl.Sym{T}(newname; type, shape, metadata) + end + BSImpl.Term(; f, args, type, shape, metadata) && if f === getindex end => begin + newargs = copy(parent(args)) + newargs[1] = rename(newargs[1], newname) + return BSImpl.Term{T}(f, newargs; type, shape, metadata) + end + BSImpl.Term(; f, args, type, shape, metadata) && if f isa BasicSymbolic{T} end => begin + f = rename(f, newname) + metadata = renamed_metadata(metadata, newname) + return BSImpl.Term{T}(f, args; type, shape, metadata) + end + _ => error("Cannot rename $x.") end end - -# Deprecation below - -struct Variable{T} end - -function (::Type{Variable{T}})(s, i...) where {T} - Base.depwarn("Variable{T}(name, idx...) is deprecated, use variable(name, idx...; T=T)", :Variable) - variable(s, i...; T=T) -end - -(::Type{Variable})(s, i...) = Variable{Real}(s, i...) - -function (::Type{Sym{T}})(s, x, i...) where {T} - Base.depwarn("Sym{T}(name, x, idx...) is deprecated, use variable(name, x, idx...; T=T)", :Variable) - variable(s, x, i...; T=T) -end -(::Type{Sym})(s, x, i...) = Sym{Real}(s, x, i...) From e49881fcbe10b3b1c79b00147daf911c18670a38 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 16:48:38 +0530 Subject: [PATCH 09/28] refactor: update `@wrapped` macro --- src/wrapper-types.jl | 38 ++++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index a669bfdfa..350f31f20 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -17,8 +17,9 @@ function set_where(subt, supert) Expr(:where, supert, Ts...) end -function SymbolicIndexingInterface.getname(x::Expr) - @assert x.head == :curly +function _get_type_name(x::Union{Symbol, Expr}) + x isa Symbol && return x + @assert Meta.isexpr(x, :curly) return x.args[1] end @@ -26,7 +27,7 @@ macro symbolic_wrap(expr) @assert expr isa Expr && expr.head == :struct @assert expr.args[2].head == :(<:) sig = expr.args[2] - T = getname(sig.args[1]) + T = _get_type_name(sig.args[1]) supertype = set_where(sig.args[1], sig.args[2]) quote @@ -119,23 +120,23 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) # expected to be defined outside Symbolics if arg isa Expr && arg.head == :(::) T = Base.eval(mod, arg.args[2]) - Ts = has_symwrapper(T) ? (T, BasicSymbolic, wrapper_type(T)) : - (T, BasicSymbolic) + Ts = has_symwrapper(T) ? (T, BasicSymbolic{VartypeT}, wrapper_type(T)) : + (T, BasicSymbolic{VartypeT}) if T <: AbstractArray && wrap_arrays eT = eltype(T) if eT == Any eT = Real end _arr_type_fn = if hasmethod(ndims, Tuple{Type{T}}) - (elT) -> AbstractArray{T, ndims(T)} where {T <: elT} + (elT) -> AbstractArray{S, ndims(T)} where {S <: elT} else - (elT) -> AbstractArray{T} where {T <: elT} + (elT) -> AbstractArray{S} where {S <: elT} end if has_symwrapper(eT) - Ts = (Ts..., _arr_type_fn(BasicSymbolic), + Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}}, _arr_type_fn(wrapper_type(eT))) else - Ts = (Ts..., _arr_type_fn(BasicSymbolic)) + Ts = (Ts..., AbstractArray{BasicSymbolic{VartypeT}}) end end Ts @@ -158,20 +159,29 @@ function wrap_func_expr(mod, expr, wrap_arrays = true) :($n::$T) end + any_wrapper = false impl_args = map(enumerate(names)) do (i, name) - is_wrapper_type(Ts[i]) ? :($unwrap($name)) : name + if is_wrapper_type(Ts[i]) + any_wrapper = true + :($unwrap($name)) + elseif Ts[i] <: AbstractArray && is_wrapper_type(Ts[i].var.ub) + any_wrapper = true + :($_recursive_unwrap($name)) + else + name + end end implcall = :($impl_name($self, $(impl_args...))) - if any(is_wrapper_type, Ts) + if any_wrapper implcall = :($wrap($implcall)) end body = Expr(:block) for (i, T) in enumerate(Ts) - if T === BasicSymbolic + if T === BasicSymbolic{VartypeT} push!(body.args, :(@assert $symtype($(names[i])) <: $(types[i][1]))) - else T === AbstractArray{T} where {T <: BasicSymbolic} - push!(body.args, :(@assert $symtype($(names[i])[1]) <: $(types[i][1]))) + elseif T === AbstractArray{BasicSymbolic{VartypeT}} && eltype(types[i][1]) !== Any + push!(body.args, :(@assert $symtype($(names[i])[1]) <: $(eltype(types[i][1])))) end end push!(body.args, implcall) From a50ad29899f30c7ae5d90ef01c5d47f2a5ff3a22 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 16:48:57 +0530 Subject: [PATCH 10/28] refactor: update `@register_symbolic` --- src/register.jl | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/register.jl b/src/register.jl index e946530a5..b473ff0e8 100644 --- a/src/register.jl +++ b/src/register.jl @@ -21,34 +21,37 @@ overwriting. ``` See `@register_array_symbolic` to register functions which return arrays. """ -macro register_symbolic(expr, define_promotion = true, Ts = :([]), wrap_arrays = true) - f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, Ts) +macro register_symbolic(expr, define_promotion = true, wrap_arrays = true) + f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr) args′ = map((a, T) -> :($a::$T), argnames, Ts) ret_type = isnothing(ret_type) ? Real : ret_type - + N = length(args′) + symbolicT = Union{BasicSymbolic{VartypeT}, AbstractArray{BasicSymbolic{VartypeT}}} fexpr = :(Symbolics.@wrapped function $f($(args′...)) - args = [$(argnames...),] - unwrapped_args = map($nested_unwrap, args) - res = if !any($is_symbolic_or_array_of_symbolic, unwrapped_args) - $f(unwrapped_args...) # partial-eval if all args are unwrapped - else - $Term{$ret_type}($f, unwrapped_args) - end - if typeof.(args) == typeof.(unwrapped_args) - return res - else - return $wrap(res) - end - end $wrap_arrays) + args = ($(argnames...),) + if Base.Cartesian.@nany $N i -> args[i] isa $symbolicT + args = Base.Cartesian.@ntuple $N i -> $Const{$VartypeT}(args[i]) + $Term{$VartypeT}($f, $(SymbolicUtils.ArgsT){$VartypeT}(args); type = $ret_type, shape = $(SymbolicUtils.ShapeVecT())) + else + $f($(argnames...)) + end + end $wrap_arrays) if define_promotion fexpr = :($fexpr; (::$typeof($promote_symtype))(::$ftype, args...) = $ret_type) + promote_expr = quote + function (::$(typeof(SymbolicUtils.promote_shape)))(::$ftype, args...) + @nospecialize args + $(SymbolicUtils.ShapeVecT)() + end + end + fexpr = :($fexpr; $promote_expr) end esc(fexpr) end -function destructure_registration_expr(expr, Ts) +function destructure_registration_expr(expr) if expr.head === :(::) ret_type = expr.args[2] expr = expr.args[1] @@ -56,8 +59,6 @@ function destructure_registration_expr(expr, Ts) ret_type = nothing end @assert expr.head === :call - @assert Ts.head === :vect - Ts = Ts.args f = expr.args[1] args = expr.args[2:end] From 9120aa61a0e1ef50ba4b52cdb24f2cfde7dceb08 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 16:49:11 +0530 Subject: [PATCH 11/28] refactor: update `@register_array_symbolic` --- src/register.jl | 72 +++++++++++++++++++++++++++++++------------------ 1 file changed, 46 insertions(+), 26 deletions(-) diff --git a/src/register.jl b/src/register.jl index b473ff0e8..1b772e969 100644 --- a/src/register.jl +++ b/src/register.jl @@ -102,33 +102,46 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs ex.args[1] => ex.args[2] end |> Dict + shape_expr = if haskey(defs, :size) + quote + sz = $(defs[:size]) + nd = length(sz) + sh = $(SymbolicUtils.ShapeVecT)(map(Base.UnitRange{Int} ∘ Base.OneTo, sz)) + end + else + quote + nd = $(get(defs, :ndims, -1)) + sh = $(SymbolicUtils.Unknown)(sh) + end + end + eltype_expr = get(defs, :eltype, Any) + container_type = get(defs, :container_type, Array) args′ = map((a, T) -> :($a::$T), argnames, Ts) + N = length(args′) + symbolicT = Union{BasicSymbolic{VartypeT}, AbstractArray{BasicSymbolic{VartypeT}}} + assigns = macroexpand(@__MODULE__, :(Base.Cartesian.@nexprs $N i -> ($argnames[i] = args[i]))) fexpr = quote @wrapped function $f($(args′...)) - args = [$(argnames...),] - unwrapped_args = map($nested_unwrap, args) - eltype = $symbolic_eltype - res = if !any($is_symbolic_or_array_of_symbolic, unwrapped_args) - $f(unwrapped_args...) # partial-eval if all args are unwrapped - elseif $ret_type == nothing || ($ret_type <: AbstractArray) - $array_term($(Expr(:parameters, [Expr(:kw, k, v) for (k, v) in defs]...)), $f, unwrapped_args...) + args = ($(argnames...),) + if Base.Cartesian.@nany $N i -> args[i] isa $symbolicT + args = Base.Cartesian.@ntuple $N i -> $Const{$VartypeT}(args[i]) + $assigns + $shape_expr + eltype = $eltype ∘ $symtype + type = if nd == -1 + $container_type{$eltype_expr} + else + $container_type{$eltype_expr, nd} + end + $Term{$VartypeT}($f, $(SymbolicUtils.ArgsT){$VartypeT}(args); type, shape = sh) else - $Term{$ret_type}($f, unwrapped_args) - end - - if typeof.(args) == typeof.(unwrapped_args) - return res - else - return $wrap(res) + $f($(argnames...)) end end $wrap_arrays end |> esc if define_promotion - container_type = get(defs, :container_type, :($propagate_atype(f, $(argnames...)))) - etype = get(defs, :eltype, :($propagate_eltype(f, $(argnames...)))) - ndim = get(defs, :ndims, nothing) is_callable_struct = f isa Expr && f.head == :(::) fn_arg = if is_callable_struct f @@ -140,18 +153,25 @@ function register_array_symbolic(f, ftype, argnames, Ts, ret_type, partial_defs else :f end + + shape_args = [:($name::$(SymbolicUtils.ShapeT)) for name in argnames] promote_expr = quote function (::$typeof($promote_symtype))($fn_arg, $(argnames...)) f = $fn_arg_name container_type = $container_type - etype = $etype - $( - if ndim === nothing - :(return container_type{etype}) - else - :(ndim = $ndim; return container_type{etype, ndim}) - end - ) + nd = $(get(defs, :ndims, -1)) + etype = $eltype_expr + if nd == -1 + return container_type{etype} + else + return container_type{etype, nd} + end + end + function (::$(typeof(SymbolicUtils.promote_shape)))($fn_arg, $(shape_args...)) + @nospecialize $(argnames...) + size = identity + $shape_expr + return sh end end |> esc fexpr = :($fexpr; $promote_expr) @@ -193,6 +213,6 @@ overloads for one function, all the rest of the registers must set overwriting. """ macro register_array_symbolic(expr, block, define_promotion = true, wrap_arrays = true) - f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr, :([])) + f, ftype, argnames, Ts, ret_type = destructure_registration_expr(expr) register_array_symbolic(f, ftype, argnames, Ts, ret_type, block, define_promotion, wrap_arrays) end From 54cb468f3bb18996ceb8dc2bb5146611f12b8c6a Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 17:33:33 +0530 Subject: [PATCH 12/28] fix: use `SymbolicUtils.unwrap` --- src/Symbolics.jl | 2 +- src/num.jl | 2 +- src/wrapper-types.jl | 6 ------ 3 files changed, 2 insertions(+), 8 deletions(-) diff --git a/src/Symbolics.jl b/src/Symbolics.jl index b3d84fe64..b273010d8 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -27,7 +27,7 @@ using TermInterface import TermInterface: maketerm, iscall, operation, arguments, metadata import SymbolicUtils: Term, Add, Mul, Sym, Div, BasicSymbolic, Const, - FnType, @rule, Rewriters, substitute, symtype, shape, + FnType, @rule, Rewriters, substitute, symtype, shape, unwrap, unwrap_const, promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv, BSImpl using SymbolicUtils.Code diff --git a/src/num.jl b/src/num.jl index 8a8cf25e6..c5e73756f 100644 --- a/src/num.jl +++ b/src/num.jl @@ -4,7 +4,7 @@ end const RCNum = Union{Num, Complex{Num}} -unwrap(x::Num) = x.val +SymbolicUtils.unwrap(x::Num) = x.val """ Num(val) diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index 350f31f20..db6ae09f9 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -42,12 +42,6 @@ macro symbolic_wrap(expr) end iswrapped(x) = false -""" - $(TYPEDSIGNATURES) - -Return the symbolic or non-symbolic value wrapped by a type such as `Num`. -""" -unwrap(x) = x """ $(TYPEDSIGNATURES) From 667f2d9c8d39b53c42b72d3965af026c4167d45c Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 18:58:44 +0530 Subject: [PATCH 13/28] fix: validate default size in `setdefaultval` --- src/variable.jl | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/src/variable.jl b/src/variable.jl index 776c3718d..6b860b17e 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -28,7 +28,21 @@ Symbolic metadata key for storing the macro used to create a symbolic variable. """ struct VariableSource <: AbstractVariableMetadata end -setdefaultval(x, val) = setmetadata(x, VariableDefaultValue, val) +function setdefaultval(x, val) + sh = shape(x) + if sh isa SymbolicUtils.Unknown + @assert sh.ndims == -1 || ndims(val) == sh.ndims """ + Variable $x must have default of matching `ndims`. Got $val with `ndims` \ + $(ndims(val)). + """ + else + @assert isempty(sh) || symtype(x) <: FnType || size(x) == size(val) """ + Variable $x must have default of matching size. Got $val with size \ + $(size(val)). + """ + end + setmetadata(x, VariableDefaultValue, val) +end function map_subscripts(indices) str = string(indices) From 563a84dbee4d90ee9478854164dcd08bcd9dadb4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 18:59:09 +0530 Subject: [PATCH 14/28] refactor: remove old `_getname`, implement `SII.getname` properly --- src/variable.jl | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) diff --git a/src/variable.jl b/src/variable.jl index 6b860b17e..99d3c5f8e 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -272,30 +272,15 @@ end const _fail = Dict() -_getname(x, _) = nameof(x) -_getname(x::Symbol, _) = x -function _getname(x::BasicSymbolic, val) - issym(x) && return nameof(x) - if iscall(x) && issym(operation(x)) - return nameof(operation(x)) - end - if !hasmetadata(x, Symbolics.GetindexParent) && iscall(x) && operation(x) == getindex - return _getname(arguments(x)[1], val) - end - ss = getsource(x, nothing) - if ss === nothing - ss = getsource(getparent(x), val) - end - ss === _fail && throw(ArgumentError("Variable $x doesn't have a source defined.")) - ss[2] -end - getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val) SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() SymbolicIndexingInterface.hasname(x::Union{Num,Arr,Complex{Num}}) = hasname(unwrap(x)) +function SymbolicIndexingInterface.getname(x::Union{Num, Arr, Complex{Num}}) + SymbolicIndexingInterface.getname(unwrap(x)) +end function SymbolicIndexingInterface.symbolic_evaluate(ex::Union{Num, Arr, BasicSymbolic, Equation, Inequality}, d::Dict; kwargs...) val = fixpoint_sub(ex, d; kwargs...) From 3af56f56f9a05cf1203483c0fb981adcb229d986 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 18:59:40 +0530 Subject: [PATCH 15/28] fix: handle indexed symbolics in `getdefaultval` --- src/variable.jl | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/variable.jl b/src/variable.jl index 99d3c5f8e..0d862eed5 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -484,7 +484,13 @@ function getdefaultval(x, val=_fail) if val !== _fail return val else - error("$x has no default value") + @match x begin + BSImpl.Term(; f, args) && if f === getindex end => begin + idxs = Iterators.map(unwrap_const, Iterators.drop(args, 1)) + return getdefaultval(args[1], val)[idxs...] + end + _ => error("$x has no default value") + end end end From 949e9e730d5de8e8db36aa540b43e0ccfaced9b9 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 16 Sep 2025 18:59:50 +0530 Subject: [PATCH 16/28] fix: fix `Symbolics.variable` --- src/variable.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variable.jl b/src/variable.jl index 0d862eed5..f816c7d3e 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -532,7 +532,7 @@ Also see `variables`. """ function variable(name, idx...; T=Real) name_ij = Symbol(name, join(map_subscripts.(idx), "ˏ")) - v = Sym{T}(name_ij) + v = Sym{VartypeT}(name_ij; type = T) wrap(setmetadata(v, VariableSource, (:variables, name_ij))) end From 9caff1bd14feed6bef175df52efc2daeffdb52a7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 17 Sep 2025 01:23:13 +0530 Subject: [PATCH 17/28] refactor: remove `get_variables!` implementations --- src/arrays.jl | 3 --- src/equations.jl | 1 - src/utils.jl | 57 ------------------------------------------------ 3 files changed, 61 deletions(-) diff --git a/src/arrays.jl b/src/arrays.jl index 2e83d47b6..694611be1 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -59,9 +59,6 @@ Base.size(A::Arr) = size(unwrap(A)) Base.axes(A::Arr) = axes(unwrap(A)) Base.eachindex(A::Arr) = eachindex(unwrap(A)) -function get_variables!(vars, e::Arr, varlist=nothing) - foreach(x -> get_variables!(vars, x, varlist), collect(e)) - vars end # cannot use `@wrapped` since it will define `\(::BasicSymbolic, ::BasicSymbolic)` diff --git a/src/equations.jl b/src/equations.jl index 4f7f300b5..209c337eb 100644 --- a/src/equations.jl +++ b/src/equations.jl @@ -201,7 +201,6 @@ end canonical_form(eq::Equation) = eq.lhs - eq.rhs ~ 0 -get_variables(eq::Equation) = unique(vcat(get_variables(eq.lhs), get_variables(eq.rhs))) struct ConstrainedEquation constraints diff --git a/src/utils.jl b/src/utils.jl index e1f9d09cf..c15855232 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -14,45 +14,6 @@ function flatten_expr!(x) xs end - -""" - get_variables(e, varlist = nothing; sort::Bool = false) - -Return a vector of variables appearing in `e`, optionally restricting to variables in `varlist`. - -Note that the returned variables are not wrapped in the `Num` type. - -# Examples -```jldoctest -julia> @variables t x y z(t); - -julia> Symbolics.get_variables(x + y + sin(z); sort = true) -3-element Vector{SymbolicUtils.BasicSymbolic}: - x - y - z(t) - -julia> Symbolics.get_variables(x - y; sort = true) -2-element Vector{SymbolicUtils.BasicSymbolic}: - x - y -``` -""" -function get_variables(e::Num, varlist = nothing; sort::Bool = false) - get_variables(value(e), varlist; sort) -end -function get_variables(e, varlist = nothing; sort::Bool = false) - vars = Vector{BasicSymbolic}() - get_variables!(vars, e, varlist) - if sort - sort!(vars; by = SymbolicUtils.get_degrees) - end - vars -end - -get_variables!(vars, e::Num, varlist=nothing) = get_variables!(vars, value(e), varlist) -get_variables!(vars, e, varlist=nothing) = vars - function is_singleton(e) if iscall(e) op = operation(e) @@ -64,24 +25,6 @@ function is_singleton(e) end end -get_variables!(vars, e::Number, varlist=nothing) = vars - -function get_variables!(vars, e::BasicSymbolic, varlist=nothing) - if is_singleton(e) - if isnothing(varlist) || any(isequal(e), varlist) - push!(vars, e) - end - else - get_variables!(vars, operation(e), varlist) - foreach(x -> get_variables!(vars, x, varlist), arguments(e)) - end - return (vars isa AbstractVector) ? unique!(vars) : vars -end - -function get_variables!(vars, e::Equation, varlist=nothing) - get_variables!(vars, e.rhs, varlist) -end - """ get_differential_vars(e, varlist = nothing; sort::Bool = false) From 986bb09ee943e9952ab7c7c49f428c3ba1bb85c4 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Wed, 17 Sep 2025 01:23:47 +0530 Subject: [PATCH 18/28] feat: implement `SymbolicUtils.search_variables!` --- src/arrays.jl | 2 ++ src/equations.jl | 4 ++++ src/num.jl | 4 ++++ 3 files changed, 10 insertions(+) diff --git a/src/arrays.jl b/src/arrays.jl index 694611be1..7c1ab85a5 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -59,6 +59,8 @@ Base.size(A::Arr) = size(unwrap(A)) Base.axes(A::Arr) = axes(unwrap(A)) Base.eachindex(A::Arr) = eachindex(unwrap(A)) +function SymbolicUtils.search_variables!(buffer, expr::Arr; kw...) + SymbolicUtils.search_variables!(buffer, unwrap(expr); kw...) end # cannot use `@wrapped` since it will define `\(::BasicSymbolic, ::BasicSymbolic)` diff --git a/src/equations.jl b/src/equations.jl index 209c337eb..248c6177a 100644 --- a/src/equations.jl +++ b/src/equations.jl @@ -201,6 +201,10 @@ end canonical_form(eq::Equation) = eq.lhs - eq.rhs ~ 0 +function SymbolicUtils.search_variables!(buffer, eq::Equation; kw...) + SymbolicUtils.search_variables!(buffer, eq.lhs; kw...) + SymbolicUtils.search_variables!(buffer, eq.rhs; kw...) +end struct ConstrainedEquation constraints diff --git a/src/num.jl b/src/num.jl index c5e73756f..0c9d1b111 100644 --- a/src/num.jl +++ b/src/num.jl @@ -31,6 +31,10 @@ Base.typemin(::Type{Num}) = Num(-Inf) Base.typemax(::Type{Num}) = Num(Inf) Base.float(x::Num) = x +function SymbolicUtils.search_variables!(buffer, expr::Num; kw...) + SymbolicUtils.search_variables!(buffer, unwrap(expr); kw...) +end + """ ifelse(cond::Num, x, y) From 73e446159f91386dfc34a1d05ef8ef389b932a02 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:06:37 +0530 Subject: [PATCH 19/28] refactor: make `Arr` store `BasicSymbolic{T}` --- src/arrays.jl | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/arrays.jl b/src/arrays.jl index 7c1ab85a5..55603cdfc 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -5,7 +5,16 @@ import Base: eltype, length, ndims, size, axes, eachindex ### Wrapper type for dispatch @symbolic_wrap struct Arr{T,N} <: AbstractArray{T, N} - value + value::BasicSymbolic{VartypeT} + + function Arr{T, N}(ex) where {T, N} + if is_wrapper_type(T) + @assert symtype(ex) <: AbstractArray{<:wraps_type(T), N} + else + @assert symtype(ex) <: AbstractArray{T, N} + end + new{T, N}(Const{VartypeT}(ex)) + end end Base.hash(x::Arr, u::UInt) = hash(unwrap(x), u) From a0e9fd3ff80548ff16eda6d895789c4c198bafc7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:09:19 +0530 Subject: [PATCH 20/28] fix: fix `build_function` codegen for arrayop --- src/build_function.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/build_function.jl b/src/build_function.jl index c58feea06..bcb9c70b7 100644 --- a/src/build_function.jl +++ b/src/build_function.jl @@ -171,6 +171,9 @@ function _build_function(target::JuliaTarget, op::Union{Arr, SymbolicUtils.Basic outsym = DEFAULT_OUTSYM if iip_config[2] + if SymbolicUtils.isarrayop(op) && !haskey(states.rewrites, :arrayop_output) + states.rewrites[:arrayop_output] = outsym + end body = inplace_expr(op, outsym) iip_expr = wrap_code[2](Func(vcat(outsym, dargs), [], body)) else From 44080e83b543c7281f0c5df1656532fd64a984cf Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:09:37 +0530 Subject: [PATCH 21/28] fix: fix `scalarize` for `Equation` --- src/equations.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/equations.jl b/src/equations.jl index 248c6177a..199a3c491 100644 --- a/src/equations.jl +++ b/src/equations.jl @@ -127,7 +127,7 @@ function Base.show(io::IO, eq::Equation) end end -scalarize(eq::Equation) = scalarize(eq.lhs) .~ scalarize(eq.rhs) +SymbolicUtils.scalarize(eq::Equation) = scalarize(eq.lhs) .~ scalarize(eq.rhs) SymbolicUtils.simplify(x::Equation; kw...) = simplify(x.lhs; kw...) ~ simplify(x.rhs; kw...) # ambiguity for T in [:Pair, :Any] From 45b25ab414261e7f61e90a2eb172dc1aacfe0b5b Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:09:49 +0530 Subject: [PATCH 22/28] fix: fix `scalarize` for `Inequality` --- src/inequality.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/inequality.jl b/src/inequality.jl index c38296e46..88bda32d5 100644 --- a/src/inequality.jl +++ b/src/inequality.jl @@ -23,7 +23,7 @@ Base.hash(a::Inequality, salt::UInt) = hash(a.lhs, hash(a.rhs, hash(a.relational @enum RelationalOperator leq geq # strict less than or strict greater than are not supported by any solver -function scalarize(ineq::Inequality) +function SymbolicUtils.scalarize(ineq::Inequality) if ineq.relational_op == leq scalarize(ineq.lhs) ≲ scalarize(ineq.rhs) else From 2df01000a59c8b5d0c13406d3412f1d6489db921 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:09:59 +0530 Subject: [PATCH 23/28] fix: store `BasicSymbolic{T}` in `Num` --- src/num.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/num.jl b/src/num.jl index 0c9d1b111..d065029d5 100644 --- a/src/num.jl +++ b/src/num.jl @@ -1,5 +1,10 @@ @symbolic_wrap struct Num <: Real - val::Any + val::BasicSymbolic{VartypeT} + + function Num(ex) + @assert symtype(ex) <: Real + return new(Const{VartypeT}(ex)) + end end const RCNum = Union{Num, Complex{Num}} From 588b0385b63c09b39bf87ae8edd3dd4c9683b450 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:10:54 +0530 Subject: [PATCH 24/28] refactor: do not pirate `Base.Symbol(::BasicSymbolic)` --- src/utils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils.jl b/src/utils.jl index c15855232..7ab02ec1a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -87,7 +87,7 @@ function get_differential_vars!(vars, e::Equation, varlist=nothing) end # Sym / Term --> Symbol -Base.Symbol(x::Union{Num,BasicSymbolic}) = tosymbol(x) +Base.Symbol(x::Num) = Symbol(unwrap(x)) tosymbol(t::Num; kwargs...) = tosymbol(value(t); kwargs...) """ From 8150ecf1cb9ec4bb03ea7774157fd52d4f284424 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:11:09 +0530 Subject: [PATCH 25/28] fix: remove use of deprecated `children` --- src/utils.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index 7ab02ec1a..cdecfb195 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -131,7 +131,7 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi d_separator = 'ˍ' if ds === nothing - return maketerm(typeof(O), TermInterface.head(O), map(diff2term, children(O)), + return maketerm(typeof(O), TermInterface.head(O), map(diff2term, arguments(O)), O_metadata isa Nothing ? metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...)) else @@ -146,7 +146,7 @@ function diff2term(O, O_metadata::Union{Dict, Nothing, Base.ImmutableDict}=nothi error("diff2term case not handled: $oldop") end newname = occursin(d_separator, opname) ? Symbol(opname, ds) : Symbol(opname, d_separator, ds) - return setname(maketerm(typeof(O), rename(oldop, newname), children(O), O_metadata isa Nothing ? + return setname(maketerm(typeof(O), rename(oldop, newname), arguments(O), O_metadata isa Nothing ? metadata(O) : Base.ImmutableDict(metadata(O)..., O_metadata...)), newname) end end From 531c2940f603ea98a7840c9c8eef9c0bd78b6e17 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:11:20 +0530 Subject: [PATCH 26/28] refactor: remove `Symbolics.Unknown` --- src/utils.jl | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/utils.jl b/src/utils.jl index cdecfb195..4fea03cd9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -213,21 +213,6 @@ function lower_varname(var::BasicSymbolic, idv, order) return diff2term(var) end -### OOPS - -struct Unknown end - -macro oops(ex) - quote - tmp = $(esc(ex)) - if tmp === Unknown() - return Unknown() - else - tmp - end - end -end - function makesubscripts(n) set = 'i':'z' m = length(set) From 376a8fa0c4e6500845c9e6454bca6a4d511e9a55 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:11:41 +0530 Subject: [PATCH 27/28] refactor: move `Operator` to `SymbolicUtils` --- src/Symbolics.jl | 3 ++- src/diff.jl | 14 +++++--------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/src/Symbolics.jl b/src/Symbolics.jl index b273010d8..82b5c7095 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -28,7 +28,8 @@ import TermInterface: maketerm, iscall, operation, arguments, metadata import SymbolicUtils: Term, Add, Mul, Sym, Div, BasicSymbolic, Const, FnType, @rule, Rewriters, substitute, symtype, shape, unwrap, unwrap_const, - promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv, BSImpl + promote_symtype, isadd, ismul, ispow, isterm, issym, isdiv, BSImpl, scalarize, + Operator using SymbolicUtils.Code diff --git a/src/diff.jl b/src/diff.jl index 4be9f46f6..c65afa573 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -1,6 +1,3 @@ -abstract type Operator end -propagate_shape(::Operator, x) = axes(x) - """ $(TYPEDEF) @@ -37,19 +34,18 @@ struct Differential <: Operator end function (D::Differential)(x) x = unwrap(x) - if isarraysymbolic(x) - array_term(D, x) - else - term(D, x) - end + term(D, x) end (D::Differential)(x::Union{AbstractFloat, Integer}) = wrap(0) (D::Differential)(x::Union{Num, Arr}) = wrap(D(unwrap(x))) (D::Differential)(x::Complex{Num}) = Complex{Num}(wrap(D(unwrap(real(x)))), wrap(D(unwrap(imag(x))))) -SymbolicUtils.promote_symtype(::Differential, T) = T SymbolicUtils.isbinop(f::Differential) = false +function SymbolicUtils.operator_to_term(d::Differential, ex::BasicSymbolic{T}) where {T} + return diff2term(ex) +end + is_derivative(x) = iscall(x) ? operation(x) isa Differential : false Base.:*(D1, D2::Differential) = D1 ∘ D2 From e79bdfa7a96c5986b6f54e6edf63ca56af779057 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 18 Sep 2025 13:12:08 +0530 Subject: [PATCH 28/28] refactor: fix `scalarize` for `Num`, `Complex{Num}` --- src/Symbolics.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/Symbolics.jl b/src/Symbolics.jl index 82b5c7095..9d906c5f4 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -214,6 +214,7 @@ for T in [Num, Complex{Num}] SymbolicUtils.hasmetadata(x::$T, t) = SymbolicUtils.hasmetadata(unwrap(x), t) Broadcast.broadcastable(x::$T) = x + SymbolicUtils.scalarize(x::$T) = scalarize(unwrap(x)) end end