Skip to content

Commit cb91d7e

Browse files
zhiicsWei Chen
authored andcommitted
[Relay][Compilation] replace relay.build_module with C++ BuildModule (apache#3174)
1 parent f835627 commit cb91d7e

File tree

13 files changed

+534
-541
lines changed

13 files changed

+534
-541
lines changed

python/tvm/relay/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from . import module
2626
from . import adt
2727
from . import ir_pass
28-
from .build_module import build, build_config, create_executor, optimize
28+
from .build_module import build, build_config, create_executor
2929
from . import prelude
3030
from . import parser
3131
from . import debug

python/tvm/relay/_build_module.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
18+
"""The interface for building Relay functions exposed from C++."""
19+
from tvm._ffi.function import _init_api
20+
21+
_init_api("relay.build_module", __name__)

python/tvm/relay/backend/graph_runtime_codegen.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,9 @@
3636
from __future__ import absolute_import
3737

3838
from tvm.ndarray import empty
39-
from tvm._ffi.function import _init_api
40-
4139
from tvm.relay import build_module
4240
from tvm import target as _target
43-
44-
_init_api("tvm.relay.build_module")
41+
from tvm import expr as _expr
4542

4643
class GraphRuntimeCodegen(object):
4744
"""The compiler from Relay to the TVM runtime system."""
@@ -57,17 +54,14 @@ def __init__(self, mod, target):
5754
self._setup(mod, target)
5855

5956
def _setup(self, mod, target):
60-
tgts = []
57+
tgts = {}
6158
if isinstance(target, dict):
62-
for kv in target.items():
63-
tgts.append(kv[0])
64-
if isinstance(kv[1], (str, _target.Target)):
65-
tgts.append(str(kv[1]))
66-
else:
59+
for dev, tgt in target.items():
60+
if not isinstance(tgt, (str, _target.Target)):
6761
raise Exception("Unknown target type")
62+
tgts[dev] = _target.create(tgt)
6863
elif isinstance(target, (str, _target.Target)):
69-
tgts.append("0")
70-
tgts.append(str(target))
64+
tgts[_expr.IntImm("int32", 0)] = _target.create(target)
7165
self._init(mod, tgts)
7266

7367
def codegen(self, func):

0 commit comments

Comments
 (0)