From ccfe250b245876d3b910d77beff8c65875582762 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 13 Jun 2019 14:43:28 -0700 Subject: [PATCH] [RELAY][PASS] Enable decorating python class as Pass --- 3rdparty/dmlc-core | 2 +- python/tvm/relay/__init__.py | 1 + python/tvm/relay/transform.py | 185 +++++++++++++++++++----- tests/python/relay/test_pass_manager.py | 49 +++++++ 4 files changed, 201 insertions(+), 36 deletions(-) diff --git a/3rdparty/dmlc-core b/3rdparty/dmlc-core index 3943914eed66..fbe142b267a8 160000 --- a/3rdparty/dmlc-core +++ b/3rdparty/dmlc-core @@ -1 +1 @@ -Subproject commit 3943914eed66470bd010df581e29e4dca4f7df6f +Subproject commit fbe142b267a8edd1f1188fa2140d88f7ae308661 diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index 1c8f5d6ceed3..5536e503e6b6 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -101,6 +101,7 @@ bind = expr.bind module_pass = transform.module_pass function_pass = transform.function_pass +alpha_equal = ir_pass.alpha_equal # ExprFunctor ExprFunctor = expr_functor.ExprFunctor diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index b76c2361605c..d7a7c26a2c92 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -19,6 +19,8 @@ Relay pass transformation infrastructure. """ import types +import inspect +import functools from tvm._ffi.runtime_ctypes import TVMContext from . import _transform @@ -444,16 +446,47 @@ def PartialEvaluate(): return _transform.PartialEvaluate() +def _wrap_class_module_pass(pass_cls, pass_info): + """Wrap a python class as function pass""" + class PyModulePass(ModulePass): + """Internal wrapper class to create a class instance.""" + def __init__(self, *args, **kwargs): + # initialize handle in cass pass_cls creation failed.fg + self.handle = None + inst = pass_cls(*args, **kwargs) + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(mod, ctx): + return inst.transform_module(mod, ctx) + self.__init_handle_by_constructor__( + _transform.MakeModulePass, _pass_func, pass_info) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyModulePass.__init__, pass_cls.__init__) + PyModulePass.__name__ = pass_cls.__name__ + PyModulePass.__doc__ = pass_cls.__doc__ + PyModulePass.__module__ = pass_cls.__module__ + return PyModulePass + + def module_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a module pass. This function returns a callback when pass_func - is provided. Otherwise, it returns the created module level pass using the - given optimization function. + """Decorate a module pass. + + This function returns a callback when pass_func is provided. + Otherwise, it serves a decorator function. + + pass_func can also be a class type with a method transform_module. + This function will create a decorated ModulePass using transform_module + as the pass function. Parameters ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. + pass_func : Optional[Callable[(Module, PassContext) ->Module]] + The transformation function or class. opt_level : int The optimization level of this module pass. @@ -468,14 +501,39 @@ def module_pass(pass_func=None, opt_level=None, name=None, required=None): Returns ------- create_module_pass : Union[Callable, ModulePass] - The callable that will create a module pass is returned when - pass_func is not passed in. Otherwise, a ModulePass object will be - directly created. + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new ModulePass will be returned when we decorate a pass function. + A new ModulePass class will be returned when we decorate a class type. Examples -------- - The following code creates a module level pass and adds an abs function to - the module. + The following code block decorates a module pass class. + + .. code-block:: python + + @relay.transform.module_pass + class CustomPipeline: + def __init__(self, enable_fold): + self.enable_fold = enable_fold + self.cse = relay.transform.EliminateCommonSubexpr() + self.const_fold = relay.transform.FoldConstant() + + def transform_module(self, mod, ctx): + mod = self.cse(mod, ctx) + if self.enable_fold: + mod = self.const_fold(mod, ctx) + return mod + + # create an instance of customized pipeline + pipeline = CustomPipeline(enable_fold=False) + assert isinstance(pipeline, transform.ModulePass) + # run the pipeline. + output_module = pipeline(input_module) + + The following code creates a module pass by decorating + a user defined transform function. .. code-block:: python @@ -497,7 +555,6 @@ def transform(mod, ctx): updated_mod = module_pass(m) # Now a function abs should be added to the module m. """ - if opt_level is None: raise ValueError("Please provide opt_level for the module pass.") @@ -506,30 +563,59 @@ def transform(mod, ctx): raise TypeError("Required is expected to be the type of " + "list/tuple.") - def create_module_pass(pass_func): + def create_module_pass(pass_arg): """Internal function that creates a module pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - fname = name if name else pass_func.__name__ + fname = name if name else pass_arg.__name__ info = PassInfo(opt_level, fname, required) - return _transform.MakeModulePass(pass_func, info) + if inspect.isclass(pass_arg): + return _wrap_class_module_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + return _transform.MakeModulePass(pass_arg, info) if pass_func: return create_module_pass(pass_func) return create_module_pass +def _wrap_class_function_pass(pass_cls, pass_info): + """Wrap a python class as function pass""" + class PyFunctionPass(FunctionPass): + """Internal wrapper class to create a class instance.""" + def __init__(self, *args, **kwargs): + # initialize handle in cass pass_cls creation failed.fg + self.handle = None + inst = pass_cls(*args, **kwargs) + # it is important not to capture self to + # avoid a cyclic dependency + def _pass_func(func, mod, ctx): + return inst.transform_function(func, mod, ctx) + self.__init_handle_by_constructor__( + _transform.MakeFunctionPass, _pass_func, pass_info) + self._inst = inst + + def __getattr__(self, name): + # fall back to instance attribute if there is not any + return self._inst.__getattribute__(name) + + functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__) + PyFunctionPass.__name__ = pass_cls.__name__ + PyFunctionPass.__doc__ = pass_cls.__doc__ + PyFunctionPass.__module__ = pass_cls.__module__ + return PyFunctionPass + + def function_pass(pass_func=None, opt_level=None, name=None, required=None): - """Create a function pass. This function returns a callback when pass_func + """Decorate a function pass. + + This function returns a callback when pass_func is provided. Otherwise, it returns the created function pass using the given optimization function. Parameters ---------- - pass_func : Optional[Callable[(Module/Function, PassContext) -> - Module/Function]] - The implemented optimization pass. + pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]] + The transformation function or class. opt_level : int The optimization level of this module pass. @@ -544,20 +630,48 @@ def function_pass(pass_func=None, opt_level=None, name=None, required=None): Returns ------- create_function_pass : Union[Callable, FunctionPass] - The callable that will create a function pass is returned when - pass_func is not passed in. Otherwise, a FunctionPass object will be - created. + + A decorator will be returned if pass_func is not provided, + otherwise return the decorated result. + The returned decorator has two behaviors depending on the input: + A new FunctionPass will be returned when we decorate a pass function. + A new FunctionPass class will be returned when we decorate a class type. Examples -------- - The following code creates a function level pass that performs constant - folding. + The following code block decorates a function pass class. + + .. code-block:: python + + @relay.transform.function_pass(opt_level=1) + class TestReplaceFunc: + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + # just for demo purposes + # transform func to new_func + return self.new_func + + x = relay.var("x", shape=(10, 20)) + f1 = relay.Function([x], x) + f2 = relay.Function([x], relay.log(x)) + # fpass is now a special pass that replaces every + # function to f1 + fpass = TestReplaceFunc(f1) + # now every function in input_mod is replaced by f1 + res_mod = fpass(input_mod) + + + The following code creates a function pass by decorating + a user defined transform function. .. code-block:: python @relay.transform.function_pass(opt_level=2) - def transform(func, ctx): - return ir_pass.fold_constant(func) + def transform(func, mod, ctx): + # my transformations here. + return func function_pass = transform assert isinstance(function_pass, transform.FunctionPass) @@ -577,14 +691,15 @@ def transform(func, ctx): raise TypeError("Required is expected to be the type of " + "list/tuple.") - def create_function_pass(pass_func): + def create_function_pass(pass_arg): """Internal function that creates a function pass""" - if not isinstance(pass_func, (types.FunctionType, types.LambdaType)): - raise TypeError("pass_func must be a callable for Module pass") - - fname = name if name else pass_func.__name__ + fname = name if name else pass_arg.__name__ info = PassInfo(opt_level, fname, required) - return _transform.MakeFunctionPass(pass_func, info) + if inspect.isclass(pass_arg): + return _wrap_class_function_pass(pass_arg, info) + if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): + raise TypeError("pass_func must be a callable for Module pass") + return _transform.MakeFunctionPass(pass_arg, info) if pass_func: return create_function_pass(pass_func) diff --git a/tests/python/relay/test_pass_manager.py b/tests/python/relay/test_pass_manager.py index 7505aa9ab981..a8f50bdb8f55 100644 --- a/tests/python/relay/test_pass_manager.py +++ b/tests/python/relay/test_pass_manager.py @@ -189,6 +189,29 @@ def test_pass_run(): test_pass_run() +def test_function_class_pass(): + @relay.transform.function_pass(opt_level=1) + class TestReplaceFunc: + """Simple test function to replace one argument to another.""" + def __init__(self, new_func): + self.new_func = new_func + + def transform_function(self, func, mod, ctx): + return self.new_func + + x = relay.var("x", shape=(10, 20)) + f1 = relay.Function([x], x) + f2 = relay.Function([x], relay.log(x)) + fpass = TestReplaceFunc(f1) + assert fpass.info.opt_level == 1 + assert fpass.info.name == "TestReplaceFunc" + mod = relay.Module.from_expr(f2) + mod = fpass(mod) + # wrap in expr + mod2 = relay.Module.from_expr(f1) + assert relay.alpha_equal(mod["main"], mod2["main"]) + + def test_function_pass(): shape = (10, ) dtype = 'float32' @@ -259,6 +282,30 @@ def test_pass_run(): test_pass_run() +def test_module_class_pass(): + @relay.transform.module_pass(opt_level=1) + class TestPipeline: + """Simple test function to replace one argument to another.""" + def __init__(self, new_mod, replace): + self.new_mod = new_mod + self.replace = replace + + def transform_module(self, mod, ctx): + if self.replace: + return self.new_mod + return mod + + x = relay.var("x", shape=(10, 20)) + m1 = relay.Module.from_expr(relay.Function([x], x)) + m2 = relay.Module.from_expr(relay.Function([x], relay.log(x))) + fpass = TestPipeline(m2, replace=True) + assert fpass.info.name == "TestPipeline" + mod3 = fpass(m1) + assert mod3.same_as(m2) + mod4 = TestPipeline(m2, replace=False)(m1) + assert mod4.same_as(m1) + + def test_pass_info(): info = relay.transform.PassInfo(opt_level=1, name="xyz") assert info.opt_level == 1 @@ -451,6 +498,8 @@ def expected(): if __name__ == "__main__": + test_function_class_pass() + test_module_class_pass() test_module_pass() test_function_pass() test_sequential_pass()