Skip to content

Commit 7b91e62

Browse files
Change target string to Target object in the TE compiler and interpreter (#8835)
* # This is a combination of 2 commits. # This is the 1st commit message: Initial changes # This is the commit message #2: Ftarget string -> Target object works! * Fix remaining target strings * fix bad rebase * Fix typo * 1 more bad rebase fix * Lint * typo * Forgot to commit this * Add TargetStrHash and Map<Target... to std::unordered_map<Target... conversion fn * Passing most tests, yay * remove some comments * lint * target-str-to-target-object * Respond to change requests Co-authored-by: Jared Roesch <[email protected]>
1 parent 400baf2 commit 7b91e62

File tree

8 files changed

+121
-40
lines changed

8 files changed

+121
-40
lines changed

include/tvm/target/target.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <tvm/target/target_kind.h>
3232

3333
#include <string>
34+
#include <unordered_map>
3435
#include <unordered_set>
3536
#include <vector>
3637

@@ -203,5 +204,6 @@ void CheckAndUpdateHostConsistency(Map<Integer, Target>* target, Target* host);
203204
* \param host The Target typed object for target host to be updated
204205
*/
205206
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* target, Target* host);
207+
206208
} // namespace tvm
207209
#endif // TVM_TARGET_TARGET_H_

src/relay/backend/aot_executor_codegen.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -665,11 +665,10 @@ class AOTExecutorCodegen : public MixedModeVisitor {
665665
ret.lowered_funcs = lowered_module.per_target_module;
666666
ret.external_mods = lowered_module.external_mods;
667667

668-
auto target_host_str = target_host_->str();
669-
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
670-
ret.lowered_funcs[target_host_str]->Update(mod_run);
668+
if (ret.lowered_funcs.find(target_host_) != ret.lowered_funcs.end()) {
669+
ret.lowered_funcs[target_host_]->Update(mod_run);
671670
} else {
672-
ret.lowered_funcs.Set(target_host_str, mod_run);
671+
ret.lowered_funcs.Set(target_host_, mod_run);
673672
}
674673

675674
std::vector<String> input_var_names(input_vars_.size());
@@ -774,7 +773,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {
774773
return (*it).second.first;
775774
}
776775

777-
Map<String, IRModule> get_irmodule() { return this->output_.lowered_funcs; }
776+
Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }
778777

779778
std::shared_ptr<AOTExecutorCodegen> codegen_;
780779
LoweredOutput output_;

src/relay/backend/build_module.cc

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ struct ExecutorCodegen {
9292
return CallFunc<Array<tvm::runtime::Module>>("get_external_modules", nullptr);
9393
}
9494

95-
Map<String, IRModule> GetIRModule() {
96-
return CallFunc<Map<String, IRModule>>("get_irmodule", nullptr);
95+
Map<Target, IRModule> GetIRModule() {
96+
return CallFunc<Map<Target, IRModule>>("get_irmodule", nullptr);
9797
}
9898

9999
runtime::Metadata GetMetadata() { return CallFunc<runtime::Metadata>("get_metadata"); }
@@ -491,8 +491,9 @@ class RelayBuildModule : public runtime::ModuleNode {
491491
auto lowered_funcs = executor_codegen_->GetIRModule();
492492

493493
// No need to build for external functions.
494-
if (lowered_funcs.find("ext_dev") != lowered_funcs.end()) {
495-
lowered_funcs.Set("ext_dev", IRModule());
494+
Target ext_dev("ext_dev");
495+
if (lowered_funcs.find(ext_dev) != lowered_funcs.end()) {
496+
lowered_funcs.Set(ext_dev, IRModule());
496497
}
497498

498499
// Generate a placeholder function that attaches linked params as its arguments.
@@ -510,11 +511,11 @@ class RelayBuildModule : public runtime::ModuleNode {
510511
DictAttrs attrs{dict};
511512
auto prim = tir::PrimFunc(Array<tir::Var>(), tir::SeqStmt(Array<tir::Stmt>()), VoidType(),
512513
Map<tir::Var, tir::Buffer>(), attrs);
513-
if (lowered_funcs.find(target_host->str()) == lowered_funcs.end()) {
514-
lowered_funcs.Set(target_host->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
514+
if (lowered_funcs.find(target_host) == lowered_funcs.end()) {
515+
lowered_funcs.Set(target_host, IRModule(Map<GlobalVar, BaseFunc>({})));
515516
}
516-
lowered_funcs[target_host->str()]->Add(
517-
GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param), prim);
517+
lowered_funcs[target_host]->Add(GlobalVar(::tvm::runtime::symbol::tvm_lookup_linked_param),
518+
prim);
518519
}
519520

520521
// When there is no lowered_funcs due to reasons such as optimization.

src/relay/backend/interpreter.cc

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@ namespace {
5353
struct PairHash {
5454
template <typename T1, typename T2>
5555
std::size_t operator()(const std::pair<T1, T2>& k) const {
56-
return std::hash<T1>()(k.first) ^ std::hash<T2>()(k.second);
56+
return dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
57+
}
58+
template <typename T2>
59+
std::size_t operator()(const std::pair<Target, T2>& k) const {
60+
return dmlc::HashCombine(ObjectHash()(k.first), std::hash<T2>()(k.second));
5761
}
5862
};
5963

@@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
289293
PatternFunctor<bool(const Pattern& p, const ObjectRef& v)> {
290294
public:
291295
// TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
292-
Interpreter(IRModule mod, Map<String, IRModule> per_target_module, Device device, Target target)
296+
Interpreter(IRModule mod, Map<Target, IRModule> per_target_module, Device device, Target target)
293297
: mod_(mod),
294298
per_target_module_(per_target_module),
295299
device_(device),
@@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
373377
*/
374378
PackedFunc TIRToPackedFunc(const GlobalVar& tir_fn_var, const Array<GlobalVar>& all_tir_fn_vars,
375379
Target target) {
376-
std::pair<std::string, std::string> packed_func_key(target->str(), tir_fn_var->name_hint);
380+
std::pair<Target, std::string> packed_func_key(target, tir_fn_var->name_hint);
377381
auto packed_itr = compiled_packed_funcs_.find(packed_func_key);
378382
if (packed_itr != compiled_packed_funcs_.end()) {
379383
// Already compiled.
@@ -382,8 +386,11 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
382386

383387
// Project out just the function(s) we need.
384388
IRModule lowered_projected_mod;
385-
auto mod_itr = per_target_module_.find(target->str());
386-
ICHECK(mod_itr != per_target_module_.end())
389+
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
390+
per_target_module_std_map =
391+
backend::TargetModuleMapToTargetStrModuleMap(per_target_module_);
392+
auto mod_itr = per_target_module_std_map.find(target);
393+
ICHECK(mod_itr != per_target_module_std_map.end())
387394
<< "No target module for target '" << target->str() << "'";
388395
const IRModule& target_module = (*mod_itr).second;
389396
for (const auto& var : all_tir_fn_vars) {
@@ -407,7 +414,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
407414
PackedFunc packed_func = runtime_module.GetFunction(var->name_hint);
408415
ICHECK(packed_func != nullptr) << "No packed function for global var '" << var->name_hint
409416
<< "' in compiled module for target '" << target->str() << "'";
410-
compiled_packed_funcs_.emplace(std::make_pair(target->str(), var->name_hint), packed_func);
417+
compiled_packed_funcs_.emplace(std::make_pair(target, var->name_hint), packed_func);
411418
}
412419

413420
// Return just what we need for this call.
@@ -874,11 +881,10 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
874881
// Map from target key to lowered TIR functions derived from mod_.
875882
// Note that primitives are implicitly executed on target_, while shape functions are implicitly
876883
// executed on the default 'cpu' host. Thus this map has at most two entries.
877-
Map<String, IRModule> per_target_module_;
884+
Map<Target, IRModule> per_target_module_;
878885
// Cached packed functions for the primitives and shape functions, keyed by target and
879886
// global var name.
880-
std::unordered_map<std::pair<std::string, std::string>, PackedFunc, PairHash>
881-
compiled_packed_funcs_;
887+
std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_;
882888
// Unique device on which primitives (but not shape functions) will be executed.
883889
// (For simplicity we only run the interpreter on a single device.)
884890
Device device_;
@@ -895,7 +901,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
895901
* rewritten \p mod and target-specific modules containing bindings for all TIR primitive
896902
* functions needed by the rewritten module.
897903
*/
898-
std::pair<IRModule, Map<String, IRModule>> Prepare(IRModule mod, Device device, Target target) {
904+
std::pair<IRModule, Map<Target, IRModule>> Prepare(IRModule mod, Device device, Target target) {
899905
// Run minimal transforms on module to establish invariants needed by interpreter.
900906
transform::Sequential seq({transform::SimplifyInference(),
901907
// FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
@@ -1014,7 +1020,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
10141020
// and can just eval it directly.
10151021
expr_to_eval = expr;
10161022
}
1017-
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
1023+
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
10181024
Prepare(mod_with_expr, device, target);
10191025
std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(
10201026
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,
@@ -1057,7 +1063,7 @@ ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
10571063
std::unordered_set<String> import_set, Device device, Target target) {
10581064
std::pair<IRModule, GlobalVar> mod_and_global =
10591065
IRModule::FromExprInContext(expr, /*global_funcs=*/{}, type_definitions, import_set);
1060-
std::pair<IRModule, Map<String, IRModule>> main_and_lowered =
1066+
std::pair<IRModule, Map<Target, IRModule>> main_and_lowered =
10611067
Prepare(mod_and_global.first, device, target);
10621068
Interpreter intrp(
10631069
/*mod=*/main_and_lowered.first, /*per_target_module=*/main_and_lowered.second, device,

src/relay/backend/te_compiler.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,32 +85,33 @@ class TECompilerImpl : public TECompilerNode {
8585
return LowerShapeFuncInternal(key)->cached_func;
8686
}
8787

88-
Map<String, IRModule> GetLoweredFunctions() {
89-
Map<String, IRModule> lowered_functions;
88+
Map<Target, IRModule> GetLoweredFunctions() {
89+
std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
90+
lowered_functions;
9091
for (const auto& it : cache_) {
9192
auto source_func = it.first;
9293
auto lowered_func = it.second;
9394
auto target = source_func->target;
9495

95-
if (!lowered_functions.count(target->str())) {
96-
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
96+
if (!lowered_functions.count(target)) {
97+
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
9798
}
9899

99-
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
100+
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
100101
}
101102

102103
for (const auto& it : shape_func_cache_) {
103104
auto source_func = it.first;
104105
auto lowered_func = it.second;
105106
auto target = source_func->target;
106107

107-
if (!lowered_functions.count(target->str())) {
108-
lowered_functions.Set(target->str(), IRModule(Map<GlobalVar, BaseFunc>({})));
108+
if (!lowered_functions.count(target)) {
109+
lowered_functions[target] = IRModule(Map<GlobalVar, BaseFunc>({}));
109110
}
110111

111-
lowered_functions[target->str()]->Update(lowered_func->cached_func->funcs);
112+
lowered_functions[target]->Update(lowered_func->cached_func->funcs);
112113
}
113-
return lowered_functions;
114+
return backend::TargetStrModuleMapToTargetModuleMap(lowered_functions);
114115
}
115116

116117
Array<tvm::runtime::Module> LowerExternalFunctions() {
@@ -884,7 +885,7 @@ IRModule LoweredModuleToIRModule(LoweredModule mod) {
884885

885886
// Annotate the per-target functions with their target and add them to the unified module
886887
for (const auto& kv : mod.per_target_module) {
887-
const String target = kv.first;
888+
const Target target = kv.first;
888889
const IRModule target_module = kv.second;
889890

890891
// Right now, per-target functions are TIR functions, which don't have type definitions, so
@@ -926,15 +927,15 @@ LoweredModule IRModuleToLoweredModule(IRModule mod) {
926927
main_mod->AddTypeDef(kv.first, kv.second);
927928
}
928929

929-
Map<String, IRModule> per_target_modules;
930+
Map<Target, IRModule> per_target_modules;
930931
for (const auto& kv : mod->functions) {
931932
const GlobalVar& var = kv.first;
932933
const BaseFunc& func = kv.second;
933934
if (func->IsInstance<relay::FunctionNode>()) {
934935
main_mod->Add(var, func);
935936
} else if (func->IsInstance<tir::PrimFuncNode>()) {
936937
// Extract target
937-
Optional<String> target = func->GetAttr<String>(tvm::attr::kTarget);
938+
Optional<Target> target = func->GetAttr<Target>(tvm::attr::kTarget);
938939
ICHECK(target) << "Target should be set at this point";
939940

940941
// Put the function in per_target_modules

src/relay/backend/te_compiler.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ class TECompilerNode : public Object {
9797
virtual CachedFunc Lower(const CCacheKey& key, const String mod_name) = 0;
9898

9999
/* Return all functions which have been lowered by the compiler, keyed by target. */
100-
virtual Map<String, IRModule> GetLoweredFunctions() = 0;
100+
virtual Map<Target, IRModule> GetLoweredFunctions() = 0;
101101

102102
/*!
103103
* \brief Just in time compile to get a PackedFunc.
@@ -144,7 +144,7 @@ struct LoweredModule {
144144
/*! \brief The module which contains the Relay code. */
145145
IRModule main_module;
146146
/*! \brief The module which contains per target code. */
147-
Map<String, IRModule> per_target_module;
147+
Map<Target, IRModule> per_target_module;
148148
/*! \brief The external runtime modules which must be combined with the lowered code. */
149149
Array<tvm::runtime::Module> external_mods;
150150
// TODO(@electriclilies): THis might need to become a map

src/relay/backend/utils.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,24 @@ Array<Pass> GetPassPrefix(const Map<tvm::Integer, tvm::Target>& targets, bool is
187187
return pass_seqs;
188188
}
189189

190+
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual>
191+
TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map) {
192+
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> std_map;
193+
for (auto kv : input_map) {
194+
std_map[kv.first] = kv.second;
195+
}
196+
return std_map;
197+
}
198+
199+
Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap(
200+
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map) {
201+
Map<Target, IRModule> tvm_map;
202+
for (auto kv : input_map) {
203+
tvm_map.Set(kv.first, kv.second);
204+
}
205+
return tvm_map;
206+
}
207+
190208
} // namespace backend
191209
} // namespace relay
192210
} // namespace tvm

src/relay/backend/utils.h

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ int64_t CalculateRelayExprSizeBytes(const Type& expr_type);
139139
*/
140140
struct LoweredOutput {
141141
std::string graph_json;
142-
Map<String, IRModule> lowered_funcs;
142+
Map<Target, IRModule> lowered_funcs;
143143
Array<tvm::runtime::Module> external_mods;
144144
Map<String, FunctionInfo> function_metadata;
145145
std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>> params;
@@ -427,6 +427,60 @@ inline bool IsCompileEngineCacheDisabled() {
427427
*/
428428
Array<Pass> GetPassPrefix(const Map<tvm::Integer, tvm::Target>& targets, bool is_vm);
429429

430+
/*! \brief Target hash function */
431+
struct TargetStrHash {
432+
/*!
433+
* \brief Calculate the hash code of a Target based on the string value of the Target.
434+
Note that this hash should NOT be used in new usecases, equality of targets based on their
435+
value is not well-defined.
436+
This will be removed when maps from Targets to IRModules are removed from the codebase.
437+
* \param target The Target to hash
438+
* \return String hash of the target
439+
*/
440+
size_t operator()(const Target& target) const {
441+
return String::HashBytes(target->str().c_str(), target->str().size());
442+
}
443+
};
444+
445+
/*! \brief Target equality function based on the string value of Target
446+
Note that this equality function should NOT be used in new usecases, equality of targets based on
447+
their value is not well-defined. This will be removed when maps from Targets to IRModules are
448+
removed from the codebase.*/
449+
struct TargetStrEqual {
450+
/*!
451+
* \brief Check if the two Targets are equal
452+
* \param target One Target
453+
* \param other_target The other Target
454+
* \return String equality of the targets
455+
*/
456+
const bool operator()(const Target& target, const Target& other_target) const {
457+
TargetStrHash target_hash = TargetStrHash();
458+
return target_hash(target) == target_hash(other_target);
459+
}
460+
};
461+
462+
/*!
463+
* \brief Convert a Map<Target, IRModule> to std::unordered_map<Target, IRmodule, TargetStrHash,
464+
* TargetStrEqual> Target equality is currently based on pointer equality, which is a problem since
465+
* we have a lot of Map<Target, IRModule> in the codebase. This function converts the map to a
466+
* version that is keyed based on string value of the Target instead. Note that once we remove
467+
* Map<Target, IRModule>, this function will be removed.
468+
* \param input_map The map to convert
469+
* \return The converted map
470+
*/
471+
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual>
472+
TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map);
473+
474+
/*!
475+
* \brief Convert a std::unordered_map<Target, IRmodule, TargetStrHash, TargetStrEqual> to
476+
* Map<Target, IRModule> This function is a helper that undoes TargetModuleMapToTargetStr. Note that
477+
* once we remove Map<Target, IRModule>, this function will be removed.
478+
* \param input_map The map to convert
479+
* \return The converted map
480+
*/
481+
Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap(
482+
std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map);
483+
430484
} // namespace backend
431485
} // namespace relay
432486
} // namespace tvm

0 commit comments

Comments
 (0)