Skip to content

Commit e31ca8a

Browse files
committed
Fix auto scheduler code
1 parent 671febc commit e31ca8a

File tree

1 file changed

+13
-9
lines changed

1 file changed

+13
-9
lines changed

src/relay/backend/te_compiler.cc

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -726,19 +726,23 @@ LoweredModule LowerTE(const IRModule& module, TargetMap targets, DeviceMap devic
726726

727727
auto updated_module = pass(module);
728728

729-
const auto* te_compiler_update_weights =
730-
runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights");
729+
// A temporary solution until we can rewrite the auto-scheduler task extraction code to work
730+
// in a more reasonable way.
731+
if (backend::IsAutoSchedulerEnabled()) {
732+
const auto* te_compiler_update_weights =
733+
runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights");
731734

732-
ICHECK(te_compiler_update_weights != nullptr)
733-
<< "auto_scheduler.relay_integration.te_compiler_update_weights";
735+
ICHECK(te_compiler_update_weights != nullptr)
736+
<< "auto_scheduler.relay_integration.te_compiler_update_weights";
734737

735-
Map<String, tvm::Integer> weight_map;
738+
Map<String, tvm::Integer> weight_map;
736739

737-
for (auto pair : compiler->GetOpWeights()) {
738-
weight_map.Set(pair.first, pair.second);
739-
}
740+
for (auto pair : compiler->GetOpWeights()) {
741+
weight_map.Set(pair.first, pair.second);
742+
}
740743

741-
(*te_compiler_update_weights)(weight_map);
744+
(*te_compiler_update_weights)(weight_map);
745+
}
742746

743747
LoweredModule lowered_module;
744748
lowered_module.main_module = updated_module;

0 commit comments

Comments
 (0)