@@ -192,25 +192,26 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
192192 << " }\n " ;
193193 }
194194
195- void GenerateEntrypointForUnpackedAPI (const std::string& run_func) {
195+ void GenerateEntrypointForUnpackedAPI (const std::string& entrypoint_name,
196+ const std::string& run_func) {
196197 code_ << " TVM_DLL int32_t " << run_func << " (" ;
197- int total_args = (metadata_->num_inputs + metadata_->num_outputs );
198- for (int i = 0 ; i < total_args; ++i) {
199- code_ << " arg" << i;
198+ unsigned int total_args = (metadata_->inputs . size () + metadata_->num_outputs );
199+ for (unsigned int i = 0 ; i < total_args; ++i) {
200+ code_ << " void* arg" << i;
200201 if (i + 1 != total_args) {
201202 code_ << " ," ;
202203 }
203204 }
204205 code_ << " );\n " ;
205- code_ << " static int32_t " << ::tvm::runtime::symbol::tvm_module_main ;
206+ code_ << " int32_t " << entrypoint_name ;
206207 code_ << " (void* args, void* type_code, int num_args, void* out_value, void* "
207208 " out_type_code, void* resource_handle) {\n " ;
208209 code_ << " return " << run_func << " (" ;
209- for (int i = 0 ; i < metadata_->num_inputs ; ++i) {
210+ for (unsigned int i = 0 ; i < metadata_->inputs . size () ; ++i) {
210211 code_ << " ((DLTensor*)(((TVMValue*)args)[" << i << " ].v_handle))[0].data," ;
211212 }
212213 for (int i = 0 ; i < metadata_->num_outputs ; ++i) {
213- int j = metadata_->num_inputs + i;
214+ int j = metadata_->inputs . size () + i;
214215 code_ << " ((DLTensor*)(((TVMValue*)args)[" << j << " ].v_handle))[0].data" ;
215216 if (i + 1 != metadata_->num_outputs ) {
216217 code_ << " ," ;
@@ -220,37 +221,85 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode {
220221 code_ << " }\n " ;
221222 }
222223
223- void GenerateEntrypointForPackedAPI (const std::string& run_func) {
224+ void GenerateEntrypointForPackedAPI (const std::string& entrypoint_name,
225+ const std::string& run_func) {
224226 code_ << " TVM_DLL int32_t " << run_func;
225227 code_ << " (void* args, void* type_code, int num_args, void* out_value, void* "
226228 " out_type_code, void* resource_handle);\n " ;
227- code_ << " static int32_t " << ::tvm::runtime::symbol::tvm_module_main ;
229+ code_ << " int32_t " << entrypoint_name ;
228230 code_ << " (void* args, void* type_code, int num_args, void* out_value, void* "
229231 " out_type_code, void* resource_handle) {\n " ;
230232 code_ << " return " << run_func;
231233 code_ << " (args, type_code, num_args, out_value, out_type_code, resource_handle);\n " ;
232234 code_ << " }\n " ;
233235 }
234236
237+ void GenerateCInterfaceEntrypoint (const std::string& entrypoint_name, const std::string& run_func,
238+ const std::string& mod_name) {
239+ code_ << " #include <" << mod_name << " .h>\n " ;
240+ code_ << " TVM_DLL int32_t " << run_func << " (" ;
241+ unsigned int total_args = (metadata_->inputs .size () + metadata_->num_outputs );
242+ for (unsigned int i = 0 ; i < total_args; ++i) {
243+ code_ << " void* arg" << i;
244+ if (i + 1 != total_args) {
245+ code_ << " ," ;
246+ }
247+ }
248+ code_ << " );\n " ;
249+ code_ << " int32_t " << entrypoint_name << " (" ;
250+ code_ << " struct " << runtime::get_name_mangled (mod_name, " inputs" ) << " * inputs,"
251+ << " struct " << runtime::get_name_mangled (mod_name, " outputs" ) << " * outputs,"
252+ << " struct " << runtime::get_name_mangled (mod_name, " memory" ) << " * memory,"
253+ << " struct " << runtime::get_name_mangled (mod_name, " devices" ) << " * devices"
254+ << " ) {" ;
255+ code_ << " return " << run_func << " (" ;
256+ for (const auto & input : metadata_->inputs ) {
257+ code_ << " inputs->" << input->name_hint () << " ," ;
258+ }
259+ if (metadata_->num_outputs == 1 ) {
260+ code_ << " outputs->output" ;
261+ } else {
262+ for (int i = 0 ; i < metadata_->num_outputs ; ++i) {
263+ code_ << " outputs->output" << i;
264+ if (i + 1 != metadata_->num_outputs ) {
265+ code_ << " ," ;
266+ }
267+ }
268+ }
269+ code_ << " );\n " ;
270+ code_ << " }\n " ;
271+ }
272+
235273 void GenerateAOTDescriptor () {
236- const std::string run_func = ::tvm::runtime::symbol::tvm_run_func_suffix;
237- const std::string run_func_mangled = runtime::get_name_mangled (metadata_->mod_name , run_func);
274+ const std::string run_func_suffix = ::tvm::runtime::symbol::tvm_run_func_suffix;
275+ const std::string tvm_entrypoint_suffix = ::tvm::runtime::symbol::tvm_entrypoint_suffix;
276+ const std::string run_func_mangled =
277+ runtime::get_name_mangled (metadata_->mod_name , run_func_suffix);
278+ const std::string entrypoint_mangled =
279+ runtime::get_name_mangled (metadata_->mod_name , tvm_entrypoint_suffix);
238280 const std::string network_mangled = runtime::get_name_mangled (metadata_->mod_name , " network" );
239- code_ << " #include \" tvm/runtime/crt/internal/aot_executor/aot_executor.h\"\n " ;
281+ auto unpacked_api = target_->GetAttr <Bool>(" unpacked-api" ).value_or (Bool (false ));
282+ auto interface_api = target_->GetAttr <String>(" interface-api" ).value_or (String (" packed" ));
283+
240284 code_ << " #include \" tvm/runtime/c_runtime_api.h\"\n " ;
241285 code_ << " #ifdef __cplusplus\n " ;
242- code_ << " extern \" C\"\n " ;
286+ code_ << " extern \" C\" { \n " ;
243287 code_ << " #endif\n " ;
244- if (target_->GetAttr <Bool>(" unpacked-api" ).value_or (Bool (false ))) {
245- GenerateEntrypointForUnpackedAPI (run_func_mangled);
288+
289+ if (unpacked_api) {
290+ if (interface_api == " c" ) {
291+ GenerateCInterfaceEntrypoint (entrypoint_mangled, run_func_mangled, metadata_->mod_name );
292+ } else {
293+ GenerateEntrypointForUnpackedAPI (entrypoint_mangled, run_func_mangled);
294+ }
246295 } else {
247- GenerateEntrypointForPackedAPI (run_func_mangled);
296+ ICHECK_EQ (interface_api, " packed" ) << " Packed interface required for packed operators" ;
297+ GenerateEntrypointForPackedAPI (entrypoint_mangled, run_func_mangled);
248298 }
249- code_ << " const tvm_model_t " << network_mangled << " = {\n "
250- << " .run_func = &" << ::tvm::runtime::symbol::tvm_module_main << " ,\n "
251- << " .num_input_tensors = " << metadata_->num_inputs << " ,\n "
252- << " .num_output_tensors = " << metadata_->num_outputs << " , \n "
253- << " };\n " ;
299+
300+ code_ << " #ifdef __cplusplus\n " ;
301+ code_ << " }\n " ;
302+ code_ << " #endif\n " ;
254303 }
255304
256305 void CreateSource () {
0 commit comments