@@ -1804,22 +1804,83 @@ def callback(self, pre, post, node_map):
18041804 if new_args :
18051805 return relay .op .concatenate (relay .expr .Tuple (new_args ), axis = 0 )
18061806 else :
1807- return concat_args
1807+ return concat_args [ 0 ]
18081808
18091809 x = relay .var ("x" )
18101810 y = relay .var ("y" )
18111811 z = relay .var ("z" )
18121812 concat = relay .op .concatenate (relay .expr .Tuple ([x , y , z ]), axis = 0 )
18131813
1814- # Let the rewriter run recursively
1815- out = rewrite (ConcatRewriter (False ), concat )
1816- expected = relay .expr .Tuple ([x ])
1817- assert tvm .ir .structural_equal (out , expected )
1814+ def test_one_callback ():
1815+ # Let the rewriter run recursively
1816+ out = rewrite (ConcatRewriter (False ), concat )
1817+ expected = x
1818+ assert tvm .ir .structural_equal (out , expected )
1819+
1820+ # Run the rewriter once
1821+ out = rewrite (ConcatRewriter (True ), concat )
1822+ expected = relay .op .concatenate (relay .expr .Tuple ([x , y ]), axis = 0 )
1823+ assert tvm .ir .structural_equal (out , expected )
1824+
1825+ def test_multi_callbacks ():
1826+ # This class recursively add a nn.relu operator after nn.softmax
1827+ class OneMoreReluRewriter (DFPatternCallback ):
1828+ def __init__ (self , rewrite_once ):
1829+ super ().__init__ (rewrite_once = rewrite_once )
1830+ self .pattern = is_op ("nn.softmax" )(None )
1831+
1832+ def callback (self , pre , post , node_map ):
1833+ return relay .nn .relu (post )
1834+
1835+ def before ():
1836+ # Before:
1837+ # x y z
1838+ # | | |
1839+ # concat
1840+ # |
1841+ # softmax
1842+ return relay .nn .softmax (concat )
1843+
1844+ def once_concat ():
1845+ # ConcatRewrite once, OneMoreReluRewrite once
1846+ # Expected:
1847+ # x y
1848+ # | |
1849+ # concat
1850+ # |
1851+ # softmax
1852+ # |
1853+ # relu
1854+ return relay .nn .relu (
1855+ relay .nn .softmax (relay .op .concatenate (relay .expr .Tuple ([x , y ]), axis = 0 ))
1856+ )
1857+
1858+ def recursive_concat ():
1859+ # ConcatRewrite recursively, OneMoreReluRewrite once
1860+ # Expected:
1861+ # x
1862+ # |
1863+ # softmax
1864+ # |
1865+ # relu
1866+ return relay .nn .relu (relay .nn .softmax (x ))
1867+
1868+ # Run ConcatRewriter once, OneMoreReluRewriter once
1869+ out = rewrite (
1870+ [OneMoreReluRewriter (True ), ConcatRewriter (True )],
1871+ before (),
1872+ )
1873+ assert tvm .ir .structural_equal (out , once_concat ())
1874+
1875+ # Run ConcatRewriter recursively, OneMoreReluRewriter once
1876+ out = rewrite (
1877+ [OneMoreReluRewriter (True ), ConcatRewriter (False )],
1878+ before (),
1879+ )
1880+ assert tvm .ir .structural_equal (out , recursive_concat ())
18181881
1819- # Run the rewriter once
1820- out = rewrite (ConcatRewriter (True ), concat )
1821- expected = relay .op .concatenate (relay .expr .Tuple ([x , y ]), axis = 0 )
1822- assert tvm .ir .structural_equal (out , expected )
1882+ test_one_callback ()
1883+ test_multi_callbacks ()
18231884
18241885
18251886def test_matched_outside_but_dominated ():
0 commit comments