diff --git a/src/EGraphs/EGraphs.jl b/src/EGraphs/EGraphs.jl index 0720cbb8..914d4b14 100644 --- a/src/EGraphs/EGraphs.jl +++ b/src/EGraphs/EGraphs.jl @@ -19,9 +19,6 @@ include("egraph.jl") export ENode export EClassId export EClass -export hasdata -export getdata -export setdata! export find export lookup export arity @@ -31,14 +28,11 @@ export in_same_class export addexpr! export rebuild! -include("analysis.jl") -export analyze! +include("extract.jl") export extract! export astsize export astsize_inv -export getcost! -export Sub include("Schedulers.jl") export Schedulers diff --git a/src/EGraphs/analysis.jl b/src/EGraphs/analysis.jl deleted file mode 100644 index f599c7b4..00000000 --- a/src/EGraphs/analysis.jl +++ /dev/null @@ -1,201 +0,0 @@ -analysis_reference(x::Symbol) = Val(x) -analysis_reference(x::Function) = x -analysis_reference(x) = error("$x is not a valid analysis reference") - -""" - islazy(::Val{analysis_name}) - -Should return `true` if the EGraph Analysis `an` is lazy -and false otherwise. A *lazy* EGraph Analysis is computed -only when [analyze!](@ref) is called. *Non-lazy* -analyses are instead computed on-the-fly every time ENodes are added to the EGraph or -EClasses are merged. -""" -islazy(::Val{analysis_name}) where {analysis_name} = false -islazy(analysis_name) = islazy(analysis_reference(analysis_name)) - -""" - modify!(::Val{analysis_name}, g, id) - -The `modify!` function for EGraph Analysis can optionally modify the eclass -`g[id]` after it has been analyzed, typically by adding an ENode. -It should be **idempotent** if no other changes occur to the EClass. -(See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)). -""" -modify!(::Val{analysis_name}, g, id) where {analysis_name} = nothing -modify!(an, g, id) = modify!(analysis_reference(an), g, id) - - -""" - join(::Val{analysis_name}, a, b) - -Joins two analyses values into a single one, used by [analyze!](@ref) -when two eclasses are being merged or the analysis is being constructed. -""" -join(analysis::Val{analysis_name}, a, b) where {analysis_name} = - error("Analysis $analysis_name does not implement join") -join(an, a, b) = join(analysis_reference(an), a, b) - -""" - make(::Val{analysis_name}, g, n) - -Given an ENode `n`, `make` should return the corresponding analysis value. -""" -make(::Val{analysis_name}, g, n) where {analysis_name} = error("Analysis $analysis_name does not implement make") -make(an, g, n) = make(analysis_reference(an), g, n) - -analyze!(g::EGraph, analysis_ref, id::EClassId) = analyze!(g, analysis_ref, reachable(g, id)) -analyze!(g::EGraph, analysis_ref) = analyze!(g, analysis_ref, collect(keys(g.classes))) - - -""" - analyze!(egraph, analysis_name, [ECLASS_IDS]) - -Given an [EGraph](@ref) and an `analysis` identified by name `analysis_name`, -do an automated bottom up trasversal of the EGraph, associating a value from the -domain of analysis to each ENode in the egraph by the [make](@ref) function. -Then, for each [EClass](@ref), compute the [join](@ref) of the children ENodes analyses values. -After `analyze!` is called, an analysis value will be associated to each EClass in the EGraph. -One can inspect and retrieve analysis values by using [hasdata](@ref) and [getdata](@ref). -""" -function analyze!(g::EGraph, analysis_ref, ids::Vector{EClassId}) - addanalysis!(g, analysis_ref) - ids = sort(ids) - # @assert isempty(g.dirty) - - did_something = true - while did_something - did_something = false - - for id in ids - eclass = g[id] - id = eclass.id - pass = mapreduce(x -> make(analysis_ref, g, x), (x, y) -> join(analysis_ref, x, y), eclass) - - if !isequal(pass, getdata(eclass, analysis_ref, missing)) - setdata!(eclass, analysis_ref, pass) - did_something = true - modify!(analysis_ref, g, id) - push!(g.analysis_pending, (eclass[1] => id)) - end - end - end - - for id in ids - eclass = g[id] - id = eclass.id - if !hasdata(eclass, analysis_ref) - error("failed to compute analysis for eclass ", id) - end - end - - rebuild!(g) - return true -end - -""" -A basic cost function, where the computed cost is the size -(number of children) of the current expression. -""" -function astsize(n::ENode, g::EGraph) - n.istree || return 1 - cost = 2 + arity(n) - for id in arguments(n) - eclass = g[id] - !hasdata(eclass, astsize) && (cost += Inf; break) - cost += last(getdata(eclass, astsize)) - end - return cost -end - -""" -A basic cost function, where the computed cost is the size -(number of children) of the current expression, times -1. -Strives to get the largest expression -""" -function astsize_inv(n::ENode, g::EGraph) - n.istree || return -1 - cost = -(1 + arity(n)) # minus sign here is the only difference vs astsize - for id in arguments(n) - eclass = g[id] - !hasdata(eclass, astsize_inv) && (cost += Inf; break) - cost += last(getdata(eclass, astsize_inv)) - end - return cost -end - -""" -When passing a function to analysis functions it is considered as a cost function -""" -make(f::Function, g::EGraph, n::ENode) = (n, f(n, g)) - -join(f::Function, from, to) = last(from) <= last(to) ? from : to - -islazy(::Function) = true -modify!(::Function, g, id) = nothing - -function rec_extract(g::EGraph, costfun, id::EClassId; cse_env = nothing) - eclass = g[id] - if !isnothing(cse_env) && haskey(cse_env, id) - (sym, _) = cse_env[id] - return sym - end - (n, ck) = getdata(eclass, costfun, (nothing, Inf)) - ck == Inf && error("Infinite cost when extracting enode") - - n.istree || return n.operation - children = map(arg -> rec_extract(g, costfun, arg; cse_env = cse_env), n.args) - meta = getdata(eclass, :metadata_analysis, nothing) - h = head(n) - children = head_symbol(h) == :call ? [operation(n); children...] : children - maketerm(h, children; metadata = meta) -end - -""" -Given a cost function, extract the expression -with the smallest computed cost from an [`EGraph`](@ref) -""" -function extract!(g::EGraph, costfun::Function; root = g.root, cse = false) - analyze!(g, costfun, root) - if cse - # TODO make sure there is no assignments/stateful code!! - cse_env = Dict{EClassId,Tuple{Symbol,Any}}() # - collect_cse!(g, costfun, root, cse_env, Set{EClassId}()) - - body = rec_extract(g, costfun, root; cse_env = cse_env) - - assignments = [Expr(:(=), name, val) for (id, (name, val)) in cse_env] - # return body - Expr(:let, Expr(:block, assignments...), body) - else - return rec_extract(g, costfun, root) - end -end - - -# Builds a dict e-class id => (symbol, extracted term) of common subexpressions in an e-graph -function collect_cse!(g::EGraph, costfun, id, cse_env, seen) - eclass = g[id] - (cn, ck) = getdata(eclass, costfun, (nothing, Inf)) - ck == Inf && error("Error when computing CSE") - - cn.istree || return - if id in seen - cse_env[id] = (gensym(), rec_extract(g, costfun, id))#, cse_env=cse_env)) # todo generalize symbol? - return - end - for child_id in arguments(cn) - collect_cse!(g, costfun, child_id, cse_env, seen) - end - push!(seen, id) -end - - -function getcost!(g::EGraph, costfun; root = -1) - if root == -1 - root = g.root - end - analyze!(g, costfun, root) - bestnode, cost = getdata(g[root], costfun) - return cost -end diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 90492dd8..ad816eba 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -1,12 +1,34 @@ # Functional implementation of https://egraphs-good.github.io/ # https://dl.acm.org/doi/10.1145/3434304 +import Metatheory: maybelock! -# abstract type AbstractENode end +""" + modify!(eclass::EClass{Analysis}) -import Metatheory: maybelock! +The `modify!` function for EGraph Analysis can optionally modify the eclass +`eclass` after it has been analyzed, typically by adding an ENode. +It should be **idempotent** if no other changes occur to the EClass. +(See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304)). +""" +function modify! end + + +""" + join(a::AnalysisType, b::AnalysisType)::AnalysisType + +Joins two analyses values into a single one, used by [analyze!](@ref) +when two eclasses are being merged or the analysis is being constructed. +""" +function join end + +""" + make(g::EGraph{Head, AnalysisType}, n::ENode)::AnalysisType where Head + +Given an ENode `n`, `make` should return the corresponding analysis value. +""" +function make end -const AnalysisData = NamedTuple{N,<:Tuple{Vararg{Ref}}} where {N} const EClassId = Int64 const TermTypes = Dict{Tuple{Any,Int},Type} # TODO document bindings @@ -49,12 +71,12 @@ function Base.:(==)(a::ENode, b::ENode) hash(a) == hash(b) && a.operation == b.operation end -function toexpr(n::ENode) +function to_expr(n::ENode) n.istree || return n.operation Expr(:call, :ENode, head(n), operation(n), arguments(n)) end -Base.show(io::IO, x::ENode) = print(io, toexpr(x)) +Base.show(io::IO, x::ENode) = print(io, to_expr(x)) function op_key(n) op = operation(n) @@ -62,16 +84,13 @@ function op_key(n) end # parametrize metadata by M -mutable struct EClass - g # EGraph +mutable struct EClass{D} id::EClassId nodes::Vector{ENode} parents::Vector{Pair{ENode,EClassId}} - data::AnalysisData + data::Union{D,Missing} end -EClass(g, id, nodes, parents) = EClass(g, id, nodes, parents, NamedTuple()) - # Interface for indexing EClass Base.getindex(a::EClass, i) = a.nodes[i] @@ -84,31 +103,26 @@ function Base.show(io::IO, a::EClass) print(io, "EClass $(a.id) (") print(io, "[", Base.join(a.nodes, ", "), "], ") + print(io, a.data) print(io, ")") end -function addparent!(a::EClass, n::ENode, id::EClassId) +function addparent!(@nospecialize(a::EClass), n::ENode, id::EClassId) push!(a.parents, (n => id)) end -function merge_analysis_data!(g, a::EClass, b::EClass)::Tuple{Bool,Bool} - if !isempty(a.data) && !isempty(b.data) - new_a_data = Base.merge(a.data, b.data) - for analysis_name in keys(b.data) - analysis_ref = g.analyses[analysis_name] - if hasproperty(a.data, analysis_name) - ref = getproperty(new_a_data, analysis_name) - ref[] = join(analysis_ref, ref[], getproperty(b.data, analysis_name)[]) - end - end + +function merge_analysis_data!(@nospecialize(a::EClass), @nospecialize(b::EClass))::Tuple{Bool,Bool} + if !ismissing(a.data) && !ismissing(b.data) + new_a_data = join(a.data, b.data) merged_a = (a.data == new_a_data) a.data = new_a_data (merged_a, b.data == new_a_data) - elseif isempty(a.data) && !isempty(b.data) + elseif !ismissing(a.data) && !ismissing(b.data) a.data = b.data # a merged, b not merged (true, false) - elseif !isempty(a.data) && isempty(b.data) + elseif !ismissing(a.data) && !ismissing(b.data) b.data = a.data (false, true) else @@ -116,48 +130,27 @@ function merge_analysis_data!(g, a::EClass, b::EClass)::Tuple{Bool,Bool} end end -# Thanks to Shashi Gowda -hasdata(a::EClass, analysis_name::Symbol) = hasproperty(a.data, analysis_name) -hasdata(a::EClass, f::Function) = hasproperty(a.data, nameof(f)) -getdata(a::EClass, analysis_name::Symbol) = getproperty(a.data, analysis_name)[] -getdata(a::EClass, f::Function) = getproperty(a.data, nameof(f))[] -getdata(a::EClass, analysis_ref::Union{Symbol,Function}, default) = - hasdata(a, analysis_ref) ? getdata(a, analysis_ref) : default - - -setdata!(a::EClass, f::Function, value) = setdata!(a, nameof(f), value) -function setdata!(a::EClass, analysis_name::Symbol, value) - if hasdata(a, analysis_name) - ref = getproperty(a.data, analysis_name) - ref[] = value - else - a.data = merge(a.data, NamedTuple{(analysis_name,)}((Ref{Any}(value),))) - end -end """ A concrete type representing an [`EGraph`]. See the [egg paper](https://dl.acm.org/doi/pdf/10.1145/3434304) for implementation details. """ -mutable struct EGraph +mutable struct EGraph{Head,Analysis} "stores the equality relations over e-class ids" uf::UnionFind "map from eclass id to eclasses" - classes::IdDict{EClassId,EClass} + classes::Dict{EClassId,EClass{Analysis}} "hashcons" memo::Dict{ENode,EClassId} "Nodes which need to be processed for rebuilding. The id is the id of the enode, not the canonical id of the eclass." pending::Vector{Pair{ENode,EClassId}} analysis_pending::UniqueQueue{Pair{ENode,EClassId}} root::EClassId - "A vector of analyses associated to the EGraph" - analyses::Dict{Union{Symbol,Function},Union{Symbol,Function}} "a cache mapping function symbols and their arity to e-classes that contain e-nodes with that function symbol." classes_by_op::Dict{Pair{Any,Int},Vector{EClassId}} - head_type::Type clean::Bool - "If we use global buffers we may need to lock. Defaults to true." + "If we use global buffers we may need to lock. Defaults to false." needslock::Bool "Buffer for e-matching which defaults to a global. Use a local buffer for generated functions." buffer::Vector{Bindings} @@ -171,17 +164,15 @@ end EGraph(expr) Construct an EGraph from a starting symbolic expression `expr`. """ -function EGraph(; needslock::Bool = false, head_type = ExprHead) - EGraph( +function EGraph{Head,Analysis}(; needslock::Bool = false) where {Head,Analysis} + EGraph{Head,Analysis}( UnionFind(), Dict{EClassId,EClass}(), Dict{ENode,EClassId}(), Pair{ENode,EClassId}[], UniqueQueue{Pair{ENode,EClassId}}(), -1, - Dict{Union{Symbol,Function},Union{Symbol,Function}}(), - Dict{Any,Vector{EClassId}}(), - head_type, + Dict{Pair{Any,Int},Vector{EClassId}}(), false, needslock, Bindings[], @@ -189,37 +180,35 @@ function EGraph(; needslock::Bool = false, head_type = ExprHead) ReentrantLock(), ) end +EGraph(; kwargs...) = EGraph{ExprHead,Missing}(; kwargs...) +EGraph{Head}(; kwargs...) where {Head} = EGraph{Head,Missing}(; kwargs...) -function maybelock!(f::Function, g::EGraph) - g.needslock ? lock(f, g.buffer_lock) : f() -end - -function EGraph(e; keepmeta = false, kwargs...) - g = EGraph(; kwargs...) - keepmeta && addanalysis!(g, :metadata_analysis) - g.root = addexpr!(g, e, keepmeta) +function EGraph{Head,Analysis}(e; kwargs...) where {Head,Analysis} + g = EGraph{Head,Analysis}(; kwargs...) + g.root = addexpr!(g, e) g end -function addanalysis!(g::EGraph, costfun::Function) - g.analyses[nameof(costfun)] = costfun - g.analyses[costfun] = costfun -end +EGraph{Head}(e; kwargs...) where {Head} = EGraph{Head,Missing}(e; kwargs...) +EGraph(e; kwargs...) = EGraph{typeof(head(e)),Missing}(e; kwargs...) -function addanalysis!(g::EGraph, analysis_name::Symbol) - g.analyses[analysis_name] = analysis_name -end +# Fallback implementation for analysis methods make and modify +@inline make(::EGraph, ::ENode) = missing +@inline modify!(::EGraph, ::EClass{Analysis}) where {Analysis} = nothing -total_size(g::EGraph) = length(g.memo) +function maybelock!(f::Function, g::EGraph) + g.needslock ? lock(f, g.buffer_lock) : f() +end + """ Returns the canonical e-class id for a given e-class. """ -find(g::EGraph, a::EClassId)::EClassId = find(g.uf, a) -find(g::EGraph, a::EClass)::EClassId = find(g, a.id) +@inline find(g::EGraph, a::EClassId)::EClassId = find(g.uf, a) +@inline find(@nospecialize(g::EGraph), @nospecialize(a::EClass))::EClassId = find(g, a.id) -Base.getindex(g::EGraph, i::EClassId) = g.classes[find(g, i)] +@inline Base.getindex(g::EGraph, i::EClassId) = g.classes[find(g, i)] function canonicalize(g::EGraph, n::ENode)::ENode n.istree || return n @@ -259,7 +248,7 @@ end """ Inserts an e-node in an [`EGraph`](@ref) """ -function add!(g::EGraph, n::ENode)::EClassId +function add!(g::EGraph{Head,Analysis}, n::ENode)::EClassId where {Head,Analysis} n = canonicalize(g, n) haskey(g.memo, n) && return g.memo[n] @@ -274,16 +263,11 @@ function add!(g::EGraph, n::ENode)::EClassId g.memo[n] = id add_class_by_op(g, n, id) - classdata = EClass(g, id, ENode[n], Pair{ENode,EClassId}[]) - g.classes[id] = classdata + eclass = EClass{Analysis}(id, ENode[n], Pair{ENode,EClassId}[], make(g, n)) + g.classes[id] = eclass + modify!(g, eclass) push!(g.pending, n => id) - for an in values(g.analyses) - if !islazy(an) && an !== :metadata_analysis - setdata!(classdata, an, make(an, g, n)) - modify!(an, g, id) - end - end return id end @@ -351,7 +335,7 @@ function Base.union!(g::EGraph, enode_id1::EClassId, enode_id2::EClassId)::Bool append!(g.pending, eclass_2.parents) - (merged_1, merged_2) = merge_analysis_data!(g, eclass_1, eclass_2) + (merged_1, merged_2) = merge_analysis_data!(eclass_1, eclass_2) merged_1 && append!(g.analysis_pending, eclass_1.parents) merged_2 && append!(g.analysis_pending, eclass_2.parents) @@ -398,7 +382,7 @@ function rebuild_classes!(g::EGraph) end end -function process_unions!(g::EGraph)::Int +function process_unions!(@nospecialize(g::EGraph))::Int n_unions = 0 while !isempty(g.pending) || !isempty(g.analysis_pending) @@ -420,26 +404,20 @@ function process_unions!(g::EGraph)::Int eclass_id = find(g, eclass_id) eclass = g[eclass_id] - for an in values(g.analyses) - - an === :metadata_analysis && continue + node_data = make(g, node) + if !ismissing(eclass.data) + joined_data = join(eclass.data, node_data) - node_data = make(an, g, node) - if hasdata(eclass, an) - class_data = getdata(eclass, an) - - joined_data = join(an, class_data, node_data) - - if joined_data != class_data - setdata!(eclass, an, joined_data) - modify!(an, g, eclass_id) - append!(g.analysis_pending, eclass.parents) - end - elseif !islazy(an) - setdata!(eclass, an, node_data) - modify!(an, g, eclass_id) + if joined_data != class_data + setdata!(eclass, an, joined_data) + modify!(g, eclass) + append!(g.analysis_pending, eclass.parents) end + else + eclass.data = node_data + modify!(g, eclass) end + end end n_unions @@ -468,13 +446,9 @@ end function check_analysis(g) for (id, eclass) in g.classes - for an in values(g.analyses) - an == :metadata_analysis && continue - islazy(an) || (@assert hasdata(eclass, an)) - hasdata(eclass, an) || continue - pass = mapreduce(x -> make(an, g, x), (x, y) -> join(an, x, y), eclass) - @assert getdata(eclass, an) == pass - end + ismissing(eclass.data) && continue + pass = mapreduce(x -> make(g, x), (x, y) -> join(x, y), eclass) + @assert eclass.data == pass end true end @@ -488,8 +462,8 @@ for more details. function rebuild!(g::EGraph) n_unions = process_unions!(g) trimmed_nodes = rebuild_classes!(g) - # @assert check_memo(g) - # @assert check_analysis(g) + @assert check_memo(g) + @assert check_analysis(g) g.clean = true @debug "REBUILT" n_unions trimmed_nodes @@ -529,19 +503,19 @@ end import Metatheory: lookup_pat -function lookup_pat(g::EGraph, p::PatTerm)::EClassId +function lookup_pat(g::EGraph{Head}, p::PatTerm)::EClassId where {Head} @assert isground(p) op = operation(p) args = arguments(p) ar = arity(p) - eh = g.head_type(head_symbol(head(p))) + eh = Head(head_symbol(head(p))) ids = map(x -> lookup_pat(g, x), args) !all((>)(0), ids) && return -1 - if g.head_type == ExprHead && op isa Union{Function,DataType} + if Head == ExprHead && op isa Union{Function,DataType} id = lookup(g, ENode(eh, op, ids)) id < 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id else diff --git a/src/EGraphs/extract.jl b/src/EGraphs/extract.jl new file mode 100644 index 00000000..899952ba --- /dev/null +++ b/src/EGraphs/extract.jl @@ -0,0 +1,97 @@ +struct Extractor{CostFun,Cost} + g::EGraph + cost_function::CostFun + costs::Dict{EClassId,Tuple{Cost,Int64}} # Cost and index in eclass +end + +""" +Given a cost function, extract the expression +with the smallest computed cost from an [`EGraph`](@ref) +""" +function Extractor(g::EGraph, cost_function::Function, cost_type = Float64) + extractor = Extractor{typeof(cost_function),cost_type}(g, cost_function, Dict{EClassId,Tuple{cost_type,ENode}}()) + find_costs!(extractor) + extractor +end + +function extract_expr_recursive(n::ENode, get_node::Function) + n.istree || return n.operation + children = extract_expr_recursive.(get_node.(n.args), get_node) + h = head(n) + # TODO style of operation? + head_symbol(h) == :call && (children = [operation(n); children]) + # TODO metadata? + maketerm(h, children) +end + + +function (extractor::Extractor)(root = extractor.g.root) + get_node(eclass_id::EClassId) = find_best_node(extractor, eclass_id) + # TODO check if infinite cost? + extract_expr_recursive(find_best_node(extractor, root), get_node) +end + +# costs dict stores index of enode. get this enode from the eclass +function find_best_node(extractor::Extractor, eclass_id::EClassId) + eclass = extractor.g[eclass_id] + (_, node_index) = extractor.costs[eclass.id] + eclass.nodes[node_index] +end + +function find_costs!(extractor::Extractor{CF,CT}) where {CF,CT} + function enode_cost(n::ENode)::CT + if all(x -> haskey(extractor.costs, x), arguments(n)) + extractor.cost_function(n, map(child_id -> extractor.costs[child_id][1], n.args)) + else + typemax(CT) + end + end + + + did_something = true + while did_something + did_something = false + + for (id, eclass) in extractor.g.classes + costs = enode_cost.(eclass.nodes) + pass = (minimum(costs), argmin(costs)) + + if pass != typemax(CT) && (!haskey(extractor.costs, id) || (pass[1] < extractor.costs[id][1])) + extractor.costs[id] = pass + did_something = true + end + end + end + + for (id, _) in extractor.g.classes + if !haskey(extractor.costs, id) + error("failed to compute extraction costs for eclass ", id) + end + end +end + +""" +A basic cost function, where the computed cost is the size +(number of children) of the current expression. +""" +function astsize(n::ENode, costs::Vector{Float64})::Float64 + n.istree || return 1 + cost = 2 + arity(n) + cost + sum(costs) +end + +""" +A basic cost function, where the computed cost is the size +(number of children) of the current expression, times -1. +Strives to get the largest expression +""" +function astsize_inv(n::ENode, costs::Vector{Float64})::Float64 + n.istree || return -1 + cost = -(1 + arity(n)) # minus sign here is the only difference vs astsize + cost + sum(costs) +end + + +function extract!(g::EGraph, costfun) + Extractor(g, costfun, Float64)() +end \ No newline at end of file diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index bb8ec81d..734e2d95 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -98,14 +98,14 @@ function eqsat_search!( return n_matches end -instantiate_enode!(bindings::Bindings, g::EGraph, p::Any)::EClassId = add!(g, ENode(p)) -instantiate_enode!(bindings::Bindings, g::EGraph, p::PatVar)::EClassId = bindings[p.idx][1] -function instantiate_enode!(bindings::Bindings, g::EGraph, p::PatTerm)::EClassId +instantiate_enode!(bindings::Bindings, @nospecialize(g::EGraph), p::Any)::EClassId = add!(g, ENode(p)) +instantiate_enode!(bindings::Bindings, @nospecialize(g::EGraph), p::PatVar)::EClassId = bindings[p.idx][1] +function instantiate_enode!(bindings::Bindings, g::EGraph{Head}, p::PatTerm)::EClassId where {Head} op = operation(p) args = arguments(p) - # TODO add predicate check `quotes_operation` - new_op = g.head_type == ExprHead && op isa Union{Function,DataType} ? nameof(op) : op - eh = g.head_type(head_symbol(head(p))) + # TODO handle this situation better + new_op = Head == ExprHead && op isa Union{Function,DataType} ? nameof(op) : op + eh = Head(head_symbol(head(p))) nargs = Vector{EClassId}(undef, length(args)) for i in 1:length(args) @inbounds nargs[i] = instantiate_enode!(bindings, g, args[i]) @@ -189,7 +189,7 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation return end - if params.enodelimit > 0 && total_size(g) > params.enodelimit + if params.enodelimit > 0 && length(g.memo) > params.enodelimit @debug "Too many enodes" rep.reason = :enodelimit break diff --git a/test/egraphs/analysis.jl b/test/egraphs/analysis.jl index afd3de16..54e09e2b 100644 --- a/test/egraphs/analysis.jl +++ b/test/egraphs/analysis.jl @@ -5,50 +5,44 @@ using Metatheory using Metatheory.Library +struct NumberFoldAnalysis + n::Number +end + +Base.:(*)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n * b.n) +Base.:(+)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n + b.n) + # This should be auto-generated by a macro -function EGraphs.make(::Val{:numberfold}, g::EGraph, n::ENode) - istree(n) || return operation(n) +function EGraphs.make(g::EGraph{Head,NumberFoldAnalysis}, n::ENode) where {Head} + istree(n) || return operation(n) isa Number ? NumberFoldAnalysis(operation(n)) : missing if head_symbol(head(n)) == :call && arity(n) == 2 op = operation(n) args = arguments(n) - l = g[args[1]] - r = g[args[2]] - ldata = getdata(l, :numberfold, nothing) - rdata = getdata(r, :numberfold, nothing) + l, r = g[args[1]], g[args[2]] - if ldata isa Number && rdata isa Number + if l.data isa NumberFoldAnalysis && r.data isa NumberFoldAnalysis if op == :* - return ldata * rdata + return l.data * r.data elseif op == :+ - return ldata + rdata + return l.data + r.data end end end - return nothing + # Could not analyze + return missing end -function EGraphs.join(an::Val{:numberfold}, from, to) - if from isa Number - if to isa Number - @assert from == to - else - return from - end - end - return to +function EGraphs.join(from::NumberFoldAnalysis, to::NumberFoldAnalysis) + @assert from == to + from end -function EGraphs.modify!(::Val{:numberfold}, g::EGraph, id::Int64) - eclass = g.classes[id] - d = getdata(eclass, :numberfold, nothing) - if d isa Number - union!(g, addexpr!(g, d), id) - end +# Add the number to the eclass. +function EGraphs.modify!(g::EGraph{Head,NumberFoldAnalysis}, eclass::EClass{NumberFoldAnalysis}) where {Head} + ismissing(eclass.data) || union!(g, addexpr!(g, eclass.data.n), find(g, eclass.id)) end -EGraphs.islazy(::Val{:numberfold}) = false - comm_monoid = @theory begin ~a * ~b --> ~b * ~a @@ -56,8 +50,7 @@ comm_monoid = @theory begin ~a * (~b * ~c) --> (~a * ~b) * ~c end -g = EGraph(:(3 * 4)) -analyze!(g, :numberfold) +g = EGraph{ExprHead,NumberFoldAnalysis}(:(3 * 4)) @testset "Basic Constant Folding Example - Commutative Monoid" begin @@ -68,28 +61,25 @@ end @testset "Basic Constant Folding Example 2 - Commutative Monoid" begin ex = :(a * 3 * b * 4) - G = EGraph(ex) - analyze!(G, :numberfold) - addexpr!(G, :(12 * a)) - @test (true == @areequalg G comm_monoid (12 * a) * b ((6 * 2) * b) * a) - @test (true == @areequalg G comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a) + g = EGraph{ExprHead,NumberFoldAnalysis}(ex) + addexpr!(g, :(12 * a)) + @test (true == @areequalg g comm_monoid (12 * a) * b ((6 * 2) * b) * a) + @test (true == @areequalg g comm_monoid (3 * a) * (4 * b) (12 * a) * b ((6 * 2) * b) * a) end @testset "Basic Constant Folding Example - Adding analysis after saturation" begin - G = EGraph(:(3 * 4)) - # addexpr!(G, 12) - saturate!(G, comm_monoid) - addexpr!(G, :(a * 2)) - analyze!(G, :numberfold) - saturate!(G, comm_monoid) + g = EGraph{ExprHead,NumberFoldAnalysis}(:(3 * 4)) + # addexpr!(g, 12) + saturate!(g, comm_monoid) + addexpr!(g, :(a * 2)) + saturate!(g, comm_monoid) - @test (true == areequal(G, comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2))) + @test (true == areequal(g, comm_monoid, :(3 * 4), 12, :(4 * 3), :(6 * 2))) ex = :(a * 3 * b * 4) - G = EGraph(ex) - analyze!(G, :numberfold) + g = EGraph{ExprHead,NumberFoldAnalysis}(ex) params = SaturationParams(timeout = 15) - @test areequal(G, comm_monoid, :((3 * a) * (4 * b)), :((12 * a) * b), :(((6 * 2) * b) * a); params = params) + @test areequal(g, comm_monoid, :((3 * a) * (4 * b)), :((12 * a) * b), :(((6 * 2) * b) * a); params = params) end @testset "Infinite Loops analysis" begin @@ -98,10 +88,10 @@ end end - G = EGraph(:(1 * x)) + g = EGraph(:(1 * x)) params = SaturationParams(timeout = 100) - saturate!(G, boson, params) - ex = extract!(G, astsize) + saturate!(g, boson, params) + ex = extract!(g, astsize) boson = @theory begin @@ -119,227 +109,3 @@ end end -@testset "Extraction" begin - comm_monoid = @commutative_monoid (*) 1 - - fold_mul = @theory begin - ~a::Number * ~b::Number => ~a * ~b - end - - t = comm_monoid ∪ fold_mul - - - @testset "Extraction 1 - Commutative Monoid" begin - g = EGraph(:(3 * 4)) - saturate!(g, t) - @test (12 == extract!(g, astsize)) - - ex = :(a * 3 * b * 4) - g = EGraph(ex) - params = SaturationParams(timeout = 15) - saturate!(g, t, params) - extr = extract!(g, astsize) - @test extr == :((12 * a) * b) || - extr == :(12 * (a * b)) || - extr == :(a * (b * 12)) || - extr == :((a * b) * 12) || - extr == :((12a) * b) || - extr == :(a * (12b)) || - extr == :((b * (12a))) || - extr == :((b * 12) * a) || - extr == :((b * a) * 12) || - extr == :(b * (a * 12)) || - extr == :((12b) * a) - end - - fold_add = @theory begin - ~a::Number + ~b::Number => ~a + ~b - end - - @testset "Extraction 2" begin - comm_group = @commutative_group (+) 0 inv - - - t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ fold_mul ∪ fold_add - - # for i ∈ 1:20 - # sleep(0.3) - ex = :((x * (a + b)) + (y * (a + b))) - G = EGraph(ex) - saturate!(G, t) - # end - - extract!(G, astsize) == :((y + x) * (b + a)) - end - - @testset "Extraction - Adding analysis after saturation" begin - G = EGraph(:(3 * 4)) - addexpr!(G, 12) - addexpr!(G, :(a * 2)) - # saturate!(G, t) - # saturate!(G, t) - - saturate!(G, t) - - @test (12 == extract!(G, astsize)) - - ex = :(a * 3 * b * 4) - G = EGraph(ex) - analyze!(G, :numberfold) - params = SaturationParams(timeout = 15) - saturate!(G, comm_monoid, params) - - extr = extract!(G, astsize) - - @test extr ∈ ( - :((12 * a) * b), - :(12 * (a * b)), - :(a * (b * 12)), - :((a * b) * 12), - :((12a) * b), - :(a * (12b)), - :((b * (12a))), - :((b * 12) * a), - :((b * a) * 12), - :(b * (a * 12)), - ) - end - - - comm_monoid = @commutative_monoid (*) 1 - - comm_group = @commutative_group (+) 0 inv - - powers = @theory begin - ~a * ~a → (~a)^2 - ~a → (~a)^1 - (~a)^~n * (~a)^~m → (~a)^(~n + ~m) - end - logids = @theory begin - log((~a)^~n) --> ~n * log(~a) - log(~x * ~y) --> log(~x) + log(~y) - log(1) --> 0 - log(:e) --> 1 - :e^(log(~x)) --> ~x - end - - G = EGraph(:(log(e))) - params = SaturationParams(timeout = 9) - saturate!(G, logids, params) - @test extract!(G, astsize) == 1 - - - t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ powers ∪ logids ∪ fold_mul ∪ fold_add - - @testset "Complex Extraction" begin - G = EGraph(:(log(e) * log(e))) - params = SaturationParams(timeout = 9) - saturate!(G, t, params) - @test extract!(G, astsize) == 1 - - G = EGraph(:(log(e) * (log(e) * e^(log(3))))) - params = SaturationParams(timeout = 7) - saturate!(G, t, params) - @test extract!(G, astsize) == 3 - - - G = EGraph(:(a^3 * a^2)) - saturate!(G, t) - ex = extract!(G, astsize) - @test ex == :(a^5) - - G = EGraph(:(a^3 * a^2)) - saturate!(G, t) - ex = extract!(G, astsize) - @test ex == :(a^5) - - function cust_astsize(n::ENode, g::EGraph) - n.istree || return 1 - cost = 1 + arity(n) - - if operation(n) == :^ - cost += 2 - end - - for id in arguments(n) - eclass = g[id] - !hasdata(eclass, cust_astsize) && (cost += Inf; break) - cost += last(getdata(eclass, cust_astsize)) - end - return cost - end - - G = EGraph(:((log(e) * log(e)) * (log(a^3 * a^2)))) - saturate!(G, t) - ex = extract!(G, cust_astsize) - @test ex == :(5 * log(a)) || ex == :(log(a) * 5) - end - - function costfun(n::ENode, g::EGraph) - n.istree || return 1 - arity(n) != 2 && (return 1) - left = arguments(n)[1] - left_class = g[left] - ENode(:a) ∈ left_class.nodes ? 1 : 100 - end - - - moveright = @theory begin - (:b * (:a * ~c)) --> (:a * (:b * ~c)) - end - - expr = :(a * (a * (b * (a * b)))) - res = rewrite(expr, moveright) - - g = EGraph(expr) - saturate!(g, moveright) - resg = extract!(g, costfun) - - @testset "Symbols in Right hand" begin - @test resg == res == :(a * (a * (a * (b * b)))) - end - - function ⋅ end - co = @theory begin - sum(~x ⋅ :bazoo ⋅ :woo) --> sum(:n * ~x) - end - @testset "Consistency with classical backend" begin - ex = :(sum(wa(rio) ⋅ bazoo ⋅ woo)) - g = EGraph(ex) - saturate!(g, co) - - res = extract!(g, astsize) - - resclassic = rewrite(ex, co) - - @test res == resclassic - end - - - @testset "No arguments" begin - ex = :(f()) - g = EGraph(ex) - @test :(f()) == extract!(g, astsize) - - ex = :(sin() + cos()) - - t = @theory begin - sin() + cos() --> tan() - end - - gg = EGraph(ex) - saturate!(gg, t) - res = extract!(gg, astsize) - - @test res == :(tan()) - end - - - @testset "Symbol or function object operators in expressions in EGraphs" begin - ex = :(($+)(x, y)) - t = [@rule a b a + b => 2] - g = EGraph(ex) - saturate!(g, t) - @test extract!(g, astsize) == 2 - end -end diff --git a/test/egraphs/egraphs.jl b/test/egraphs/egraphs.jl index 9443ef57..9bd20711 100644 --- a/test/egraphs/egraphs.jl +++ b/test/egraphs/egraphs.jl @@ -38,7 +38,7 @@ end apply(n, f, x) = n == 0 ? x : apply(n - 1, f, f(x)) f(x) = Expr(:call, :f, x) - g = EGraph(:a) + g = EGraph{ExprHead}(:a) t1 = addexpr!(g, apply(6, f, :a)) t2 = addexpr!(g, apply(9, f, :a)) diff --git a/test/egraphs/ematch.jl b/test/egraphs/ematch.jl index 3e455b19..1ad46f43 100644 --- a/test/egraphs/ematch.jl +++ b/test/egraphs/ematch.jl @@ -104,7 +104,7 @@ export t end -g = EGraph(:woo); +g = EGraph{ExprHead}(:woo); saturate!(g, Bar.t); saturate!(g, Foo.t); foo = 12 @@ -171,7 +171,7 @@ end :bazoo --> :wazoo end - g = EGraph(:foo) + g = EGraph{ExprHead}(:foo) report = saturate!(g, failme) @test report.reason === :contradiction end diff --git a/test/egraphs/extract.jl b/test/egraphs/extract.jl new file mode 100644 index 00000000..a94bc2d4 --- /dev/null +++ b/test/egraphs/extract.jl @@ -0,0 +1,186 @@ + +using Metatheory +using Metatheory.Library + +comm_monoid = @commutative_monoid (*) 1 + +fold_mul = @theory begin + ~a::Number * ~b::Number => ~a * ~b +end + + + +@testset "Extraction 1 - Commutative Monoid" begin + t = comm_monoid ∪ fold_mul + g = EGraph(:(3 * 4)) + saturate!(g, t) + @test (12 == extract!(g, astsize)) + + ex = :(a * 3 * b * 4) + g = EGraph(ex) + params = SaturationParams(timeout = 15) + saturate!(g, t, params) + extr = extract!(g, astsize) + @test extr == :((12 * a) * b) || + extr == :(12 * (a * b)) || + extr == :(a * (b * 12)) || + extr == :((a * b) * 12) || + extr == :((12a) * b) || + extr == :(a * (12b)) || + extr == :((b * (12a))) || + extr == :((b * 12) * a) || + extr == :((b * a) * 12) || + extr == :(b * (a * 12)) || + extr == :((12b) * a) +end + +fold_add = @theory begin + ~a::Number + ~b::Number => ~a + ~b +end + +@testset "Extraction 2" begin + comm_group = @commutative_group (+) 0 inv + + + t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ fold_mul ∪ fold_add + + ex = :((x * (a + b)) + (y * (a + b))) + g = EGraph(ex) + saturate!(g, t) + extract!(g, astsize) == :((y + x) * (b + a)) +end + +comm_monoid = @commutative_monoid (*) 1 + +comm_group = @commutative_group (+) 0 inv + +powers = @theory begin + ~a * ~a → (~a)^2 + ~a → (~a)^1 + (~a)^~n * (~a)^~m → (~a)^(~n + ~m) +end +logids = @theory begin + log((~a)^~n) --> ~n * log(~a) + log(~x * ~y) --> log(~x) + log(~y) + log(1) --> 0 + log(:e) --> 1 + :e^(log(~x)) --> ~x +end + +@testset "Extraction 3" begin + g = EGraph(:(log(e))) + params = SaturationParams(timeout = 9) + saturate!(g, logids, params) + @test extract!(g, astsize) == 1 +end + +t = comm_monoid ∪ comm_group ∪ (@distrib (*) (+)) ∪ powers ∪ logids ∪ fold_mul ∪ fold_add + +@testset "Complex Extraction" begin + g = EGraph(:(log(e) * log(e))) + params = SaturationParams(timeout = 9) + saturate!(g, t, params) + @test extract!(g, astsize) == 1 + + g = EGraph(:(log(e) * (log(e) * e^(log(3))))) + params = SaturationParams(timeout = 7) + saturate!(g, t, params) + @test extract!(g, astsize) == 3 + + + g = EGraph(:(a^3 * a^2)) + saturate!(g, t) + ex = extract!(g, astsize) + @test ex == :(a^5) + + g = EGraph(:(a^3 * a^2)) + saturate!(g, t) + ex = extract!(g, astsize) + @test ex == :(a^5) +end + +@testset "Custom Cost Function 1" begin + function cust_astsize(n::ENode, children_costs::Vector{Float64})::Float64 + istree(n) || return 1 + cost = 1 + arity(n) + + if operation(n) == :^ + cost += 2 + end + + cost + sum(children_costs) + end + + g = EGraph(:((log(e) * log(e)) * (log(a^3 * a^2)))) + saturate!(g, t) + ex = extract!(g, cust_astsize) + @test ex == :(5 * log(a)) || ex == :(log(a) * 5) +end + +@testset "Symbols in Right hand" begin + expr = :(a * (a * (b * (a * b)))) + g = EGraph(expr) + + function costfun(n::ENode, children_costs::Vector{Float64})::Float64 + n.istree || return 1 + arity(n) != 2 && (return 1) + left = arguments(n)[1] + left_class = g[left] + ENode(:a) ∈ left_class.nodes ? 1 : 100 + end + + + moveright = @theory begin + (:b * (:a * ~c)) --> (:a * (:b * ~c)) + end + + res = rewrite(expr, moveright) + + saturate!(g, moveright) + resg = extract!(g, costfun) + + @test resg == res == :(a * (a * (a * (b * b)))) +end + +@testset "Consistency with classical backend" begin + co = @theory begin + sum(~x ⋅ :bazoo ⋅ :woo) --> sum(:n * ~x) + end + + ex = :(sum(wa(rio) ⋅ bazoo ⋅ woo)) + g = EGraph(ex) + saturate!(g, co) + + res = extract!(g, astsize) + resclassic = rewrite(ex, co) + + @test res == resclassic +end + + +@testset "No arguments" begin + ex = :(f()) + g = EGraph(ex) + @test :(f()) == extract!(g, astsize) + + ex = :(sin() + cos()) + + t = @theory begin + sin() + cos() --> tan() + end + + gg = EGraph(ex) + saturate!(gg, t) + res = extract!(gg, astsize) + + @test res == :(tan()) +end + + +@testset "Symbol or function object operators in expressions in EGraphs" begin + ex = :(($+)(x, y)) + t = [@rule a b a + b => 2] + g = EGraph(ex) + saturate!(g, t) + @test extract!(g, astsize) == 2 +end diff --git a/test/integration/kb_benchmark.jl b/test/integration/kb_benchmark.jl index dd1d1583..711d095a 100644 --- a/test/integration/kb_benchmark.jl +++ b/test/integration/kb_benchmark.jl @@ -51,14 +51,14 @@ another_expr = :(a * a * a * a) g = EGraph(another_expr) some_eclass = addexpr!(g, another_expr) saturate!(g, G) -ex = extract!(g, astsize; root = some_eclass) +ex = extract!(g, astsize) @test ex == :ε another_expr = :(((((((a * b) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)) * (a * b)) g = EGraph(another_expr) some_eclass = addexpr!(g, another_expr) saturate!(g, G) -ex = extract!(g, astsize; root = some_eclass) +ex = extract!(g, astsize) @test ex == :ε diff --git a/test/integration/lambda_theory.jl b/test/integration/lambda_theory.jl index 010bc845..982ad7ce 100644 --- a/test/integration/lambda_theory.jl +++ b/test/integration/lambda_theory.jl @@ -1,6 +1,4 @@ -using Metatheory -using Metatheory.EGraphs -using Test +using Metatheory, Test abstract type LambdaExpr end @@ -138,11 +136,11 @@ end λT = open_term ∪ subst_intro ∪ subst_prop ∪ subst_elim ex = λ(:x, Add(4, Apply(λ(:y, Variable(:y)), 4))) -g = EGraph(ex; head_type = LambdaHead) +g = EGraph{LambdaHead}(ex) saturate!(g, λT) @test λ(:x, Add(4, 4)) == extract!(g, astsize) # expected: :(λ(x, 4 + 4)) #%% -g = EGraph(; head_type = LambdaHead) +g = EGraph{LambdaHead}() @test areequal(g, λT, 2, Apply(λ(:x, Variable(:x)), 2)) \ No newline at end of file diff --git a/test/integration/stream_fusion.jl b/test/integration/stream_fusion.jl index def7648e..4ca3e093 100644 --- a/test/integration/stream_fusion.jl +++ b/test/integration/stream_fusion.jl @@ -75,18 +75,11 @@ normalize_theory = @theory x y z f g begin end -function stream_fusion_cost(n::ENode, g::EGraph) +function stream_fusion_cost(n::ENode, costs::Vector{Float64})::Float64 n.istree || return 1 cost = 1 + arity(n) - for id in arguments(n) - eclass = g[id] - !hasdata(eclass, stream_fusion_cost) && (cost += Inf; break) - cost += last(getdata(eclass, stream_fusion_cost)) - end - operation(n) ∈ (:map, :filter) && (cost += 10) - - return cost + cost + sum(costs) end function stream_optimize(ex) diff --git a/test/integration/while_superinterpreter.jl b/test/integration/while_superinterpreter.jl index c73671e5..3679fae6 100644 --- a/test/integration/while_superinterpreter.jl +++ b/test/integration/while_superinterpreter.jl @@ -34,44 +34,41 @@ end end @testset "If Semantics" begin - @test areequal(if_language, 2, :(if true + @test areequal(if_language, :(if true x else 0 - end, $(Mem(:x => 2)))) - @test areequal(if_language, 0, :(if false + end, $(Mem(:x => 2))), 2) + @test areequal(if_language, :(if false x else 0 - end, $(Mem(:x => 2)))) - @test areequal(if_language, 2, :(if !(false) + end, $(Mem(:x => 2))), 0) + @test areequal(if_language, :(if !(false) x else 0 - end, $(Mem(:x => 2)))) + end, $(Mem(:x => 2))), 2) params = SaturationParams(timeout = 10) - @test areequal(if_language, 0, :(if !(2 < x) + @test areequal(if_language, :(if !(2 < x) x else 0 - end, $(Mem(:x => 3))); params = params) + end, $(Mem(:x => 3))), 0; params = params) end @testset "While Semantics" begin exx = :((x = 3), $(Mem(:x => 2))) g = EGraph(exx) saturate!(g, while_language) - ex = extract!(g, astsize) + @test Mem(:x => 3) == extract!(g, astsize) - @test areequal(while_language, Mem(:x => 3), exx) exx = :((x = 4; x = x + 1), $(Mem(:x => 3))) g = EGraph(exx) saturate!(g, while_language) - ex = extract!(g, astsize) + @test Mem(:x => 5) == extract!(g, astsize) - params = SaturationParams(timeout = 10) - @test areequal(while_language, Mem(:x => 5), exx; params = params) params = SaturationParams(timeout = 14, timer = false) exx = :(( @@ -81,7 +78,7 @@ end skip end ), $(Mem(:x => 3))) - @test areequal(while_language, Mem(:x => 4), exx; params = params) + @test areequal(while_language, exx, Mem(:x => 4); params = params) exx = :((while x < 10 x = x + 1 diff --git a/test/tutorials/calculational_logic.jl b/test/tutorials/calculational_logic.jl index 213219d9..14c4a54f 100644 --- a/test/tutorials/calculational_logic.jl +++ b/test/tutorials/calculational_logic.jl @@ -9,20 +9,20 @@ include(joinpath(dirname(pathof(Metatheory)), "../examples/calculational_logic_t saturate!(g, calculational_logic_theory) extract!(g, astsize) - @test @areequal calculational_logic_theory true ((!p == p) == false) - @test @areequal calculational_logic_theory true ((!p == !p) == true) - @test @areequal calculational_logic_theory true ((!p || !p) == !p) (!p || p) !(!p && p) - @test @areequal calculational_logic_theory true ((p ⟹ (p || p)) == true) + @test @areequal calculational_logic_theory ((!p == p) == false) true + @test @areequal calculational_logic_theory ((!p == !p) == true) true + @test @areequal calculational_logic_theory ((!p || !p) == !p) (!p || p) !(!p && p) true + @test @areequal calculational_logic_theory ((p ⟹ (p || p)) == true) true params = SaturationParams(timeout = 12, eclasslimit = 10000, schedulerparams = (1000, 5)) - @test areequal(calculational_logic_theory, true, :(((p ⟹ (p || p)) == ((!(p) && q) ⟹ q))); params = params) + @test areequal(calculational_logic_theory, :(((p ⟹ (p || p)) == ((!(p) && q) ⟹ q))), true; params = params) ex = :((p ⟹ (q ⟹ r)) ⟹ ((p ⟹ q) ⟹ (p ⟹ r))) # Frege's theorem - res = areequal(calculational_logic_theory, true, ex; params = params) + res = areequal(calculational_logic_theory, ex, true; params = params) @test_broken !ismissing(res) && res - @test @areequal calculational_logic_theory true (!(p || q) == (!p && !q)) # Demorgan's + @test @areequal calculational_logic_theory (!(p || q) == (!p && !q)) true # Demorgan's areequal(calculational_logic_theory, :((x && y) || (!x && z) || (y && z)), :((x && y) || (!x && z)); params = params) # Consensus theorem end diff --git a/test/tutorials/custom_types.jl b/test/tutorials/custom_types.jl index 32597c61..9f8e1bc2 100644 --- a/test/tutorials/custom_types.jl +++ b/test/tutorials/custom_types.jl @@ -72,7 +72,13 @@ ex = :(a[b]) # `metadata` should return the extra metadata. If you have many fields, i suggest using a `NamedTuple`. -TermInterface.metadata(e::MyExpr) = e.foo +# TermInterface.metadata(e::MyExpr) = e.foo + +# struct MetadataAnalysis +# metadata +# end + +# function EGraphs.make(g::EGraph{MyExprHead,MetadataAnalysis}, n::ENode) = # Additionally, you can override `EGraphs.preprocess` on your custom expression # to pre-process any expression before insertion in the E-Graph. @@ -99,13 +105,15 @@ end # Let's create an example expression and e-graph hcall = MyExpr(:h, [4], "hello") ex = MyExpr(:f, [MyExpr(:z, [2]), hcall]) -# We use `head_type` kwarg on an existing e-graph to inform the system about +# We use the first type parameter an existing e-graph to inform the system about # the *default* type of expressions that we want newly added expressions to have. -g = EGraph(ex; keepmeta = true, head_type = MyExprHead) +g = EGraph{MyExprHead}(ex) # Now let's test that it works. saturate!(g, t) -expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "") +# expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "") +expected = MyExpr(:f, [MyExpr(:h, [4], "")], "") + extracted = extract!(g, astsize) @test expected == extracted diff --git a/test/tutorials/propositional_logic.jl b/test/tutorials/propositional_logic.jl index 5b77f0ee..0f36db85 100644 --- a/test/tutorials/propositional_logic.jl +++ b/test/tutorials/propositional_logic.jl @@ -9,18 +9,18 @@ include(joinpath(dirname(pathof(Metatheory)), "../examples/propositional_logic_t @test prove(propositional_logic_theory, ex, 5, 10, 5000) - @test @areequal propositional_logic_theory true ((!p == p) == false) - @test @areequal propositional_logic_theory true ((!p == !p) == true) - @test @areequal propositional_logic_theory true ((!p || !p) == !p) (!p || p) !(!p && p) - @test @areequal propositional_logic_theory p (p || p) - @test @areequal propositional_logic_theory true ((p ⟹ (p || p))) - @test @areequal propositional_logic_theory true ((p ⟹ (p || p)) == ((!(p) && q) ⟹ q)) == true + @test @areequal propositional_logic_theory ((!p == p) == false) true + @test @areequal propositional_logic_theory ((!p == !p) == true) true + @test @areequal propositional_logic_theory ((!p || !p) == !p) (!p || p) !(!p && p) true + @test @areequal propositional_logic_theory (p || p) p + @test @areequal propositional_logic_theory ((p ⟹ (p || p))) true + @test @areequal propositional_logic_theory ((p ⟹ (p || p)) == ((!(p) && q) ⟹ q)) == true true - @test @areequal propositional_logic_theory true (p ⟹ (q ⟹ r)) ⟹ ((p ⟹ q) ⟹ (p ⟹ r)) # Frege's theorem + @test @areequal propositional_logic_theory (p ⟹ (q ⟹ r)) ⟹ ((p ⟹ q) ⟹ (p ⟹ r)) true # Frege's theorem - @test @areequal propositional_logic_theory true (!(p || q) == (!p && !q)) # Demorgan's + @test @areequal propositional_logic_theory (!(p || q) == (!p && !q)) true # Demorgan's end # Consensus theorem -# @test_broken @areequal propositional_logic_theory true ((x && y) || (!x && z) || (y && z)) ((x && y) || (!x && z)) \ No newline at end of file +# @test_broken @areequal propositional_logic_theory ((x && y) || (!x && z) || (y && z)) ((x && y) || (!x && z)) true \ No newline at end of file From 7d140c3eafb93ab019096d1b80467b8c54d111d7 Mon Sep 17 00:00:00 2001 From: a Date: Tue, 9 Jan 2024 21:56:01 +0100 Subject: [PATCH 41/47] use nothing instead of missing for analysis --- docs/src/egraphs.md | 83 ++++++++++++++++------------------------ src/EGraphs/egraph.jl | 24 ++++++------ test/egraphs/analysis.jl | 8 ++-- 3 files changed, 47 insertions(+), 68 deletions(-) diff --git a/docs/src/egraphs.md b/docs/src/egraphs.md index 0463167c..c466355f 100644 --- a/docs/src/egraphs.md +++ b/docs/src/egraphs.md @@ -237,23 +237,18 @@ using Metatheory # This is a cost function that behaves like `astsize` but increments the cost # of nodes containing the `^` operation. This results in a tendency to avoid # extraction of expressions containing '^'. -function cost_function(n::ENode, g::EGraph) +# TODO: add example extraction +function cost_function(n::ENode, children_costs::Vector{Float64})::Float64 # All literal expressions (e.g `a`, 123, 0.42, "hello") have cost 1 istree(n) || return 1 cost = 1 + arity(n) - # This is where the custom cost is computed operation(n) == :^ && (cost += 2) - for id in arguments(n) - eclass = g[id] - # if the child e-class has not yet been analyzed, return +Inf - !hasdata(eclass, cost_function) && (cost += Inf; break) - cost += last(getdata(eclass, cost_function)) - end - return cost + cost + sum(children_costs) end + ``` ## EGraph Analyses @@ -272,7 +267,6 @@ In Metatheory.jl, **EGraph Analyses are uniquely identified** by either If you are specifying a custom analysis by its `Symbol` name, the following functions define an interface for analyses based on multiple dispatch on `Val{analysis_name::Symbol}`: -* [islazy(an)](@ref) should return true if the analysis name `an` should NOT be computed on-the-fly during egraphs operation, but only when inspected. * [make(an, egraph, n)](@ref) should take an ENode `n` and return a value from the analysis domain. * [join(an, x,y)](@ref) should return the semilattice join of `x` and `y` in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?*). If `an` is a `Function`, it is treated as a cost function analysis, it is automatically defined to be the minimum analysis value between `x` and `y`. Typically, the domain value of cost functions are real numbers, but if you really do want to have your own cost type, make sure that `Base.isless` is defined. * [modify!(an, egraph, eclassid)](@ref) Can be optionally implemented. This can be used modify an EClass `egraph[eclassid]` on-the-fly during an e-graph saturation iteration, given its analysis value. @@ -298,12 +292,15 @@ associate an analysis value only to the *literals* contained in the EGraph (the ```@example custom_analysis using Metatheory +struct OddEvenAnalysis + s::Symbol # :odd or :even +end + function odd_even_base_case(n::ENode) # Should be called only if istree(n) is false - return if operation(n) isa Integer - iseven(operation(n)) ? :even : :odd - else - nothing - end + if operation(n) isa Integer + OddEvenAnalysis(iseven(operation(n)) ? :even : :odd) + end + # It's ok to return nothing end # ... Rest of code defined below ``` @@ -326,44 +323,32 @@ From the definition of an [ENode](@ref), we know that children of ENodes are alw to EClasses in the EGraph. ```@example custom_analysis -function EGraphs.make(::Val{:OddEvenAnalysis}, g::EGraph, n::ENode) - if !istree(n) - return odd_even_base_case(n) - end +function EGraphs.make(g::EGraph{Head,OddEvenAnalysis}, n::ENode) where {Head} + istree(n) || return odd_even_base_case(n) # The e-node is not a literal value, # Let's consider only binary function call terms. if head_symbol(head(n)) == :call && arity(n) == 2 op = operation(n) # Get the left and right child eclasses child_eclasses = arguments(n) - l = g[child_eclasses[1]] - r = g[child_eclasses[2]] - - # Get the corresponding OddEvenAnalysis value of the children - # defaulting to nothing - ldata = getdata(l, :OddEvenAnalysis, nothing) - rdata = getdata(r, :OddEvenAnalysis, nothing) + l,r = g[child_eclasses[1]], g[child_eclasses[2]] - if ldata isa Symbol && rdata isa Symbol + if !isnothing(l.data) && !isnothing(r.data) if op == :* - if ldata == rdata - ldata - elseif (ldata == :even || rdata == :even) - :even - else - nothing + if l.data == r.data + l.data + elseif (l.data.s == :even || r.data.s == :even) + OddEvenAnalysis(:even) end elseif op == :+ - (ldata == rdata) ? :even : :odd + (l.data == r.data) ? OddEvenAnalysis(:even) : OddEvenAnalysis(:odd) end - elseif isnothing(ldata) && rdata isa Symbol && op == :* - rdata - elseif ldata isa Symbol && isnothing(rdata) && op == :* - ldata + elseif isnothing(l.data) && !isnothing(r.data) && op == :* + r.data + elseif !isnothing(l.data) && isnothing(r.data) && op == :* + l.data end end - - return nothing end ``` @@ -375,14 +360,11 @@ how to extract a single value out of the many analyses values contained in an EG We do this by defining a method for [join](@ref). ```@example custom_analysis -function EGraphs.join(::Val{:OddEvenAnalysis}, a, b) - if a == b - return a - else - # an expression cannot be odd and even at the same time! - # this is contradictory, so we ignore the analysis value - error("contradiction") - end +function EGraphs.join(a::OddEvenAnalysis, b::OddEvenAnalysis) + # an expression cannot be odd and even at the same time! + # this is contradictory, so we ignore the analysis value + a != b && error("contradiction") + a end ``` @@ -400,10 +382,9 @@ t = @theory a b c begin end function custom_analysis(expr) - g = EGraph(expr) + g = EGraph{ExprHead, OddEvenAnalysis}(expr) saturate!(g, t) - analyze!(g, :OddEvenAnalysis) - return getdata(g[g.root], :OddEvenAnalysis) + return g[g.root].data end custom_analysis(:(2*a)) # :even diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index ad816eba..0f7c8e28 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -88,7 +88,7 @@ mutable struct EClass{D} id::EClassId nodes::Vector{ENode} parents::Vector{Pair{ENode,EClassId}} - data::Union{D,Missing} + data::Union{D,Nothing} end # Interface for indexing EClass @@ -113,16 +113,16 @@ end function merge_analysis_data!(@nospecialize(a::EClass), @nospecialize(b::EClass))::Tuple{Bool,Bool} - if !ismissing(a.data) && !ismissing(b.data) + if !isnothing(a.data) && !isnothing(b.data) new_a_data = join(a.data, b.data) merged_a = (a.data == new_a_data) a.data = new_a_data (merged_a, b.data == new_a_data) - elseif !ismissing(a.data) && !ismissing(b.data) + elseif !isnothing(a.data) && !isnothing(b.data) a.data = b.data # a merged, b not merged (true, false) - elseif !ismissing(a.data) && !ismissing(b.data) + elseif !isnothing(a.data) && !isnothing(b.data) b.data = a.data (false, true) else @@ -180,8 +180,8 @@ function EGraph{Head,Analysis}(; needslock::Bool = false) where {Head,Analysis} ReentrantLock(), ) end -EGraph(; kwargs...) = EGraph{ExprHead,Missing}(; kwargs...) -EGraph{Head}(; kwargs...) where {Head} = EGraph{Head,Missing}(; kwargs...) +EGraph(; kwargs...) = EGraph{ExprHead,Nothing}(; kwargs...) +EGraph{Head}(; kwargs...) where {Head} = EGraph{Head,Nothing}(; kwargs...) function EGraph{Head,Analysis}(e; kwargs...) where {Head,Analysis} g = EGraph{Head,Analysis}(; kwargs...) @@ -189,11 +189,11 @@ function EGraph{Head,Analysis}(e; kwargs...) where {Head,Analysis} g end -EGraph{Head}(e; kwargs...) where {Head} = EGraph{Head,Missing}(e; kwargs...) -EGraph(e; kwargs...) = EGraph{typeof(head(e)),Missing}(e; kwargs...) +EGraph{Head}(e; kwargs...) where {Head} = EGraph{Head,Nothing}(e; kwargs...) +EGraph(e; kwargs...) = EGraph{typeof(head(e)),Nothing}(e; kwargs...) # Fallback implementation for analysis methods make and modify -@inline make(::EGraph, ::ENode) = missing +@inline make(::EGraph, ::ENode) = nothing @inline modify!(::EGraph, ::EClass{Analysis}) where {Analysis} = nothing @@ -405,10 +405,10 @@ function process_unions!(@nospecialize(g::EGraph))::Int eclass = g[eclass_id] node_data = make(g, node) - if !ismissing(eclass.data) + if !isnothing(eclass.data) joined_data = join(eclass.data, node_data) - if joined_data != class_data + if joined_data != eclass.data setdata!(eclass, an, joined_data) modify!(g, eclass) append!(g.analysis_pending, eclass.parents) @@ -446,7 +446,7 @@ end function check_analysis(g) for (id, eclass) in g.classes - ismissing(eclass.data) && continue + isnothing(eclass.data) && continue pass = mapreduce(x -> make(g, x), (x, y) -> join(x, y), eclass) @assert eclass.data == pass end diff --git a/test/egraphs/analysis.jl b/test/egraphs/analysis.jl index 54e09e2b..4c9c8b04 100644 --- a/test/egraphs/analysis.jl +++ b/test/egraphs/analysis.jl @@ -14,7 +14,7 @@ Base.:(+)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n # This should be auto-generated by a macro function EGraphs.make(g::EGraph{Head,NumberFoldAnalysis}, n::ENode) where {Head} - istree(n) || return operation(n) isa Number ? NumberFoldAnalysis(operation(n)) : missing + istree(n) || return operation(n) isa Number ? NumberFoldAnalysis(operation(n)) : nothing if head_symbol(head(n)) == :call && arity(n) == 2 op = operation(n) args = arguments(n) @@ -28,9 +28,7 @@ function EGraphs.make(g::EGraph{Head,NumberFoldAnalysis}, n::ENode) where {Head} end end end - - # Could not analyze - return missing + # Could not analyze, returns nothing end function EGraphs.join(from::NumberFoldAnalysis, to::NumberFoldAnalysis) @@ -40,7 +38,7 @@ end # Add the number to the eclass. function EGraphs.modify!(g::EGraph{Head,NumberFoldAnalysis}, eclass::EClass{NumberFoldAnalysis}) where {Head} - ismissing(eclass.data) || union!(g, addexpr!(g, eclass.data.n), find(g, eclass.id)) + isnothing(eclass.data) || union!(g, addexpr!(g, eclass.data.n), find(g, eclass.id)) end From 0bd3021f2d574dcf8d080371bb56241e7ecdc5d0 Mon Sep 17 00:00:00 2001 From: a Date: Tue, 9 Jan 2024 22:05:52 +0100 Subject: [PATCH 42/47] remove assertions --- src/EGraphs/egraph.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 0f7c8e28..04086606 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -462,8 +462,8 @@ for more details. function rebuild!(g::EGraph) n_unions = process_unions!(g) trimmed_nodes = rebuild_classes!(g) - @assert check_memo(g) - @assert check_analysis(g) + # @assert check_memo(g) + # @assert check_analysis(g) g.clean = true @debug "REBUILT" n_unions trimmed_nodes From 329595ef0cce30ac275ef10a167ec67d31f8f3c3 Mon Sep 17 00:00:00 2001 From: a Date: Wed, 10 Jan 2024 18:08:08 +0100 Subject: [PATCH 43/47] use UInt as id --- src/EGraphs/egraph.jl | 44 +++++++++++++++++++-------------------- src/EGraphs/saturation.jl | 18 +++++++++------- src/EGraphs/unionfind.jl | 10 ++++----- src/Patterns.jl | 32 +++++++++++++++------------- src/TermInterface.jl | 9 ++++++-- src/ematch_compiler.jl | 28 ++++++++++++------------- test/egraphs/egraphs.jl | 20 +++++++++--------- test/egraphs/unionfind.jl | 18 +++++++--------- 8 files changed, 94 insertions(+), 85 deletions(-) diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index 04086606..71e4461c 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -29,11 +29,10 @@ Given an ENode `n`, `make` should return the corresponding analysis value. """ function make end -const EClassId = Int64 -const TermTypes = Dict{Tuple{Any,Int},Type} +const EClassId = UInt64 # TODO document bindings -const Bindings = Base.ImmutableDict{Int,Tuple{Int,Int}} -const UNDEF_ARGS = Vector{EClassId}(undef, 0) +const Bindings = Base.ImmutableDict{Int,Tuple{EClassId,Int}} +const UNDEF_ID_VEC = Vector{EClassId}(undef, 0) # @compactify begin struct ENode @@ -44,7 +43,7 @@ struct ENode args::Vector{EClassId} hash::Ref{UInt} ENode(head, operation, args) = new(true, head, operation, args, Ref{UInt}(0)) - ENode(literal) = new(false, nothing, literal, UNDEF_ARGS, Ref{UInt}(0)) + ENode(literal) = new(false, nothing, literal, UNDEF_ID_VEC, Ref{UInt}(0)) end TermInterface.istree(n::ENode) = n.istree @@ -52,7 +51,7 @@ TermInterface.head(n::ENode) = n.head TermInterface.operation(n::ENode) = n.operation TermInterface.arguments(n::ENode) = n.args TermInterface.children(n::ENode) = [n.operation; n.args...] -TermInterface.arity(n::ENode) = length(n.args) +TermInterface.arity(n::ENode)::Int = length(n.args) # This optimization comes from SymbolicUtils @@ -78,7 +77,7 @@ end Base.show(io::IO, x::ENode) = print(io, to_expr(x)) -function op_key(n) +function op_key(n)::Pair{Any,Int} op = operation(n) (op isa Union{Function,DataType} ? nameof(op) : op) => (istree(n) ? arity(n) : -1) end @@ -155,7 +154,7 @@ mutable struct EGraph{Head,Analysis} "Buffer for e-matching which defaults to a global. Use a local buffer for generated functions." buffer::Vector{Bindings} "Buffer for rule application which defaults to a global. Use a local buffer for generated functions." - merges_buffer::Vector{Tuple{Int,Int}} + merges_buffer::Vector{EClassId} lock::ReentrantLock end @@ -167,16 +166,16 @@ Construct an EGraph from a starting symbolic expression `expr`. function EGraph{Head,Analysis}(; needslock::Bool = false) where {Head,Analysis} EGraph{Head,Analysis}( UnionFind(), - Dict{EClassId,EClass}(), + Dict{EClassId,EClass{Analysis}}(), Dict{ENode,EClassId}(), Pair{ENode,EClassId}[], UniqueQueue{Pair{ENode,EClassId}}(), - -1, + 0, Dict{Pair{Any,Int},Vector{EClassId}}(), false, needslock, Bindings[], - Tuple{Int,Int}[], + EClassId[], ReentrantLock(), ) end @@ -232,7 +231,7 @@ end function lookup(g::EGraph, n::ENode)::EClassId cc = canonicalize(g, n) - haskey(g.memo, cc) ? find(g, g.memo[cc]) : -1 + haskey(g.memo, cc) ? find(g, g.memo[cc]) : 0 end @@ -288,26 +287,22 @@ Recursively traverse an type satisfying the `TermInterface` and insert terms int [`EGraph`](@ref). If `e` has no children (has an arity of 0) then directly insert the literal into the [`EGraph`](@ref). """ -function addexpr!(g::EGraph, se, keepmeta = false)::EClassId +function addexpr!(g::EGraph, se)::EClassId se isa EClass && return se.id e = preprocess(se) n = if istree(se) args = arguments(e) - ar = length(args) + ar = arity(e) class_ids = Vector{EClassId}(undef, ar) for i in 1:ar - @inbounds class_ids[i] = addexpr!(g, args[i], keepmeta) + @inbounds class_ids[i] = addexpr!(g, args[i]) end ENode(head(e), operation(e), class_ids) else # constant enode ENode(e) end id = add!(g, n) - if keepmeta - meta = TermInterface.metadata(e) - !isnothing(meta) && setdata!(g.classes[id], :metadata_analysis, meta) - end return id end @@ -512,15 +507,18 @@ function lookup_pat(g::EGraph{Head}, p::PatTerm)::EClassId where {Head} eh = Head(head_symbol(head(p))) - ids = map(x -> lookup_pat(g, x), args) - !all((>)(0), ids) && return -1 + ids = Vector{EClassId}(undef, ar) + for i in 1:ar + @inbounds ids[i] = lookup_pat(g, args[i]) + ids[i] <= 0 && return 0 + end if Head == ExprHead && op isa Union{Function,DataType} id = lookup(g, ENode(eh, op, ids)) - id < 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id + id <= 0 ? lookup(g, ENode(eh, nameof(op), ids)) : id else lookup(g, ENode(eh, op, ids)) end end -lookup_pat(g::EGraph, p::Any) = lookup(g, ENode(p)) +lookup_pat(g::EGraph, p::Any)::EClassId = lookup(g, ENode(p)) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 734e2d95..6586dbcf 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -38,12 +38,12 @@ Base.@kwdef mutable struct SaturationParams timer::Bool = true end -function cached_ids(g::EGraph, p::PatTerm)# ::Vector{Int64} +function cached_ids(g::EGraph, p::PatTerm)::Vector{EClassId} if isground(p) id = lookup_pat(g, p) !isnothing(id) && return [id] else - get(g.classes_by_op, op_key(p), ()) + get(g.classes_by_op, op_key(p), UNDEF_ID_VEC) end end @@ -115,13 +115,15 @@ function instantiate_enode!(bindings::Bindings, g::EGraph{Head}, p::PatTerm)::EC end function apply_rule!(buf, g::EGraph, rule::RewriteRule, id, direction) - push!(g.merges_buffer, (id, instantiate_enode!(buf, g, rule.right))) + push!(g.merges_buffer, id) + push!(g.merges_buffer, instantiate_enode!(buf, g, rule.right)) nothing end function apply_rule!(bindings::Bindings, g::EGraph, rule::EqualityRule, id::EClassId, direction::Int) pat_to_inst = direction == 1 ? rule.right : rule.left - push!(g.merges_buffer, (id, instantiate_enode!(bindings, g, pat_to_inst))) + push!(g.merges_buffer, id) + push!(g.merges_buffer, instantiate_enode!(bindings, g, pat_to_inst)) nothing end @@ -156,7 +158,8 @@ function apply_rule!(bindings::Bindings, g::EGraph, rule::DynamicRule, id::EClas r = f(id, g, (instantiate_actual_param!(bindings, g, i) for i in 1:length(rule.patvars))...) isnothing(r) && return nothing rcid = addexpr!(g, r) - push!(g.merges_buffer, (id, rcid)) + push!(g.merges_buffer, id) + push!(g.merges_buffer, rcid) return nothing end @@ -177,7 +180,7 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation end bindings = pop!(g.buffer) - rule_idx, id = bindings[0] + id, rule_idx = bindings[0] direction = sign(rule_idx) rule_idx = abs(rule_idx) rule = theory[rule_idx] @@ -198,7 +201,8 @@ function eqsat_apply!(g::EGraph, theory::Vector{<:AbstractRule}, rep::Saturation end maybelock!(g) do while !isempty(g.merges_buffer) - (l, r) = pop!(g.merges_buffer) + l = pop!(g.merges_buffer) + r = pop!(g.merges_buffer) union!(g, l, r) end end diff --git a/src/EGraphs/unionfind.jl b/src/EGraphs/unionfind.jl index 0e19aa31..e2aa6ada 100644 --- a/src/EGraphs/unionfind.jl +++ b/src/EGraphs/unionfind.jl @@ -1,10 +1,10 @@ struct UnionFind - parents::Vector{Int} + parents::Vector{UInt} end -UnionFind() = UnionFind(Int[]) +UnionFind() = UnionFind(UInt[]) -function Base.push!(uf::UnionFind) +function Base.push!(uf::UnionFind)::UInt l = length(uf.parents) + 1 push!(uf.parents, l) l @@ -12,12 +12,12 @@ end Base.length(uf::UnionFind) = length(uf.parents) -function Base.union!(uf::UnionFind, i::Int, j::Int) +function Base.union!(uf::UnionFind, i::UInt, j::UInt) uf.parents[j] = i i end -function find(uf::UnionFind, i::Int) +function find(uf::UnionFind, i::UInt) while i != uf.parents[i] i = uf.parents[i] end diff --git a/src/Patterns.jl b/src/Patterns.jl index 2179cf65..546864b7 100644 --- a/src/Patterns.jl +++ b/src/Patterns.jl @@ -13,7 +13,7 @@ abstract type AbstractPat end struct PatHead head end -TermInterface.head_symbol(p::PatHead) = p.head +TermInterface.head_symbol(p::PatHead)::Symbol = p.head PatHead(p::PatHead) = error("recursive!") @@ -83,34 +83,38 @@ symbol `operation` and expression head `head.head`. struct PatTerm <: AbstractPat head::PatHead children::Vector - PatTerm(h, t::Vector) = new(h, t) + isground::Bool + PatTerm(h, t::Vector) = new(h, t, all(isground, t)) end PatTerm(eh, op) = PatTerm(eh, [op]) PatTerm(eh, children...) = PatTerm(eh, collect(children)) + +isground(p::PatTerm)::Bool = p.isground + TermInterface.istree(::PatTerm) = true TermInterface.head(p::PatTerm)::PatHead = p.head TermInterface.children(p::PatTerm) = p.children function TermInterface.operation(p::PatTerm) hs = head_symbol(head(p)) - hs == :call && return first(p.children) + hs in (:call, :macrocall) && return first(p.children) # hs == :ref && return getindex hs end function TermInterface.arguments(p::PatTerm) hs = head_symbol(head(p)) - hs == :call ? @view(p.children[2:end]) : p.children + hs in (:call, :macrocall) ? @view(p.children[2:end]) : p.children +end +function TermInterface.arity(p::PatTerm) + hs = head_symbol(head(p)) + l = length(p.children) + hs in (:call, :macrocall) ? l - 1 : l end -TermInterface.arity(p::PatTerm) = length(arguments(p)) TermInterface.metadata(p::PatTerm) = nothing TermInterface.maketerm(head::PatHead, children; type = Any, metadata = nothing) = PatTerm(head, children...) -isground(p::PatTerm) = all(isground, p.children) - - -# ============================================== -# ================== PATTERN VARIABLES ========= -# ============================================== +# --------------------- +# # Pattern Variables. """ Collects pattern variables appearing in a pattern into a vector of symbols @@ -122,9 +126,9 @@ patvars(x, s) = s patvars(p) = unique!(patvars(p, Symbol[])) -# ============================================== -# ================== DEBRUJIN INDEXING ========= -# ============================================== +# --------------------- +# # Debrujin Indexing. + function setdebrujin!(p::Union{PatVar,PatSegment}, pvars) p.idx = findfirst((==)(p.name), pvars) diff --git a/src/TermInterface.jl b/src/TermInterface.jl index cc17e5a9..d2fe5e59 100644 --- a/src/TermInterface.jl +++ b/src/TermInterface.jl @@ -114,7 +114,7 @@ export unsorted_arguments Returns the number of arguments of `x`. Implicitly defined if `arguments(x)` is defined. """ -arity(x) = length(arguments(x)) +arity(x)::Int = length(arguments(x)) export arity @@ -220,7 +220,7 @@ struct ExprHead end export ExprHead -head_symbol(eh::ExprHead) = eh.head +head_symbol(eh::ExprHead)::Symbol = eh.head istree(x::Expr) = true head(e::Expr) = ExprHead(e.head) @@ -247,6 +247,11 @@ function arguments(e::Expr) end end +function arity(e::Expr)::Int + l = length(e.args) + e.head in (:call, :macrocall) ? l - 1 : l +end + function maketerm(head::ExprHead, children; type = Any, metadata = nothing) if !isempty(children) && first(children) isa Union{Function,DataType} Expr(head.head, nameof(first(children)), @view(children[2:end])...) diff --git a/src/ematch_compiler.jl b/src/ematch_compiler.jl index 0e5088fd..b6b19d04 100644 --- a/src/ematch_compiler.jl +++ b/src/ematch_compiler.jl @@ -32,7 +32,7 @@ end function predicate_ematcher(p::PatVar, pred) function predicate_ematcher(next, g, data, bindings) !islist(data) && return - id::Int = car(data) + id::UInt = car(data) eclass = g[id] if pred(eclass) enode_idx = 0 @@ -122,27 +122,27 @@ function ematcher(p::PatTerm) end -const EMPTY_ECLASS_DICT = Base.ImmutableDict{Int,Tuple{Int,Int}}() +const EMPTY_BINDINGS = Base.ImmutableDict{Int,Tuple{UInt,Int}}() """ -Substitutions are efficiently represented in memory as vector of tuples of two integers. -This should allow for static allocation of matches and use of LoopVectorization.jl -The buffer has to be fairly big when e-matching. -The size of the buffer should double when there's too many matches. -The format is as follows -* The first pair denotes the index of the rule in the theory and the e-class id - of the node of the e-graph that is being substituted. The rule number should be negative if it's a bidirectional - the direction is right-to-left. -* From the second pair on, it represents (e-class id, literal position) at the position of the pattern variable -* The end of a substitution is delimited by (0,0) +Substitutions are efficiently represented in memory as immutable dictionaries of tuples of two integers. + +The format is as follows: + +bindings[0] holds + 1. e-class-id of the node of the e-graph that is being substituted. + 2. the index of the rule in the theory. The rule number should be negative + if it's a bidirectional rule and the direction is right-to-left. + +The rest of the immutable dictionary bindings[n>0] represents (e-class id, literal position) at the position of the pattern variable `n`. """ function ematcher_yield(p, npvars::Int, direction::Int) em = ematcher(p) function ematcher_yield(g, rule_idx, id)::Int n_matches = 0 - em(g, (id,), EMPTY_ECLASS_DICT) do b, n + em(g, (id,), EMPTY_BINDINGS) do b, n maybelock!(g) do - push!(g.buffer, assoc(b, 0, (rule_idx * direction, id))) + push!(g.buffer, assoc(b, 0, (id, rule_idx * direction))) n_matches += 1 end end diff --git a/test/egraphs/egraphs.jl b/test/egraphs/egraphs.jl index 9bd20711..493066bb 100644 --- a/test/egraphs/egraphs.jl +++ b/test/egraphs/egraphs.jl @@ -7,8 +7,8 @@ using Metatheory testmatch = :(a << 1) g = EGraph(testexpr) t2 = addexpr!(g, testmatch) - union!(g, t2, 3) - @test find(g, t2) == find(g, 3) + union!(g, t2, EClassId(3)) + @test find(g, t2) == find(g, EClassId(3)) # DOES NOT UPWARD MERGE end @@ -43,8 +43,8 @@ end t1 = addexpr!(g, apply(6, f, :a)) t2 = addexpr!(g, apply(9, f, :a)) - c_id = union!(g, t1, 1) # a == apply(6,f,a) - c2_id = union!(g, t2, 1) # a == apply(9,f,a) + c_id = union!(g, t1, EClassId(1)) # a == apply(6,f,a) + c2_id = union!(g, t2, EClassId(1)) # a == apply(9,f,a) rebuild!(g) @@ -52,10 +52,10 @@ end t4 = addexpr!(g, apply(7, f, :a)) # f^m(a) = a = f^n(a) ⟹ f^(gcd(m,n))(a) = a - @test find(g, t1) == find(g, 1) - @test find(g, t2) == find(g, 1) - @test find(g, t3) == find(g, 1) - @test find(g, t4) != find(g, 1) + @test find(g, t1) == find(g, EClassId(1)) + @test find(g, t2) == find(g, EClassId(1)) + @test find(g, t3) == find(g, EClassId(1)) + @test find(g, t4) != find(g, EClassId(1)) # if m or n is prime, f(a) = a t5 = addexpr!(g, apply(11, f, :a)) @@ -64,6 +64,6 @@ end rebuild!(g) - @test find(g, t5) == find(g, 1) - @test find(g, t6) == find(g, 1) + @test find(g, t5) == find(g, EClassId(1)) + @test find(g, t6) == find(g, EClassId(1)) end diff --git a/test/egraphs/unionfind.jl b/test/egraphs/unionfind.jl index 24fc4013..cf151e30 100644 --- a/test/egraphs/unionfind.jl +++ b/test/egraphs/unionfind.jl @@ -8,17 +8,15 @@ for _ in 1:n push!(uf) end -union!(uf, 1, 2) -union!(uf, 1, 3) -union!(uf, 1, 4) +union!(uf, UInt(1), UInt(2)) +union!(uf, UInt(1), UInt(3)) +union!(uf, UInt(1), UInt(4)) -union!(uf, 6, 8) -union!(uf, 6, 9) -union!(uf, 6, 10) +union!(uf, UInt(6), UInt(8)) +union!(uf, UInt(6), UInt(9)) +union!(uf, UInt(6), UInt(10)) for i in 1:n - find(uf, i) + find(uf, UInt(i)) end -@test uf.parents == [1, 1, 1, 1, 5, 6, 7, 6, 6, 6] - -# TODO test path compression \ No newline at end of file +@test uf.parents == UInt[1, 1, 1, 1, 5, 6, 7, 6, 6, 6] From 6a256d0143d82726d90ee01dec148472e3a075ac Mon Sep 17 00:00:00 2001 From: a Date: Thu, 11 Jan 2024 19:01:12 +0100 Subject: [PATCH 44/47] better readme --- README.md | 114 +++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 100 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 1c4d68f9..816a7c90 100644 --- a/README.md +++ b/README.md @@ -12,17 +12,27 @@ [](https://joss.theoj.org/papers/3266e8a08a75b9be2f194126a9c6f0e9) [](https://julialang.zulipchat.com/#narrow/stream/277860-metatheory.2Ejl) -**Metatheory.jl** is a general purpose term rewriting, metaprogramming and algebraic computation library for the Julia programming language, designed to take advantage of the powerful reflection capabilities to bridge the gap between symbolic mathematics, abstract interpretation, equational reasoning, optimization, composable compiler transforms, and advanced -homoiconic pattern matching features. The core features of Metatheory.jl are a powerful rewrite rule definition language, a vast library of functional combinators for classical term rewriting and an *e-graph rewriting*, a fresh approach to term rewriting achieved through an equality saturation algorithm. Metatheory.jl can manipulate any kind of -Julia symbolic expression type, ~~as long as it satisfies the [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl)~~. +**Metatheory.jl** is a general purpose term rewriting, metaprogramming and +algebraic computation library for the Julia programming language, designed to +take advantage of the powerful reflection capabilities to bridge the gap between +symbolic mathematics, abstract interpretation, equational reasoning, +optimization, composable compiler transforms, and advanced homoiconic pattern +matching features. The core features of Metatheory.jl are a powerful rewrite +rule definition language, a vast library of functional combinators for classical +term rewriting and an *[e-graph](https://en.wikipedia.org/wiki/E-graph) +rewriting*, a fresh approach to term rewriting achieved through an equality +saturation algorithm. Metatheory.jl can manipulate any kind of Julia symbolic +expression type, ~~as long as it satisfies the [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl)~~. ### NOTE: TermInterface.jl has been temporarily deprecated. Its functionality has moved to module [Metatheory.TermInterface](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/TermInterface.jl) until consensus for a shared symbolic term interface is reached by the community. + + Metatheory.jl provides: -- An eDSL (domain specific language) to define different kinds of symbolic rewrite rules. +- An eDSL (embedded domain specific language) to define different kinds of symbolic rewrite rules. - A classical rewriting backend, derived from the [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl) pattern matcher, supporting associative-commutative rules. It is based on the pattern matcher in the [SICM book](https://mitpress.mit.edu/sites/default/files/titles/content/sicm_edition_2/book.html). - A flexible library of rewriter combinators. -- An e-graph rewriting (equality saturation) engine, based on the [egg](https://egraphs-good.github.io/) library, supporting a backtracking pattern matcher and non-deterministic term rewriting by using a data structure called *e-graph*, efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning. +- An [e-graph](https://en.wikipedia.org/wiki/E-graph) rewriting (equality saturation) engine, based on the [egg](https://egraphs-good.github.io/) library, supporting a backtracking pattern matcher and non-deterministic term rewriting by using a data structure called [e-graph](https://en.wikipedia.org/wiki/E-graph), efficiently incorporating the notion of equivalence in order to reduce the amount of user effort required to achieve optimization tasks and equational reasoning. - `@capture` macro for flexible metaprogramming. Intuitively, Metatheory.jl transforms Julia expressions @@ -33,14 +43,7 @@ This allows users to perform customized and composable compiler optimizations sp Our library provides a simple, algebraically composable interface to help scientists in implementing and reasoning about semantics and all kinds of formal systems, by defining concise rewriting rules in pure, syntactically valid Julia on a high level of abstraction. Our implementation of equality saturation on e-graphs is based on the excellent, state-of-the-art technique implemented in the [egg](https://egraphs-good.github.io/) library, reimplemented in pure Julia. - - - - - - - -## 3.0 WORK IN PROGRESS! +## 3.0 Alpha - Many tests have been rewritten in [Literate.jl](https://github.com/fredrikekre/Literate.jl) format and are thus narrative tutorials available in the docs. - Many performance optimizations. - Comprehensive benchmarks are available. @@ -48,8 +51,91 @@ Our library provides a simple, algebraically composable interface to help scient - Lots of bugfixes. +--- + +## We need your help! - Practical and Research Contributions + +There's lot of room for improvement for Metatheory.jl, by making it more performant and by extending its features. +Any contribution is welcome! + +**Performance**: +- Improving the speed of the e-graph pattern matcher. [(Useful paper)](https://arxiv.org/abs/2108.02290) +- Reducing allocations used by Equality Saturation. +- Goal-informed [rule schedulers](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/EGraphs/Schedulers.jl): develop heuristic algorithms that choose what rules to apply at each equality saturation iteration to prune space of possible rewrites. + +**Features**: +- Introduce proof production capabilities for e-graphs. This can be based on the [egg implementation](https://github.com/egraphs-good/egg/blob/main/src/explain.rs). +- Common Subexpression Elimination when extracting from an e-graph +- Integer Linear Programming extraction of expressions. + +**Documentation**: +- Port more [integration tests]() to [tutorials]() that are rendered with [Literate.jl](https://github.com/fredrikekre/Literate.jl) +- Document [Functional Rewrite Combinators](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/Rewriters.jl) and add a tutorial. + +Most importantly, there are many **practical real world applications** where Metatheory.jl could be used. Let's +work together to turn this list into some new Julia packages: + +#### Integration with Symbolics.jl + +Many features of this package, such as the classical rewriting system, have been ported from [SymbolicUtils.jl], and are technically the same. Integration between Metatheory.jl with Symbolics.jl **is currently +paused**, as we are planning to [reach consensus in the development of a common Julia symbolic term interface](https://github.com/JuliaSymbolics/TermInterface.jl). + +An integration between Metatheory.jl and [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) is possible and has previously been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. Once we reach consensus for a shared symbolic term interface, Metatheory.jl can be used to: + +- Rewrite Symbolics.jl expressions with **bi-directional equations** instead of simple directed rewrite rules. +- Search for the space of mathematically equivalent Symbolics.jl expressions for more computationally efficient forms to speed various packages like [ModelingToolkit.jl](https://github.com/SciML/ModelingToolkit.jl) that numerically evaluate Symbolics.jl expressions. +- When proof production is introduced in Metatheory.jl, automatically search the space of a domain-specific equational theory to prove that Symbolics.jl expressions are equal in that theory. +- Other scientific domains extending Symbolics.jl for system modeling. + +#### Simplifying Quantum Algebras + +[QuantumCumulants.jl](https://github.com/qojulia/QuantumCumulants.jl/) automates +the symbolic derivation of mean-field equations in quantum mechanics, expanding +them in cumulants and generating numerical solutions using state-of-the-art +solvers like [ModelingToolkit.jl](https://github.com/SciML/ModelingToolkit.jl) +and +[DifferentialEquations.jl](https://github.com/SciML/DifferentialEquations.jl). A +potential application for Metatheory.jl is domain-specific code optimization for +QuantumCumulants.jl, aiming to be the first symbolic simplification engine for +Fock algebras. + + +#### Automatic Floating Point Error Fixer + + +[Herbie](https://herbie.uwplse.org/) is a tool using equality saturation to automatically rewrites mathematical expressions to enhance +floating-point accuracy. Recently, Herbie's core has been rewritten using +[egg](https://egraphs-good.github.io/), with the tool originally implemented in +a mix of Racket, Scheme, and Rust. While effective, its usage involves multiple +languages, making it impractical for non-experts. The text suggests the theoretical +possibility of porting this technique to a pure Julia solution, seamlessly +integrating with the language, in a single macro `@fp_optimize` that fixes +floating-point errors in expressions just before code compilation and execution. + +#### Automatic Theorem Proving in Julia + +Metatheory.jl can be used to make a pure Julia Automated Theorem Prover (ATP) +inspired by the use of E-graphs in existing ATP environments like +[Z3](https://github.com/Z3Prover/z3), [Simplify](https://dl.acm.org/doi/10.1145/1066100.1066102) and [CVC4](https://en.wikipedia.org/wiki/CVC4), +in the context of [Satisfiability Modulo Theories (SMT)](https://en.wikipedia.org/wiki/Satisfiability_modulo_theories). + +The two-language problem in program verification can be addressed by allowing users to define high-level +theories about their code, that are statically verified before executing the program. This holds potential for various applications in +software verification, offering a flexible and generic environment for proving +formulae in different logics, and statically verifying such constraints on Julia +code before it gets compiled (see +[Mixtape.jl](https://github.com/JuliaCompilerPlugins/Mixtape.jl)). + +**Some concrete steps**: + +- Introduce Proof Production in equality saturation. +- Test using Metatheory for SMT in conjunction with a SAT solver like [PicoSAT.jl](https://github.com/sisl/PicoSAT.jl) +- Test out various logic theories and software verification applications. + +#### And much more -Many features have been ported from SymbolicUtils.jl. Metatheory.jl can be used in place of SymbolicUtils.jl when you have no need of manipulating mathematical expressions. Integration between Metatheory.jl with Symbolics.jl, as it has been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. +Many projects that could potentially be ported to Julai are listed on the [egg website]. +A simple search for ["equality saturation" on Google Scholar](https://scholar.google.com/scholar?hl=en&q="equality+saturation") shows. ## Recommended Readings - Selected Publications From 0d9c532f74cc38080b5bb15e315137cee384d494 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 11 Jan 2024 19:33:25 +0100 Subject: [PATCH 45/47] remove TODO and enhance readme --- README.md | 38 +++++++++++++++++++++++++++----------- src/EGraphs/uniquequeue.jl | 1 - 2 files changed, 27 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 816a7c90..a13b51e9 100644 --- a/README.md +++ b/README.md @@ -61,20 +61,25 @@ Any contribution is welcome! **Performance**: - Improving the speed of the e-graph pattern matcher. [(Useful paper)](https://arxiv.org/abs/2108.02290) - Reducing allocations used by Equality Saturation. -- Goal-informed [rule schedulers](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/EGraphs/Schedulers.jl): develop heuristic algorithms that choose what rules to apply at each equality saturation iteration to prune space of possible rewrites. +- [#50](https://github.com/JuliaSymbolics/Metatheory.jl/issues/50) - Goal-informed [rule schedulers](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/EGraphs/Schedulers.jl): develop heuristic algorithms that choose what rules to apply at each equality saturation iteration to prune space of possible rewrites. **Features**: -- Introduce proof production capabilities for e-graphs. This can be based on the [egg implementation](https://github.com/egraphs-good/egg/blob/main/src/explain.rs). -- Common Subexpression Elimination when extracting from an e-graph +- [#111](https://github.com/JuliaSymbolics/Metatheory.jl/issues/111) Introduce proof production capabilities for e-graphs. This can be based on the [egg implementation](https://github.com/egraphs-good/egg/blob/main/src/explain.rs). +- Common Subexpression Elimination when extracting from an e-graph [#158](https://github.com/JuliaSymbolics/Metatheory.jl/issues/158) - Integer Linear Programming extraction of expressions. +- Pattern matcher enhancements: [#43 Better parsing of blocks](https://github.com/JuliaSymbolics/Metatheory.jl/issues/43), [#3 Support `...` variables in e-graphs](https://github.com/JuliaSymbolics/Metatheory.jl/issues/3), [#89 syntax for vectors](https://github.com/JuliaSymbolics/Metatheory.jl/issues/89) +- [#75 E-Graph intersection algorithm](https://github.com/JuliaSymbolics/Metatheory.jl/issues/75) **Documentation**: -- Port more [integration tests]() to [tutorials]() that are rendered with [Literate.jl](https://github.com/fredrikekre/Literate.jl) +- Port more [integration tests](https://github.com/JuliaSymbolics/Metatheory.jl/tree/master/test/integration) to [tutorials](https://github.com/JuliaSymbolics/Metatheory.jl/tree/master/test/tutorials) that are rendered with [Literate.jl](https://github.com/fredrikekre/Literate.jl) - Document [Functional Rewrite Combinators](https://github.com/JuliaSymbolics/Metatheory.jl/blob/master/src/Rewriters.jl) and add a tutorial. +## Real World Applications + Most importantly, there are many **practical real world applications** where Metatheory.jl could be used. Let's work together to turn this list into some new Julia packages: + #### Integration with Symbolics.jl Many features of this package, such as the classical rewriting system, have been ported from [SymbolicUtils.jl], and are technically the same. Integration between Metatheory.jl with Symbolics.jl **is currently @@ -126,16 +131,27 @@ formulae in different logics, and statically verifying such constraints on Julia code before it gets compiled (see [Mixtape.jl](https://github.com/JuliaCompilerPlugins/Mixtape.jl)). -**Some concrete steps**: +To develop such a package, Metatheory.jl needs: + +- Introduction of Proof Production in equality saturation. +- SMT in conjunction with a SAT solver like [PicoSAT.jl](https://github.com/sisl/PicoSAT.jl) +- Experiments with various logic theories and software verification applications. -- Introduce Proof Production in equality saturation. -- Test using Metatheory for SMT in conjunction with a SAT solver like [PicoSAT.jl](https://github.com/sisl/PicoSAT.jl) -- Test out various logic theories and software verification applications. +#### Other potential applications -#### And much more +Many projects that could potentially be ported to Julia are listed on the [egg website](https://egraphs-good.github.io/). +A simple search for ["equality saturation" on Google Scholar](https://scholar.google.com/scholar?hl=en&q="equality+saturation") shows many new articles that leverage the techniques used in this packages. -Many projects that could potentially be ported to Julai are listed on the [egg website]. -A simple search for ["equality saturation" on Google Scholar](https://scholar.google.com/scholar?hl=en&q="equality+saturation") shows. +PLDI is a premier academic forum in the field of programming languages and programming systems research, which organizes an [e-graph symposium](https://pldi23.sigplan.org/home/egraphs-2023) where many interesting research and projects have been presented. + +--- + +## Theoretical Developments + +TODO +https://effect.systems/blog/ta-completion.html + +--- ## Recommended Readings - Selected Publications diff --git a/src/EGraphs/uniquequeue.jl b/src/EGraphs/uniquequeue.jl index 079916bf..aade15d6 100644 --- a/src/EGraphs/uniquequeue.jl +++ b/src/EGraphs/uniquequeue.jl @@ -25,7 +25,6 @@ function Base.append!(uq::UniqueQueue{T}, xs::Vector{T}) where {T} end function Base.pop!(uq::UniqueQueue{T}) where {T} - # TODO maybe popfirst? v = pop!(uq.vec) delete!(uq.set, v) v From ec7a2e458cfe4ecdd1e83550a1803e97417f1ca8 Mon Sep 17 00:00:00 2001 From: a Date: Thu, 11 Jan 2024 19:54:55 +0100 Subject: [PATCH 46/47] ideas --- README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index a13b51e9..a1b606e8 100644 --- a/README.md +++ b/README.md @@ -148,7 +148,14 @@ PLDI is a premier academic forum in the field of programming languages and progr ## Theoretical Developments -TODO +TODO write + +Associative-Commutative matching: + +`@acrule` in SymbolicUtils.jl + +[Why reasonable rules can create infinite loops](https://github.com/egraphs-good/egg/discussions/60) + https://effect.systems/blog/ta-completion.html --- From c3a62c8bddfc4d98eab3153c18511c07a47d7bfb Mon Sep 17 00:00:00 2001 From: a Date: Fri, 12 Jan 2024 13:55:17 +0100 Subject: [PATCH 47/47] more readme --- README.md | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index a1b606e8..8828f8d3 100644 --- a/README.md +++ b/README.md @@ -82,8 +82,10 @@ work together to turn this list into some new Julia packages: #### Integration with Symbolics.jl -Many features of this package, such as the classical rewriting system, have been ported from [SymbolicUtils.jl], and are technically the same. Integration between Metatheory.jl with Symbolics.jl **is currently -paused**, as we are planning to [reach consensus in the development of a common Julia symbolic term interface](https://github.com/JuliaSymbolics/TermInterface.jl). +Many features of this package, such as the classical rewriting system, have been ported from [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl), and are technically the same. Integration between Metatheory.jl with Symbolics.jl **is currently +paused**, as we are waiting to reach consensus for the redesign of a common Julia symbolic term interface, [TermInterface.jl](https://github.com/JuliaSymbolics/TermInterface.jl). + +TODO link discussion when posted An integration between Metatheory.jl and [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl) is possible and has previously been shown in the ["High-performance symbolic-numerics via multiple dispatch"](https://arxiv.org/abs/2105.03949) paper. Once we reach consensus for a shared symbolic term interface, Metatheory.jl can be used to: @@ -148,15 +150,19 @@ PLDI is a premier academic forum in the field of programming languages and progr ## Theoretical Developments -TODO write +There's also lots of room for theoretical improvements to the e-graph data structure and equality saturation rewriting. + +#### Associative-Commutative-Distributive e-matching -Associative-Commutative matching: +In classical rewriting SymbolicUtils.jl offers a mechanism for matching expressions with associative and commutative operations: [`@acrule`](https://docs.sciml.ai/SymbolicUtils/stable/manual/rewrite/#Associative-Commutative-Rules) - a special kind of rule that considers all permutations and combinations of arguments. In e-graph rewriting in Metatheory.jl, associativity and commutativity have to be explicitly defined as rules. However, the presence of such rules, together with distributivity, will likely cause equality saturation to loop infinitely. See ["Why reasonable rules can create infinite loops"](https://github.com/egraphs-good/egg/discussions/60) for an explanation. -`@acrule` in SymbolicUtils.jl +Some workaround exists for ensuring termination of equality saturation: bounding the depth of search, or merge-only rewriting without introducing new terms (see ["Ensuring the Termination of EqSat over a Terminating Term Rewriting System"](https://effect.systems/blog/ta-completion.html)). -[Why reasonable rules can create infinite loops](https://github.com/egraphs-good/egg/discussions/60) +There's a few theoretical questions left: -https://effect.systems/blog/ta-completion.html +- **What kind of rewrite systems terminate in equality saturation**? +- Can associative-commutative matching be applied efficiently to e-graphs while avoiding combinatory explosion? +- Can e-graphs be extended to include nodes with special algebraic properties, in order to mitigate the downsides of non-terminating systems? ---