@@ -53,7 +53,11 @@ namespace {
5353struct PairHash {
5454 template <typename T1, typename T2>
5555 std::size_t operator ()(const std::pair<T1, T2>& k) const {
56- return std::hash<T1>()(k.first ) ^ std::hash<T2>()(k.second );
56+ return dmlc::HashCombine (std::hash<T1>()(k.first ), std::hash<T2>()(k.second ));
57+ }
58+ template <typename T2>
59+ std::size_t operator ()(const std::pair<Target, T2>& k) const {
60+ return dmlc::HashCombine (ObjectHash ()(k.first ), std::hash<T2>()(k.second ));
5761 }
5862};
5963
@@ -289,7 +293,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
289293 PatternFunctor<bool (const Pattern& p, const ObjectRef& v)> {
290294 public:
291295 // TODO(mbs): Collapse mod and per_target_module once IRModule subsumes LoweredModule.
292- Interpreter (IRModule mod, Map<String , IRModule> per_target_module, Device device, Target target)
296+ Interpreter (IRModule mod, Map<Target , IRModule> per_target_module, Device device, Target target)
293297 : mod_(mod),
294298 per_target_module_ (per_target_module),
295299 device_(device),
@@ -373,7 +377,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
373377 */
374378 PackedFunc TIRToPackedFunc (const GlobalVar& tir_fn_var, const Array<GlobalVar>& all_tir_fn_vars,
375379 Target target) {
376- std::pair<std::string , std::string> packed_func_key (target-> str () , tir_fn_var->name_hint );
380+ std::pair<Target , std::string> packed_func_key (target, tir_fn_var->name_hint );
377381 auto packed_itr = compiled_packed_funcs_.find (packed_func_key);
378382 if (packed_itr != compiled_packed_funcs_.end ()) {
379383 // Already compiled.
@@ -382,8 +386,11 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
382386
383387 // Project out just the function(s) we need.
384388 IRModule lowered_projected_mod;
385- auto mod_itr = per_target_module_.find (target->str ());
386- ICHECK (mod_itr != per_target_module_.end ())
389+ std::unordered_map<Target, IRModule, backend::TargetStrHash, backend::TargetStrEqual>
390+ per_target_module_std_map =
391+ backend::TargetModuleMapToTargetStrModuleMap (per_target_module_);
392+ auto mod_itr = per_target_module_std_map.find (target);
393+ ICHECK (mod_itr != per_target_module_std_map.end ())
387394 << " No target module for target '" << target->str () << " '" ;
388395 const IRModule& target_module = (*mod_itr).second ;
389396 for (const auto & var : all_tir_fn_vars) {
@@ -407,7 +414,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
407414 PackedFunc packed_func = runtime_module.GetFunction (var->name_hint );
408415 ICHECK (packed_func != nullptr ) << " No packed function for global var '" << var->name_hint
409416 << " ' in compiled module for target '" << target->str () << " '" ;
410- compiled_packed_funcs_.emplace (std::make_pair (target-> str () , var->name_hint ), packed_func);
417+ compiled_packed_funcs_.emplace (std::make_pair (target, var->name_hint ), packed_func);
411418 }
412419
413420 // Return just what we need for this call.
@@ -874,11 +881,10 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
874881 // Map from target key to lowered TIR functions derived from mod_.
875882 // Note that primitives are implicitly executed on target_, while shape functions are implicitly
876883 // executed on the default 'cpu' host. Thus this map has at most two entries.
877- Map<String , IRModule> per_target_module_;
884+ Map<Target , IRModule> per_target_module_;
878885 // Cached packed functions for the primitives and shape functions, keyed by target and
879886 // global var name.
880- std::unordered_map<std::pair<std::string, std::string>, PackedFunc, PairHash>
881- compiled_packed_funcs_;
887+ std::unordered_map<std::pair<Target, std::string>, PackedFunc, PairHash> compiled_packed_funcs_;
882888 // Unique device on which primitives (but not shape functions) will be executed.
883889 // (For simplicity we only run the interpreter on a single device.)
884890 Device device_;
@@ -895,7 +901,7 @@ class Interpreter : public ExprFunctor<ObjectRef(const Expr& n)>,
895901 * rewritten \p mod and target-specific modules containing bindings for all TIR primitive
896902 * functions needed by the rewritten module.
897903 */
898- std::pair<IRModule, Map<String , IRModule>> Prepare (IRModule mod, Device device, Target target) {
904+ std::pair<IRModule, Map<Target , IRModule>> Prepare (IRModule mod, Device device, Target target) {
899905 // Run minimal transforms on module to establish invariants needed by interpreter.
900906 transform::Sequential seq ({transform::SimplifyInference (),
901907 // FuseOps will mark wrapped calls to prim-ops with the 'Primitive'
@@ -1014,7 +1020,7 @@ TypedPackedFunc<ObjectRef(Array<Expr>)> EvalFunction(IRModule mod, Expr expr, De
10141020 // and can just eval it directly.
10151021 expr_to_eval = expr;
10161022 }
1017- std::pair<IRModule, Map<String , IRModule>> main_and_lowered =
1023+ std::pair<IRModule, Map<Target , IRModule>> main_and_lowered =
10181024 Prepare (mod_with_expr, device, target);
10191025 std::shared_ptr<Interpreter> intrp = std::make_shared<Interpreter>(
10201026 /* mod=*/ main_and_lowered.first , /* per_target_module=*/ main_and_lowered.second , device,
@@ -1057,7 +1063,7 @@ ObjectRef Eval(Expr expr, Map<GlobalTypeVar, TypeData> type_definitions,
10571063 std::unordered_set<String> import_set, Device device, Target target) {
10581064 std::pair<IRModule, GlobalVar> mod_and_global =
10591065 IRModule::FromExprInContext (expr, /* global_funcs=*/ {}, type_definitions, import_set);
1060- std::pair<IRModule, Map<String , IRModule>> main_and_lowered =
1066+ std::pair<IRModule, Map<Target , IRModule>> main_and_lowered =
10611067 Prepare (mod_and_global.first , device, target);
10621068 Interpreter intrp (
10631069 /* mod=*/ main_and_lowered.first , /* per_target_module=*/ main_and_lowered.second , device,
0 commit comments