Skip to content

Commit 1224d56

Browse files
authored
[RELAY][VM] Enable heterogeneous execution for Relay VM (#6337)
* vm heterogeneous execution * context analysis on module * fix profiler * fix memory plan * add more unification * add serialization * add gpu tests for test_adt * cache visited functions * path compression * C++ context analysis * remove python context analysis * add tests * clean * lint * fix * enable gpu test for dynamic namespace * remove GetParamsContext * fix comments and add doc for context analysis * cache context * cache allocator * rebase and fix comments
1 parent 12d66d4 commit 1224d56

33 files changed

+1631
-338
lines changed

include/tvm/relay/analysis.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,17 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod);
263263
*/
264264
TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);
265265

266+
/*!
267+
* \brief Analyze the device context of each IR node in a given relay module.
268+
*
269+
* \param mod The module for analysis.
270+
* \param default_context The default context used by unassigned IR nodes.
271+
*
272+
* \return The mapping between an IR node and its associated context.
273+
*/
274+
TVM_DLL std::unordered_map<Expr, TVMContext, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
275+
ContextAnalysis(const IRModule& mod, const TVMContext& default_context);
276+
266277
} // namespace relay
267278
} // namespace tvm
268279

include/tvm/runtime/vm/bytecode.h

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ enum class Opcode {
6666
AllocStorage = 16U,
6767
ShapeOf = 17U,
6868
ReshapeTensor = 18U,
69+
DeviceCopy = 19U,
6970
};
7071

7172
/*! \brief A single virtual machine instruction.
@@ -196,6 +197,8 @@ struct Instruction {
196197
Index alignment;
197198
/*! \brief The hint of the dtype. */
198199
DLDataType dtype_hint;
200+
/*! \brief The device type of the allocation. */
201+
Index device_type;
199202
} alloc_storage;
200203
struct /* ShapeOf Operands */ {
201204
RegName tensor;
@@ -204,6 +207,13 @@ struct Instruction {
204207
RegName tensor;
205208
RegName newshape;
206209
} reshape_tensor;
210+
struct /* DeviceCopy Operands */ {
211+
RegName src;
212+
/*! \brief The source device type. */
213+
Index src_device_type;
214+
/*! \brief The destination device type. */
215+
Index dst_device_type;
216+
};
207217
};
208218

209219
/*!
@@ -341,11 +351,12 @@ struct Instruction {
341351
* \param size The size of the allocation.
342352
* \param alignment The allocation's alignment.
343353
* \param dtype_hint The data type hint for the allocator.
354+
* \param device_type The device type for the allocator.
344355
* \param dst The destination to place the storage.
345356
* \return The alloc storage instruction.
346357
*/
347358
static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
348-
RegName dst);
359+
Index device_type, RegName dst);
349360
/*!
350361
* \brief Get the shape of an input tensor.
351362
* \param tensor The input tensor.
@@ -361,6 +372,16 @@ struct Instruction {
361372
* \return The reshape tensor instruction.
362373
*/
363374
static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst);
375+
/*!
376+
* \brief Copy tensor cross different devices.
377+
* \param src The source register.
378+
* \param src_device_type The device type of the tensor for the source register.
379+
* \param dst_device_type The device type of the tensor ofr the destination register.
380+
* \param dst The destination register to store the copied tensor.
381+
* \return The device copy instruction.
382+
*/
383+
static Instruction DeviceCopy(RegName src, Index src_device_type, Index dst_device_type,
384+
RegName dst);
364385

365386
Instruction();
366387
Instruction(const Instruction& instr);

include/tvm/runtime/vm/executable.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ class Executable : public ModuleNode {
161161
std::unordered_map<std::string, Index> primitive_map;
162162
/*! \brief The virtual machine's function table. */
163163
std::vector<VMFunction> functions;
164+
/*! \brief The device type for each constant. */
165+
std::vector<Index> const_device_type;
164166

165167
private:
166168
/*!

include/tvm/runtime/vm/vm.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,17 @@ struct VMFunction {
8383
std::vector<Instruction> instructions;
8484
/*! \brief The size of the frame for this function */
8585
Index register_file_size;
86+
/*! \brief The device type of each parameter for this function. */
87+
std::vector<Index> params_device_type;
8688

8789
VMFunction(const std::string& name, std::vector<std::string> params,
88-
const std::vector<Instruction>& instructions, Index register_file_size)
90+
const std::vector<Instruction>& instructions, Index register_file_size,
91+
const std::vector<Index> params_device_type = {})
8992
: name(name),
9093
params(params),
9194
instructions(instructions),
92-
register_file_size(register_file_size) {}
95+
register_file_size(register_file_size),
96+
params_device_type(params_device_type) {}
9397

9498
VMFunction() {}
9599

@@ -244,8 +248,8 @@ class VirtualMachine : public runtime::ModuleNode {
244248
/*! \brief Run VM dispatch loop. */
245249
void RunLoop();
246250

247-
/*! \brief Get device context for params. */
248-
TVMContext GetParamsContext() const;
251+
/*! \brief Get context from the context list based on a given device type. */
252+
TVMContext GetContext(Index device_type) const;
249253

250254
/*!
251255
* \brief Invoke a global setting up the VM state to execute.
@@ -273,8 +277,8 @@ class VirtualMachine : public runtime::ModuleNode {
273277
std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
274278
/*! \brief The set of TVM contexts the VM is currently executing on. */
275279
std::vector<TVMContext> ctxs_;
276-
/*! \brief The mapping from TVM context to memory allocator. */
277-
std::unordered_map<TVMContext, Allocator*> allocators_;
280+
/*! \brief The cached memory allocators. */
281+
std::vector<Allocator*> allocators_;
278282
/*!
279283
* \brief The constant pool for runtime. It caches the device dependent
280284
* object to avoid rellocation of constants during inference.

python/tvm/ir/module.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,20 @@ def update(self, other):
118118
other = Module(other)
119119
return _ffi_api.Module_Update(self, other)
120120

121+
def update_func(self, var, func):
122+
"""Update the function corresponding to a global variable in the
123+
module.
124+
125+
Parameters
126+
----------
127+
var: GlobalVar
128+
The global variable.
129+
130+
func: tvm.relay.Function
131+
The function to be inserted.
132+
"""
133+
return _ffi_api.Module_UpdateFunction(self, var, func)
134+
121135
def get_global_var(self, name):
122136
"""Get a global variable in the function by name.
123137

python/tvm/relay/analysis/analysis.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,21 @@
2828
from .feature import Feature
2929

3030

31+
def context_analysis(mod, default_context):
32+
"""Analyze the device context information of each IR node in a Relay
33+
program.
34+
35+
Parameters
36+
----------
37+
mod : tvm.IRModule
38+
The input module.
39+
40+
default_context : tvm.runtime.TVMContext
41+
The default context allocated to an IR node.
42+
"""
43+
return _ffi_api.ContextAnalysis(mod, default_context)
44+
45+
3146
def post_order_visit(expr, fvisit):
3247
"""Recursively visit the ir in post DFS order node,
3348
apply fvisit. Each node is guaranteed to be visited

python/tvm/relay/backend/vm.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import tvm.runtime.vm as vm_rt
2828
from tvm import autotvm
2929
from tvm.relay import expr as _expr
30-
from tvm.relay.ty import is_dynamic
3130
from tvm.relay.backend.interpreter import Executor
3231
from . import _vm
3332

@@ -261,12 +260,6 @@ def _make_executor(self, expr=None):
261260

262261
def _vm_wrapper(*args, **kwargs):
263262
args = self._convert_args(main, args, kwargs)
264-
ret_type = self.mod["main"].checked_type.ret_type
265-
if is_dynamic(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
266-
self.target):
267-
raise ValueError(
268-
"Virtual Machine only supports dynamic graphs on CPU, got output type",
269-
ret_type, "on target", self.target)
270263
return self.vm.run(*args)
271264

272265
return _vm_wrapper

python/tvm/relay/op/_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
register_injective_schedule("left_shift")
8585
register_injective_schedule("shape_of")
8686
register_injective_schedule("ndarray_size")
87+
register_injective_schedule("device_copy")
8788
register_broadcast_schedule("fast_exp")
8889
register_broadcast_schedule("fast_tanh")
8990
register_broadcast_schedule("fast_erf")
@@ -241,3 +242,4 @@ def elemwise_shape_func(attrs, inputs, _):
241242
register_shape_func("fast_erf", False, elemwise_shape_func)
242243
register_shape_func("floor", False, elemwise_shape_func)
243244
register_shape_func("log", False, elemwise_shape_func)
245+
register_shape_func("device_copy", False, elemwise_shape_func)

0 commit comments

Comments
 (0)