Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions pytensor/graph/rewriting/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def register(
rewriter: Union["RewriteDatabase", RewritesType],
*tags: str,
use_db_name_as_tag=True,
overwrite_existing=False,
):
"""Register a new rewriter to the database.

Expand All @@ -56,7 +57,8 @@ def register(
``local_remove_all_assert``. Setting `use_db_name_as_tag` to
``False`` removes that behavior. This means that only the rewrite's name
and/or its tags will enable it.

overwrite_existing:
Overwrite the existing rewriter with a new one having the same name
"""
if not isinstance(
rewriter,
Expand All @@ -66,22 +68,27 @@ def register(
):
raise TypeError(f"{rewriter} is not a valid rewrite type.")

if name in self.__db__:
raise ValueError(f"The tag '{name}' is already present in the database.")

if use_db_name_as_tag:
if self.name is not None:
tags = (*tags, self.name)

rewriter.name = name
# This restriction is there because in many place we suppose that
# something in the RewriteDatabase is there only once.
if rewriter.name in self.__db__:
raise ValueError(
f"Tried to register {rewriter.name} again under the new name {name}. "
"The same rewrite cannot be registered multiple times in"
" an `RewriteDatabase`; use `ProxyDB` instead."
)

# if tag collides with name
if name in self.__db__ and name not in self._names:
raise ValueError(f"The tag '{name}' is already present in the database.")

if name in self.__db__ or rewriter.name in self.__db__:
if overwrite_existing:
self.remove_tags(name, *tags)
old_rewriter = self.__db__[name].pop()
self._names.remove(name)
self.__db__[old_rewriter.__class__.__name__].remove(old_rewriter)
else:
raise ValueError(
f"The tag '{name}' is already present in the database."
)

self.__db__[name] = OrderedSet([rewriter])
self._names.add(name)
self.__db__[rewriter.__class__.__name__].add(rewriter)
Expand Down
44 changes: 43 additions & 1 deletion tests/graph/rewriting/test_db.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest

from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import GraphRewriter, SequentialGraphRewriter
from pytensor.graph.rewriting.db import (
EquilibriumDB,
Expand All @@ -17,6 +18,13 @@ def apply(self, fgraph):
pass


class NewTestRewriter(GraphRewriter):
name = "bleh"

def apply(self, fgraph):
pass


class TestDB:
def test_register(self):
db = RewriteDatabase()
Expand All @@ -31,7 +39,7 @@ def test_register(self):
assert "c" in db

with pytest.raises(ValueError, match=r"The tag.*"):
db.register("c", TestRewriter()) # name taken
db.register("c", NewTestRewriter()) # name taken

with pytest.raises(ValueError, match=r"The tag.*"):
db.register("z", TestRewriter()) # name collides with tag
Expand All @@ -42,6 +50,40 @@ def test_register(self):
with pytest.raises(TypeError, match=r".* is not a valid.*"):
db.register("d", 1)

def test_overwrite_existing(self):
class TestOverwrite1(GraphRewriter):
def apply(self, fgraph):
fgraph.counter[0] += 1

class TestOverwrite2(GraphRewriter):
def apply(self, fgraph):
fgraph.counter[1] += 1

db = SequenceDB()
fg = FunctionGraph([], [])
fg.counter = [0, 0]

db.register("a", TestRewriter(), "basic")
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [0, 0]

with pytest.raises(ValueError, match=r"The tag.*"):
db.register("a", TestOverwrite1(), "basic")
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [0, 0]

db.register("a", TestOverwrite1(), "basic", overwrite_existing=True)
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [1, 0]

db.register("a", TestOverwrite2(), "basic", overwrite_existing=True)
rewriter = db.query("+basic")
rewriter.rewrite(fg)
assert fg.counter == [1, 1]

def test_EquilibriumDB(self):
eq_db = EquilibriumDB()

Expand Down