Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
5834681
fix: fix promote rules
AayushSabharwal Sep 12, 2025
2f69b08
test: make some tests independent of hash order
AayushSabharwal Sep 12, 2025
9b0f27e
refactor: do not print extra information in `inspect`
AayushSabharwal Sep 12, 2025
e6dbacb
refactor: avoid unnecessary global state in `basicsymbolic_to_polyvar`
AayushSabharwal Sep 12, 2025
8df0348
build: remove `ConcurrentUtilities` dependency
AayushSabharwal Sep 12, 2025
5b29590
refactor: remove outdated empty functions
AayushSabharwal Sep 12, 2025
bb00e5a
chore: remove redundant file
AayushSabharwal Sep 12, 2025
dc902cc
test: update reference tests
AayushSabharwal Sep 12, 2025
d21a13e
docs: turn doctests into `@example` blocks
AayushSabharwal Sep 12, 2025
6fe4a45
test: add note for intermittently failing test
AayushSabharwal Sep 12, 2025
92d4a17
test: update rewrite test
AayushSabharwal Sep 12, 2025
dc752bc
refactor: use locked WCS instead of TaskLocalValue for hashconsing
AayushSabharwal Sep 12, 2025
2965827
feat: allow filtering in `substitute`
AayushSabharwal Sep 15, 2025
0a73bd3
feat: handle `SparseMatrixCSC` in `substitute`
AayushSabharwal Sep 15, 2025
47865f7
feat: special-case `complex(re, img)` term in complex methods
AayushSabharwal Sep 15, 2025
f1462e0
feat: propagate metadata when calling `FnType`
AayushSabharwal Sep 15, 2025
bc06f10
feat: add symbolic function checking methods
AayushSabharwal Sep 15, 2025
10154cb
feat: support `LinearAlgebra.dot`
AayushSabharwal Sep 15, 2025
9c1473c
feat: support `LinearAlgebra.det`
AayushSabharwal Sep 15, 2025
e5bbf7a
feat: support `Base.isempty`
AayushSabharwal Sep 15, 2025
28dfca1
feat: support `Base.CartesianIndex`
AayushSabharwal Sep 15, 2025
6635306
refactor: enable easily extending polyadic methods to wrapper types
AayushSabharwal Sep 15, 2025
817ad4a
feat: support `Base.map`
AayushSabharwal Sep 15, 2025
cb04dfb
feat: support `Base.mapreduce`
AayushSabharwal Sep 15, 2025
2248eb6
refactor: generalize `to_poly!` type bounds
AayushSabharwal Sep 16, 2025
77f791b
refactor: add and use `from_poly`
AayushSabharwal Sep 16, 2025
ae6acba
docs: better describe `@syms` syntax, refactor parsing
AayushSabharwal Sep 16, 2025
9565723
feat: implement 2-arg `size`
AayushSabharwal Sep 16, 2025
252c039
fix: more `@syms` modularity and parsing updates
AayushSabharwal Sep 16, 2025
5cbd451
feat: implement `SII.getname`
AayushSabharwal Sep 16, 2025
6a76ef8
feat: add bounds checks to symbolic `getindex`
AayushSabharwal Sep 16, 2025
ea3ac6a
feat: support `ArrayOp` in `query!`
AayushSabharwal Sep 16, 2025
8222a54
feat: improve `search_variables!`, support `ArrayOp`
AayushSabharwal Sep 16, 2025
8193b61
feat: add `search_variables`
AayushSabharwal Sep 16, 2025
18207c4
feat: add `@map_methods` and `@mapreduce_methods`
AayushSabharwal Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ version = "3.32.0"
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
ConcurrentUtilities = "f0e56b4a-5159-44fe-b623-3e5288b988bb"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand Down Expand Up @@ -51,7 +50,6 @@ AbstractTrees = "0.4"
ArrayInterface = "7.8"
ChainRulesCore = "1"
Combinatorics = "1 - 1.0.2"
ConcurrentUtilities = "2.5.0"
ConstructionBase = "1.5.7"
DataStructures = "0.18, 0.19"
DocStringExtensions = "0.8, 0.9"
Expand Down
9 changes: 0 additions & 9 deletions bench.jl

This file was deleted.

116 changes: 21 additions & 95 deletions docs/src/manual/rewrite.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Rewrite rules match and transform an expression. A rule is written using either

Here is a simple rewrite rule, that uses formula for the double angle of the sine function:

```jldoctest rewrite
```@example rewrite
using SymbolicUtils

@syms w z α::Real β::Real
Expand All @@ -18,9 +18,6 @@ using SymbolicUtils
r1 = @rule sin(2(~x)) => 2sin(~x)*cos(~x)

r1(sin(2z))

# output
2cos(z)*sin(z)
```

The `@rule` macro takes a pair of patterns -- the _matcher_ and the _consequent_ (`@rule matcher => consequent`). If an expression matches the matcher pattern, it is rewritten to the consequent pattern. `@rule` returns a callable object that applies the rule to an expression.
Expand All @@ -30,47 +27,32 @@ The `@rule` macro takes a pair of patterns -- the _matcher_ and the _consequent_
If you try to apply this rule to an expression with triple angle, it will return `nothing` -- this is the way a rule signifies failure to match.
```julia
r1(sin(3z))

# output
nothing
```

Slot variable (matcher) is not necessary a single variable:
```jldoctest rewrite
```@example rewrite
r1(sin(2*(w-z)))

# output
2sin(w - z)*cos(w - z)
```

And can also match a function:
```julia
r = @rule (~f)(z+1) => ~f

r(sin(z+1))

# output
sin (generic function with 20 methods)

```
Rules are of course not limited to single slot variable

```jldoctest rewrite
```@example rewrite
r2 = @rule sin(~x + ~y) => sin(~x)*cos(~y) + cos(~x)*sin(~y);

r2(sin(α+β))

# output
cos(β)*sin(α) + sin(β)*cos(α)
```

Now let's say you want to catch the coefficients of a second degree polynomial in z. You can do that with:
```jldoctest rewrite
```@example rewrite
c2d = @rule ~a + ~b*z + ~c*z^2 => (~a, ~b, ~c)

c2d(3 + 2z + 5z^2)
# output
(3, 2, 5)
2d(3 + 2z + 5z^2)
```
Great! But if you try:
```julia
Expand All @@ -80,12 +62,10 @@ c2d(3 + 2z + z^2)
nothing
```
the rule is not applied. This is because in the input polynomial there isn't a multiplication in front of the `z^2`. For this you can use **defslot variables**, with syntax `~!a`:
```jldoctest rewrite
```@example rewrite
c2d = @rule ~!a + ~!b*z + ~!c*z^2 => (~a, ~b, ~c)

c2d(3 + 2z + z^2)
# output
(3, 2, 1)
2d(3 + 2z + z^2)
```
They work like normal slot variables, but if they are not present they take a default value depending on the operation they are in, in the above example `~b = 1`. Currently defslot variables can be defined in:

Expand All @@ -97,52 +77,31 @@ addition `+` | 0

If you want to match a variable number of subexpressions at once, you will need a **segment variable**. `~~xs` in the following example is a segment variable:

```jldoctest rewrite
```@example rewrite
@syms x y z
@rule(+(~~xs) => ~~xs)(x + y + z)

# output
3-element view(::ReadOnlyArrays.ReadOnlyVector{Any, SymbolicUtils.SmallVec{Any, Vector{Any}}}, 1:3) with eltype Any:
x
y
z
```

`~~xs` is a vector of subexpressions matched. You can use it to construct something more useful:

```jldoctest rewrite
```@example rewrite
r3 = @rule ~x * +(~~ys) => sum(map(y-> ~x * y, ~~ys));

r3(2 * (w+w+α+β))

# output
4w + 2α + 2β
```

Notice that the expression was autosimplified before application of the rule.

```jldoctest rewrite
```@example rewrite
2 * (w+w+α+β)

# output
2(2w + α + β)
```

Note that writing a single tilde `~` as consequent, will make the rule return a dictionary of [slot variable, expression matched].

```jldoctest rewrite
```@example rewrite
r = @rule (~x + (~y)^(~m)) => ~

r(z+w^α)

# output
Base.ImmutableDict{Symbol, Any} with 5 entries:
:MATCH => z + w^α
:m => α
:y => w
:x => z
:____ => nothing

```

### Predicates for matching
Expand All @@ -153,7 +112,7 @@ Similarly `~~x::g` is a way of attaching a predicate `g` to a segment variable.

For example,

```jldoctest pred
```@example pred
using SymbolicUtils
@syms a b c d

Expand All @@ -163,69 +122,51 @@ r = @rule ~x + ~~y::(ys->iseven(length(ys))) => "odd terms";
@show r(b + c + d)
@show r(b + c + b)
@show r(a + b)

# output
r(a + b + c + d) = nothing
r(b + c + d) = "odd terms"
r(b + c + b) = nothing
r(a + b) = nothing
```


### Associative-Commutative Rules

Given an expression `f(x, f(y, z, u), v, w)`, a `f` is said to be associative if the expression is equivalent to `f(x, y, z, u, v, w)` and commutative if the order of arguments does not matter. SymbolicUtils has a special `@acrule` macro meant for rules on functions which are associate and commutative such as addition and multiplication of real and complex numbers.

```jldoctest acr
```@example acr
using SymbolicUtils
@syms x y z

acr = @acrule((~a)^(~x) * (~a)^(~y) => (~a)^(~x + ~y))

acr(x^y * x^z)

# output
x^(y + z)
```

although in case of `Number` it also works the same way with regular `@rule` since autosimplification orders and applies associativity and commutativity to the expression.

### Example of applying the rules to simplify expression

Consider expression `(cos(x) + sin(x))^2` that we would like simplify by applying some trigonometric rules. First, we need rule to expand square of `cos(x) + sin(x)`. First we try the simplest rule to expand square of the sum and try it on simple expression
```jldoctest rewriteex
```@example rewriteex
using SymbolicUtils

@syms x::Real y::Real

sqexpand = @rule (~x + ~y)^2 => (~x)^2 + (~y)^2 + 2 * ~x * ~y

sqexpand((cos(x) + sin(x))^2)

# output
sin(x)^2 + 2sin(x)*cos(x) + cos(x)^2
```

It works. This can be further simplified using Pythagorean identity and check it

```jldoctest rewriteex
```@example rewriteex
pyid = @rule sin(~x)^2 + cos(~x)^2 => 1

pyid(sin(x)^2 + 2sin(x)*cos(x) + cos(x)^2)===nothing

# output
true
```

Why does it return `nothing`? If we look at the expression, we see that we have an additional addend `+ 2sin(x)*cos(x)`. Therefore, in order to work, the rule needs to be associative-commutative.

```jldoctest rewriteex
```@example rewriteex
acpyid = @acrule sin(~x)^2 + cos(~x)^2 => 1

acpyid(cos(x)^2 + sin(x)^2 + 2cos(x)*sin(x))

# output
1 + 2sin(x)*cos(x)
```

It has been some work. Fortunately rules may be [chained together](#chaining rewriters) into more sophisticated rewriters to avoid manual application of the rules.
Expand Down Expand Up @@ -270,7 +211,7 @@ Several rules may be chained to give chain of rules. Chain is an array of rules

To check that, we will combine rules from [previous example](#example of applying the rules to simplify expression) into a chain

```jldoctest composing
```@example composing
using SymbolicUtils
using SymbolicUtils.Rewriters

Expand All @@ -282,52 +223,37 @@ acpyid = @acrule sin(~x)^2 + cos(~x)^2 => 1
csa = Chain([sqexpand, acpyid])

csa((cos(x) + sin(x))^2)

# output
1 + 2sin(x)*cos(x)
```

Important feature of `Chain` is that it returns the expression instead of `nothing` if it doesn't change the expression

```jldoctest composing
```@example composing
Chain([@acrule sin(~x)^2 + cos(~x)^2 => 1])((cos(x) + sin(x))^2)

# output
(sin(x) + cos(x))^2
```

it's important to notice, that chain is ordered, so if rules are in different order it wouldn't work the same as in earlier example

```jldoctest composing
```@example composing
cas = Chain([acpyid, sqexpand])

cas((cos(x) + sin(x))^2)

# output
sin(x)^2 + 2sin(x)*cos(x) + cos(x)^2
```
since Pythagorean identity is applied before square expansion, so it is unable to match squares of sine and cosine.

One way to circumvent the problem of order of applying rules in chain is to use `RestartedChain`

```jldoctest composing
```@example composing
using SymbolicUtils.Rewriters: RestartedChain

rcas = RestartedChain([acpyid, sqexpand])

rcas((cos(x) + sin(x))^2)

# output
1 + 2sin(x)*cos(x)
```

It restarts the chain after each successful application of a rule, so after `sqexpand` is hit it (re)starts again and successfully applies `acpyid` to resulting expression.

You can also use `Fixpoint` to apply the rules until there are no changes.

```jldoctest composing
```@example composing
Fixpoint(cas)((cos(x) + sin(x))^2)

# output
1 + 2sin(x)*cos(x)
```
6 changes: 1 addition & 5 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@ import MacroTools
import MultivariatePolynomials as MP
import DynamicPolynomials as DP
import MutableArithmetics as MA
import ConcurrentUtilities: ReadWriteLock, readlock, readunlock
import LinearAlgebra
import SparseArrays: SparseMatrixCSC, findnz

function hash2 end
function isequal_with_metadata end
import SparseArrays: SparseMatrixCSC, findnz, sparse

macro manually_scope(val, expr, is_forced = false)
@assert Meta.isexpr(val, :call)
Expand Down
2 changes: 1 addition & 1 deletion src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ The key stored in the cache for a particular value. Returns a `SymbolicKey` for
# can't dispatch because `BasicSymbolic` isn't defined here
function get_cache_key(x)
if x isa BasicSymbolic
id = x.id[2]
id = x.id
if id === nothing
return CacheSentinel()
end
Expand Down
17 changes: 14 additions & 3 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, a
symtype, sorted_arguments, metadata, isterm, term, maketerm, unwrap_const,
ArgsT, Const, SymVariant, _is_array_of_symbolics, _is_tuple_of_symbolics,
ArrayOp, isarrayop, IdxToAxesT, ROArgsT, shape, Unknown, ShapeVecT,
search_variables!, _is_index_variable, RangesT, IDXS_SYM
search_variables!, _is_index_variable, RangesT, IDXS_SYM, _is_array_shape
import SymbolicIndexingInterface: symbolic_type, NotSymbolic

##== state management ==##
Expand Down Expand Up @@ -151,7 +151,13 @@ function function_to_expr(::Type{ArrayOp{T}}, O::BasicSymbolic{T}, st) where {T}

# TODO: better infer default eltype from `O`
output_eltype = get(st.rewrites, :arrayop_eltype, Float64)
output_buffer = get(st.rewrites, :arrayop_output, term(zeros, output_eltype, size(O)))
sh = shape(O)
default_output_buffer = if _is_array_shape(sh)
term(zeros, output_eltype, size(O))
else
term(zero, output_eltype)
end
output_buffer = get(st.rewrites, :arrayop_output, default_output_buffer)
toexpr(Let(
[
Assignment(ARRAYOP_OUTSYM, output_buffer),
Expand Down Expand Up @@ -212,12 +218,17 @@ function inplace_expr(x::BasicSymbolic{T}, outsym) where {T}
if outsym isa Symbol
outsym = Sym{T}(outsym; type = Array{Any}, shape = Unknown(-1))
end
sh = shape(x)
ranges = x.ranges
new_ranges = RangesT{T}()
new_expr = unidealize_indices(x.expr, ranges, new_ranges)
loopvar_order = unique!(filter(x -> x isa BasicSymbolic{T}, vcat(reverse(x.output_idx), collect(keys(ranges)), collect(keys(new_ranges)))))

inner_expr = SetArray(false, outsym, [AtIndex(term(CartesianIndex, x.output_idx...), term(x.reduce, term(getindex, outsym, x.output_idx...), new_expr))])
if _is_array_shape(sh)
inner_expr = SetArray(false, outsym, [AtIndex(term(CartesianIndex, x.output_idx...), term(x.reduce, term(getindex, outsym, x.output_idx...), new_expr))])
else
inner_expr = Assignment(outsym, term(x.reduce, outsym, new_expr))
end
merge!(new_ranges, ranges)
loops = foldl(reverse(loopvar_order), init=inner_expr) do acc, k
ForLoop(k, new_ranges[k], acc)
Expand Down
4 changes: 2 additions & 2 deletions src/inspect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@ function AbstractTrees.nodevalue(x::BSImpl.Type)
string(T, "(", x, ")")
elseif isadd(x)
string(T,
(variant=string(x.variant), scalar=x.coeff, coeffs=Tuple(k=>v for (k,v) in x.dict)))
(variant=string(x.variant),))
elseif ismul(x)
string(T,
(variant=string(x.variant), scalar=x.coeff, powers=Tuple(k=>v for (k,v) in x.dict)))
(variant=string(x.variant),))
elseif isdiv(x) || ispow(x)
string(T)
else
Expand Down
Loading
Loading