@@ -422,12 +422,10 @@ Array<LoweredFunc> lower(Schedule sch,
422422  return  Array<LoweredFunc>({ ir::MakeAPI (stmt, name, out_arg_list, 0 , config->restricted_func ) });
423423}
424424
425- runtime::Module build (const  Array<LoweredFunc>& funcs,
426-                       const  Target& target,
427-                       const  Target& target_host,
428-                       const  BuildConfig& config,
429-                       Array<LoweredFunc>* fhost_ret,
430-                       std::vector<runtime::Module>* devmod_ret) {
425+ Array<Array<LoweredFunc> > split_dev_host_funcs (const  Array<LoweredFunc>& funcs,
426+                                                 const  Target& target,
427+                                                 const  Target& target_host,
428+                                                 const  BuildConfig& config) {
431429  std::unordered_set<std::string> all_names;
432430  for  (const  auto  &x : funcs) {
433431    CHECK (all_names.count (x->name ) == 0 ) << " Duplicate function name " name ;
@@ -466,12 +464,6 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
466464    }
467465  }
468466
469-   if  (fhost_ret != nullptr ) {
470-     for  (auto  f : fhost) {
471-       fhost_ret->push_back (f);
472-     }
473-   }
474- 
475467  auto  keys = target->keys ();
476468  bool  target_is_gpu =
477469    std::find (keys.begin (), keys.end (), " gpu" end ();
@@ -500,14 +492,25 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
500492    func = ir::CombineContextCall (func);
501493    fhost.Set (i, func);
502494  }
495+   Array<Array<LoweredFunc> > ret;
496+   ret.push_back (fhost);
497+   ret.push_back (fdevice);
498+   return  ret;
499+ }
500+ 
501+ runtime::Module build (const  Array<LoweredFunc>& funcs,
502+                       const  Target& target,
503+                       const  Target& target_host,
504+                       const  BuildConfig& config) {
505+   auto  target_host_val = target_host.defined () ? target_host : DefaultTargetHost (target);
506+   auto  host_dev_funcs = split_dev_host_funcs (funcs, target, target_host, config);
507+   auto & fhost = host_dev_funcs[0 ];
508+   auto & fdevice = host_dev_funcs[1 ];
503509
504510  auto  mhost = codegen::Build (fhost, target_host_val->str ());
505511
506512  if  (fdevice.size () > 0 ) {
507513    auto  mdev = codegen::Build (fdevice, target->str ());
508-     if  (devmod_ret != nullptr ) {
509-       devmod_ret->push_back (mdev);
510-     }
511514    mhost.Import (mdev);
512515  }
513516
0 commit comments