@@ -570,15 +570,16 @@ class GraphExecutor(_interpreter.Executor):
570570 device : :py:class:`Device`
571571 The runtime device to run the code on.
572572
573- raw_targets : Array[tvm.target.Target]
574- The available targets.
573+ target : any multi-target like object, see Target.canon_multi_target
574+ For homogeneous compilation, the unique build target.
575+ For heterogeneous compilation, a dictionary or list of possible build targets.
575576 """
576577
577- def __init__ (self , mod , device , raw_targets ):
578+ def __init__ (self , mod , device , target ):
578579 assert mod is not None
579580 self .mod = mod
580581 self .device = device
581- self .raw_targets = raw_targets
582+ self .target = target
582583
583584 def _make_executor (self , expr = None ):
584585 if expr :
@@ -589,7 +590,7 @@ def _make_executor(self, expr=None):
589590 raise ValueError (
590591 "Graph Executor only supports static graphs, got output type" , ret_type
591592 )
592- mod = build (self .mod , target = self .raw_targets )
593+ mod = build (self .mod , target = self .target )
593594 gmodule = _graph_executor .GraphModule (mod ["default" ](self .device ))
594595
595596 def _unflatten (flat_iter , cur_type ):
@@ -630,16 +631,16 @@ class AotExecutor(_interpreter.Executor):
630631 device : :py:class:`Device`
631632 The runtime device to run the code on.
632633
633- raw_targets : Array[tvm.target.Target]
634- The available targets.
634+ target : any multi-target like object, see Target.canon_multi_target
635+ For homogeneous compilation, the unique build target.
636+ For heterogeneous compilation, a dictionary or list of possible build targets.
635637 """
636638
637- def __init__ (self , mod , device , raw_targets ):
639+ def __init__ (self , mod , device , target ):
638640 assert mod is not None
639641 self .mod = mod
640642 self .device = device
641- self .raw_targets = raw_targets
642- assert raw_targets [0 ].attrs .get ("executor" , "graph" ) == "aot"
643+ self .target = target
643644
644645 def _make_executor (self , expr = None ):
645646 if expr :
@@ -648,7 +649,7 @@ def _make_executor(self, expr=None):
648649 ret_type = self .mod ["main" ].checked_type .ret_type
649650 if _ty .is_dynamic (ret_type ):
650651 raise ValueError ("AOT Executor only supports static graphs, got output type" , ret_type )
651- mod = build (self .mod , target = self .raw_targets )
652+ mod = build (self .mod , target = self .target )
652653
653654 # NOTE: Given AOT requires use of the "c" backend, must export/import to compile the
654655 # generated code.
@@ -722,6 +723,8 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
722723 target : any multi-target like object, see Target.canon_multi_target
723724 For homogeneous compilation, the unique build target.
724725 For heterogeneous compilation, a dictionary or list of possible build targets.
726+ CAUTION: Though this API allows multiple targets, it does not allow multiple devices, so
727+ heterogenous compilation is not yet supported.
725728
726729 params : dict of str to NDArray
727730 Input parameters to the graph that do not change
@@ -737,11 +740,14 @@ def create_executor(kind="debug", mod=None, device=None, target="llvm", params=N
737740 if device is not None :
738741 assert device .device_type == raw_targets [0 ].kind .device_type
739742 else :
743+ # Use the first target as the device.
740744 device = _nd .device (raw_targets [0 ].kind .device_type , 0 )
741745
742746 if params is not None :
743747 mod = IRModule .from_expr (bind_params_by_name (mod ["main" ], params ))
744748
749+ assert raw_targets [0 ].attrs .get ("executor" ) == kind
750+
745751 if kind == "debug" :
746752 assert len (raw_targets ) == 1 , "The interpreter currently only supports a single target"
747753 return _interpreter .Interpreter (mod , device , raw_targets [0 ])
0 commit comments