Skip to content

Commit cd1318b

Browse files
committed
- Masa & Andrew's comments
1 parent 0383153 commit cd1318b

File tree

16 files changed

+61
-65
lines changed

16 files changed

+61
-65
lines changed

include/tvm/ir/module.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,12 +519,12 @@ constexpr const char* kConstantMemoryPools = "constant_memory_pools";
519519

520520
/*
521521
* \brief All the runtime::NDArrays extracted from PrimFunc tir::AllocateConst nodes. The
522-
* node will record the index into this array. See also kConstNameToNDArray below, which is
522+
* node will record the index into this array. See also kConstNameToConstant below, which is
523523
* the analog for Realy Functions.
524524
*
525525
* Type: Array<runtime::NDArray>
526526
*/
527-
constexpr const char* kConstantsArray = "Constants";
527+
constexpr const char* kConstants = "constants";
528528

529529
/*!
530530
* \brief All the runtime::Modules accumulated during compilation by external codegen. These
@@ -542,7 +542,7 @@ constexpr const char* kExternalMods = "external_mods";
542542
*
543543
* Type: Map<String, runtime::NDArray>
544544
*/
545-
constexpr const char* kConstNameToNDArray = "const_name_to_ndarray";
545+
constexpr const char* kConstNameToConstant = "const_name_to_constant";
546546

547547
} // namespace attr
548548
} // namespace tvm

python/tvm/relay/backend/interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ class Interpreter(Executor):
195195
The runtime device to run the code on.
196196
197197
target : tvm.Target
198-
The target option to build the function.
198+
The target option to build the function. Only homogeneous execution is supported.
199199
200200
CAUTION: Despite the API the module is prepared upon each call to evaluate
201201
rather than once in create_executor.

python/tvm/relay/backend/vm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,9 @@ class VMExecutor(Executor):
198198
device : :py:class:`~tvm.runtime.Device`
199199
The runtime device to run the code on.
200200
201-
target : :py:class:`Target`
202-
The target option to build the function.
201+
target : any multi-target like object, see Target.canon_multi_target
202+
For homogeneous compilation, the unique build target.
203+
For heterogeneous compilation, a dictionary or list of possible build targets.
203204
"""
204205

205206
def __init__(self, mod, device, target):

python/tvm/relay/build_module.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -570,15 +570,16 @@ class GraphExecutor(_interpreter.Executor):
570570
device : :py:class:`Device`
571571
The runtime device to run the code on.
572572
573-
raw_targets : Array[tvm.target.Target]
574-
The available targets.
573+
target : any multi-target like object, see Target.canon_multi_target
574+
For homogeneous compilation, the unique build target.
575+
For heterogeneous compilation, a dictionary or list of possible build targets.
575576
"""
576577

577-
def __init__(self, mod, device, raw_targets):
578+
def __init__(self, mod, device, target):
578579
assert mod is not None
579580
self.mod = mod
580581
self.device = device
581-
self.raw_targets = raw_targets
582+
self.target = target
582583

583584
def _make_executor(self, expr=None):
584585
if expr:
@@ -589,7 +590,7 @@ def _make_executor(self, expr=None):
589590
raise ValueError(
590591
"Graph Executor only supports static graphs, got output type", ret_type
591592
)
592-
mod = build(self.mod, target=self.raw_targets)
593+
mod = build(self.mod, target=self.target)
593594
gmodule = _graph_executor.GraphModule(mod["default"](self.device))
594595

595596
def _unflatten(flat_iter, cur_type):
@@ -630,16 +631,16 @@ class AotExecutor(_interpreter.Executor):
630631
device : :py:class:`Device`
631632
The runtime device to run the code on.
632633
633-
raw_targets : Array[tvm.target.Target]
634-
The available targets.
634+
target : any multi-target like object, see Target.canon_multi_target
635+
For homogeneous compilation, the unique build target.
636+
For heterogeneous compilation, a dictionary or list of possible build targets.
635637
"""
636638

637-
def __init__(self, mod, device, raw_targets):
639+
def __init__(self, mod, device, target):
638640
assert mod is not None
639641
self.mod = mod
640642
self.device = device
641-
self.raw_targets = raw_targets
642-
assert raw_targets[0].attrs.get("executor", "graph") == "aot"
643+
self.target = target
643644

644645
def _make_executor(self, expr=None):
645646
if expr:
@@ -648,7 +649,7 @@ def _make_executor(self, expr=None):
648649
ret_type = self.mod["main"].checked_type.ret_type
649650
if _ty.is_dynamic(ret_type):
650651
raise ValueError("AOT Executor only supports static graphs, got output type", ret_type)
651-
mod = build(self.mod, target=self.raw_targets)
652+
mod = build(self.mod, target=self.target)
652653

653654
# NOTE: Given AOT requires use of the "c" backend, must export/import to compile the
654655
# generated code.
@@ -722,6 +723,8 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
722723
target : any multi-target like object, see Target.canon_multi_target
723724
For homogeneous compilation, the unique build target.
724725
For heterogeneous compilation, a dictionary or list of possible build targets.
726+
CAUTION: Though this API allows multiple targets, it does not allow multiple devices, so
727+
heterogenous compilation is not yet supported.
725728
726729
params : dict of str to NDArray
727730
Input parameters to the graph that do not change
@@ -737,11 +740,14 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
737740
if device is not None:
738741
assert device.device_type == raw_targets[0].kind.device_type
739742
else:
743+
# Use the first target as the device.
740744
device = _nd.device(raw_targets[0].kind.device_type, 0)
741745

742746
if params is not None:
743747
mod = IRModule.from_expr(bind_params_by_name(mod["main"], params))
744748

749+
assert raw_targets[0].attrs.get("executor") == kind
750+
745751
if kind == "debug":
746752
assert len(raw_targets) == 1, "The interpreter currently only supports a single target"
747753
return _interpreter.Interpreter(mod, device, raw_targets[0])

src/relay/backend/aot_executor_codegen.cc

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1170,12 +1170,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
11701170

11711171
// Collect any constants extracted by external codegen.
11721172
ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>();
1173-
Map<String, runtime::NDArray> const_name_to_ndarray =
1174-
lowered_mod
1175-
->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToNDArray,
1176-
Map<String, runtime::NDArray>())
1177-
.value();
1178-
for (const auto& kv : const_name_to_ndarray) {
1173+
Map<String, runtime::NDArray> const_name_to_constant =
1174+
lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
1175+
.value_or({});
1176+
for (const auto& kv : const_name_to_constant) {
11791177
ICHECK(ret.params.emplace(kv.first, kv.second).second);
11801178
}
11811179

@@ -1223,10 +1221,8 @@ class AOTExecutorCodegen : public MixedModeVisitor {
12231221
}
12241222

12251223
// Collect any runtime modules generated by external codegen.
1226-
ret.external_mods = lowered_mod
1227-
->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods,
1228-
Array<tvm::runtime::Module>())
1229-
.value();
1224+
ret.external_mods =
1225+
lowered_mod->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods).value_or({});
12301226

12311227
// This is the point where we separate the functions in the module by target
12321228
VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod);

src/relay/backend/contrib/arm_compute_lib/codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,7 @@ runtime::Module ACLCompiler(const ObjectRef& ref) {
393393
serializer.serialize();
394394
std::string graph_json = serializer.GetJSON();
395395

396-
// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
396+
// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
397397
// a callback which calls backend::UpdateConstants to capture the map before the function
398398
// 'disappears' into lowered form, on the assumption the visit order and thus constant
399399
// names match those generated by the JSONSerializer.

src/relay/backend/contrib/bnns/codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ runtime::Module BNNSCompiler(const ObjectRef& ref) {
137137
serializer.serialize();
138138
std::string graph_json = serializer.GetJSON();
139139

140-
// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
140+
// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
141141
// a callback which calls backend::UpdateConstants to capture the map before the function
142142
// 'disappears' into lowered form, on the assumption the visit order and thus constant
143143
// names match those generated by the JSONSerializer.

src/relay/backend/contrib/codegen_json/codegen_json.h

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
169169
* \brief Returns the accumulated map from constant names to the NDArray they must be bound to
170170
* at runtime. Also referred to a 'params' elsewhere in the code.
171171
*/
172-
const std::unordered_map<std::string, runtime::NDArray>& const_name_to_ndarray() const {
173-
return const_name_to_ndarray_;
172+
const std::unordered_map<std::string, runtime::NDArray>& const_name_to_constant() const {
173+
return const_name_to_constant_;
174174
}
175175

176176
/*!
@@ -260,9 +260,10 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
260260

261261
std::vector<JSONGraphNodeEntry> VisitExpr_(const ConstantNode* constant_node) {
262262
std::string name = symbol_ + "_const_" + std::to_string(const_names_.size());
263-
VLOG(1) << "Will require parameter '" << name << "' to be available at runtime";
264-
ICHECK_EQ(const_name_to_ndarray_.count(name), 0);
265-
const_name_to_ndarray_.emplace(name, constant_node->data);
263+
VLOG(1) << "Will require parameter '" << name
264+
<< "' to be supplied by the ConstLoaderModule at runtime";
265+
ICHECK_EQ(const_name_to_constant_.count(name), 0);
266+
const_name_to_constant_.emplace(name, constant_node->data);
266267
const_names_.push_back(name);
267268
auto node = std::make_shared<JSONGraphNode>(name, /*op_type=*/"const");
268269
return AddNode(node, GetRef<Expr>(constant_node));
@@ -361,7 +362,7 @@ class JSONSerializer : public MemoizedExprTranslator<std::vector<JSONGraphNodeEn
361362
* translation to JSON. The JSON will record only the constant name. The actual NDArray must
362363
* be made available at runtime from a ConstLoaderModule.
363364
*/
364-
std::unordered_map<std::string, runtime::NDArray> const_name_to_ndarray_;
365+
std::unordered_map<std::string, runtime::NDArray> const_name_to_constant_;
365366
/*!
366367
* \brief The domain of the above map, but in order the constants were encountered during
367368
* translation.

src/relay/backend/contrib/cutlass/codegen.cc

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,9 @@ namespace {
4747
Target GetCutlassTarget() {
4848
Target target = Target::Current(/*allow_not_defined=*/true);
4949
if (!target.defined() || target->kind->name != "cutlass") {
50-
// Use the default CUTLASS compilation options.
50+
// Use the default CUTLASS compilation options if no specific "cutlass" target was given
51+
// in the overall targets list. In that case target_hooks.cc will invoke the custom pass
52+
// without pushing any target instance onto the implicit target stack.
5153
target = Target("cutlass");
5254
}
5355
return target;
@@ -912,8 +914,7 @@ transform::Pass CompileForCutlassImpl() {
912914
Target target = GetCutlassTarget();
913915
runtime::Module runtime_mod = (*pf)(mod, target);
914916
Array<runtime::Module> external_mods =
915-
mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods, Array<runtime::Module>())
916-
.value();
917+
mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
917918
external_mods.push_back(runtime_mod);
918919
return WithAttr(mod, tvm::attr::kExternalMods, external_mods);
919920
};

src/relay/backend/contrib/dnnl/codegen.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,7 @@ runtime::Module DNNLCompiler(const ObjectRef& ref) {
586586
serializer.serialize();
587587
std::string graph_json = serializer.GetJSON();
588588

589-
// Note that serializer.const_name_to_ndarray() is ignored. Instead the TECompiler invokes
589+
// Note that serializer.const_name_to_constant() is ignored. Instead the TECompiler invokes
590590
// a callback which calls backend::UpdateConstants to capture the map before the function
591591
// 'disappears' into lowered form, on the assumption the visit order and thus constant
592592
// names match those generated by the JSONSerializer.

0 commit comments

Comments
 (0)