diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index fcc3555f..6b5522ce 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -78,7 +78,6 @@ function eqsat_search!( @debug "SEARCHING" for (rule_idx, rule) in enumerate(theory) - prev_matches = n_matches @timeit report.to string(rule_idx) begin prev_matches = n_matches # don't apply banned rules @@ -90,24 +89,24 @@ function eqsat_search!( ids_left = cached_ids(g, rule.left) for i in ids_left cansearch(scheduler, rule_idx, i) || continue - n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer) - inform!(scheduler, rule_idx, i, n_matches) + eclass_matches = rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer) + n_matches += eclass_matches + inform!(scheduler, rule_idx, i, eclass_matches) end if is_bidirectional(rule) ids_right = cached_ids(g, rule.right) for i in ids_right cansearch(scheduler, rule_idx, i) || continue - n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer) - inform!(scheduler, rule_idx, i, n_matches) + eclass_matches = rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer) + n_matches += eclass_matches + inform!(scheduler, rule_idx, i, eclass_matches) end end n_matches - prev_matches > 0 && @debug "Rule $rule_idx: $rule produced $(n_matches - prev_matches) matches" - # if n_matches - prev_matches > 2 && rule_idx == 2 - # @debug buffer_readable(g, old_len) - # end - inform!(scheduler, rule_idx, n_matches) + + inform!(scheduler, rule_idx, n_matches - prev_matches) end end