Skip to content

Commit 40d78be

Browse files
committed
Use main_func_info rather than bespoke logic in AOT
This moves from using the bespoke AOT UpdateMainWorkspaceSize to the LoweredModule main_func_info property to unify with Graph executor codegen.
1 parent 8b075fd commit 40d78be

File tree

1 file changed

+1
-36
lines changed

1 file changed

+1
-36
lines changed

src/relay/backend/aot_executor_codegen.cc

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -363,40 +363,6 @@ class AOTExecutorCodegen : public ExprVisitor {
363363
return ss.str();
364364
}
365365

366-
/*!
367-
* \brief Update the "main" control function's metadata
368-
*
369-
* \param func The main function that contains calls to operator tir primitive functions
370-
*/
371-
void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const relay::Function& func) {
372-
auto workspace_byte_alignment = target_host_->GetAttr<Integer>("workspace-byte-alignment")
373-
.value_or(tvm::runtime::kDefaultWorkspaceAlignment);
374-
Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment);
375-
// Populate FunctionInfo
376-
auto fi_node = make_object<FunctionInfoNode>();
377-
// Initialize all target workspaces to zero
378-
for (const auto& kv : targets_) {
379-
auto tgt = kv.second;
380-
fi_node->workspace_sizes.Set(tgt, 0);
381-
}
382-
fi_node->workspace_sizes.Set(target_host_, workspace_size);
383-
fi_node->relay_primfuncs.Set(target_host_, func);
384-
385-
int64_t io_size = 0;
386-
for (const auto& input : input_vars_) {
387-
io_size += CalculateRelayExprSizeBytes(input->checked_type());
388-
}
389-
io_size += CalculateRelayExprSizeBytes(func->body->checked_type());
390-
fi_node->io_sizes.Set(target_host_, io_size);
391-
392-
int64_t const_size = 0;
393-
for (const auto& kv : params_by_expr_) {
394-
const_size += CalculateRelayExprSizeBytes(kv.first->checked_type());
395-
}
396-
fi_node->constant_sizes.Set(target_host_, const_size);
397-
function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node));
398-
}
399-
400366
void VisitExpr_(const CallNode* op) override {
401367
// Descend the call tree
402368
for (auto arg : op->args) {
@@ -635,6 +601,7 @@ class AOTExecutorCodegen : public ExprVisitor {
635601
tec::UpdateFunctionMetadata(func, this->function_metadata_);
636602
});
637603

604+
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
638605
auto lowered_main = lowered_module.main_module->Lookup("main");
639606
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
640607

@@ -670,8 +637,6 @@ class AOTExecutorCodegen : public ExprVisitor {
670637
// because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
671638
// to run the LegalizePackedCalls pass.
672639
auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
673-
UpdateMainWorkspaceSize(prim_func, lowered_main_func);
674-
675640
LoweredOutput ret;
676641

677642
ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();

0 commit comments

Comments
 (0)