Skip to content

Commit cfb297c

Browse files
committed
Correct Graph Executor Python API
1 parent 229484e commit cfb297c

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

python/tvm/autotvm/task/relay_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def _lower(mod, target, params):
4646
with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}):
4747
mod, _ = relay.optimize(mod, target, params)
4848
grc = graph_executor_codegen.GraphExecutorCodegen(None, target)
49-
grc.codegen(mod["main"])
49+
grc.codegen(mod, mod["main"])
5050
return
5151

5252
compiler = relay.vm.VMCompiler()

python/tvm/relay/backend/graph_executor_codegen.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,13 @@ def _setup(self, mod, target):
6464
tgts[_expr.IntImm("int32", 0)] = Target(target)
6565
self._init(mod, tgts)
6666

67-
def codegen(self, func):
67+
def codegen(self, ir_module, func):
6868
"""Compile a single function into a graph.
6969
7070
Parameters
7171
----------
72+
ir_module: tvm.ir.Module
73+
The module to compile
7274
func: tvm.relay.Expr
7375
The function to compile.
7476
@@ -82,7 +84,7 @@ def codegen(self, func):
8284
Additional constant parameters.
8385
"""
8486
default_mod_name = mangle_module_name("default")
85-
self._codegen(func, default_mod_name)
87+
self._codegen(ir_module, func, default_mod_name)
8688
graph_json = self._get_graph_json()
8789
lowered_func = self._get_irmodule()
8890
param_names = self._list_params_name()

0 commit comments

Comments
 (0)