Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,17 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod);
*/
TVM_DLL Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& mod);

/*!
* \brief Analyze the device context of each IR node in a given relay module.
*
* \param mod The module for analysis.
* \param default_context The default context used by unassigned IR nodes.
*
* \return The mapping between an IR node and its associated context.
*/
TVM_DLL std::unordered_map<Expr, TVMContext, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>
ContextAnalysis(const IRModule& mod, const TVMContext& default_context);

} // namespace relay
} // namespace tvm

Expand Down
23 changes: 22 additions & 1 deletion include/tvm/runtime/vm/bytecode.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ enum class Opcode {
AllocStorage = 16U,
ShapeOf = 17U,
ReshapeTensor = 18U,
DeviceCopy = 19U,
};

/*! \brief A single virtual machine instruction.
Expand Down Expand Up @@ -196,6 +197,8 @@ struct Instruction {
Index alignment;
/*! \brief The hint of the dtype. */
DLDataType dtype_hint;
/*! \brief The device type of the allocation. */
Index device_type;
} alloc_storage;
struct /* ShapeOf Operands */ {
RegName tensor;
Expand All @@ -204,6 +207,13 @@ struct Instruction {
RegName tensor;
RegName newshape;
} reshape_tensor;
struct /* DeviceCopy Operands */ {
RegName src;
/*! \brief The source device type. */
Index src_device_type;
/*! \brief The destination device type. */
Index dst_device_type;
};
};

/*!
Expand Down Expand Up @@ -341,11 +351,12 @@ struct Instruction {
* \param size The size of the allocation.
* \param alignment The allocation's alignment.
* \param dtype_hint The data type hint for the allocator.
* \param device_type The device type for the allocator.
* \param dst The destination to place the storage.
* \return The alloc storage instruction.
*/
static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint,
RegName dst);
Index device_type, RegName dst);
/*!
* \brief Get the shape of an input tensor.
* \param tensor The input tensor.
Expand All @@ -361,6 +372,16 @@ struct Instruction {
* \return The reshape tensor instruction.
*/
static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst);
/*!
* \brief Copy tensor cross different devices.
* \param src The source register.
* \param src_device_type The device type of the tensor for the source register.
* \param dst_device_type The device type of the tensor ofr the destination register.
* \param dst The destination register to store the copied tensor.
* \return The device copy instruction.
*/
static Instruction DeviceCopy(RegName src, Index src_device_type, Index dst_device_type,
RegName dst);

Instruction();
Instruction(const Instruction& instr);
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/runtime/vm/executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ class Executable : public ModuleNode {
std::unordered_map<std::string, Index> primitive_map;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;
/*! \brief The device type for each constant. */
std::vector<Index> const_device_type;

private:
/*!
Expand Down
16 changes: 10 additions & 6 deletions include/tvm/runtime/vm/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,17 @@ struct VMFunction {
std::vector<Instruction> instructions;
/*! \brief The size of the frame for this function */
Index register_file_size;
/*! \brief The device type of each parameter for this function. */
std::vector<Index> params_device_type;

VMFunction(const std::string& name, std::vector<std::string> params,
const std::vector<Instruction>& instructions, Index register_file_size)
const std::vector<Instruction>& instructions, Index register_file_size,
const std::vector<Index> params_device_type = {})
: name(name),
params(params),
instructions(instructions),
register_file_size(register_file_size) {}
register_file_size(register_file_size),
params_device_type(params_device_type) {}

VMFunction() {}

Expand Down Expand Up @@ -244,8 +248,8 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief Run VM dispatch loop. */
void RunLoop();

/*! \brief Get device context for params. */
TVMContext GetParamsContext() const;
/*! \brief Get context from the context list based on a given device type. */
TVMContext GetContext(Index device_type) const;

/*!
* \brief Invoke a global setting up the VM state to execute.
Expand Down Expand Up @@ -273,8 +277,8 @@ class VirtualMachine : public runtime::ModuleNode {
std::unordered_map<std::string, std::vector<ObjectRef>> inputs_;
/*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs_;
/*! \brief The mapping from TVM context to memory allocator. */
std::unordered_map<TVMContext, Allocator*> allocators_;
/*! \brief The cached memory allocators. */
std::vector<Allocator*> allocators_;
/*!
* \brief The constant pool for runtime. It caches the device dependent
* object to avoid rellocation of constants during inference.
Expand Down
14 changes: 14 additions & 0 deletions python/tvm/ir/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,20 @@ def update(self, other):
other = Module(other)
return _ffi_api.Module_Update(self, other)

def update_func(self, var, func):
"""Update the function corresponding to a global variable in the
module.

Parameters
----------
var: GlobalVar
The global variable.

func: tvm.relay.Function
The function to be inserted.
"""
return _ffi_api.Module_UpdateFunction(self, var, func)

def get_global_var(self, name):
"""Get a global variable in the function by name.

Expand Down
15 changes: 15 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@
from .feature import Feature


def context_analysis(mod, default_context):
"""Analyze the device context information of each IR node in a Relay
program.

Parameters
----------
mod : tvm.IRModule
The input module.

default_context : tvm.runtime.TVMContext
The default context allocated to an IR node.
"""
return _ffi_api.ContextAnalysis(mod, default_context)


def post_order_visit(expr, fvisit):
"""Recursively visit the ir in post DFS order node,
apply fvisit. Each node is guaranteed to be visited
Expand Down
7 changes: 0 additions & 7 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import tvm.runtime.vm as vm_rt
from tvm import autotvm
from tvm.relay import expr as _expr
from tvm.relay.ty import is_dynamic
from tvm.relay.backend.interpreter import Executor
from . import _vm

Expand Down Expand Up @@ -261,12 +260,6 @@ def _make_executor(self, expr=None):

def _vm_wrapper(*args, **kwargs):
args = self._convert_args(main, args, kwargs)
ret_type = self.mod["main"].checked_type.ret_type
if is_dynamic(ret_type) and "llvm" not in str(self.target) and "arm" not in str(
self.target):
raise ValueError(
"Virtual Machine only supports dynamic graphs on CPU, got output type",
ret_type, "on target", self.target)
return self.vm.run(*args)

return _vm_wrapper
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
register_injective_schedule("left_shift")
register_injective_schedule("shape_of")
register_injective_schedule("ndarray_size")
register_injective_schedule("device_copy")
register_broadcast_schedule("fast_exp")
register_broadcast_schedule("fast_tanh")
register_broadcast_schedule("fast_erf")
Expand Down Expand Up @@ -241,3 +242,4 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("fast_erf", False, elemwise_shape_func)
register_shape_func("floor", False, elemwise_shape_func)
register_shape_func("log", False, elemwise_shape_func)
register_shape_func("device_copy", False, elemwise_shape_func)
Loading