diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index c65bb41282cf..00da9408408b 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -263,6 +263,17 @@ TVM_DLL IRModule GetCalibrateModule(IRModule mod); */ TVM_DLL Map> GetCalibrateOutputMap(const IRModule& mod); +/*! + * \brief Analyze the device context of each IR node in a given relay module. + * + * \param mod The module for analysis. + * \param default_context The default context used by unassigned IR nodes. + * + * \return The mapping between an IR node and its associated context. + */ +TVM_DLL std::unordered_map +ContextAnalysis(const IRModule& mod, const TVMContext& default_context); + } // namespace relay } // namespace tvm diff --git a/include/tvm/runtime/vm/bytecode.h b/include/tvm/runtime/vm/bytecode.h index 89a3164f7483..edcbd881e074 100644 --- a/include/tvm/runtime/vm/bytecode.h +++ b/include/tvm/runtime/vm/bytecode.h @@ -66,6 +66,7 @@ enum class Opcode { AllocStorage = 16U, ShapeOf = 17U, ReshapeTensor = 18U, + DeviceCopy = 19U, }; /*! \brief A single virtual machine instruction. @@ -196,6 +197,8 @@ struct Instruction { Index alignment; /*! \brief The hint of the dtype. */ DLDataType dtype_hint; + /*! \brief The device type of the allocation. */ + Index device_type; } alloc_storage; struct /* ShapeOf Operands */ { RegName tensor; @@ -204,6 +207,13 @@ struct Instruction { RegName tensor; RegName newshape; } reshape_tensor; + struct /* DeviceCopy Operands */ { + RegName src; + /*! \brief The source device type. */ + Index src_device_type; + /*! \brief The destination device type. */ + Index dst_device_type; + }; }; /*! @@ -341,11 +351,12 @@ struct Instruction { * \param size The size of the allocation. * \param alignment The allocation's alignment. * \param dtype_hint The data type hint for the allocator. + * \param device_type The device type for the allocator. * \param dst The destination to place the storage. * \return The alloc storage instruction. */ static Instruction AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, - RegName dst); + Index device_type, RegName dst); /*! * \brief Get the shape of an input tensor. * \param tensor The input tensor. @@ -361,6 +372,16 @@ struct Instruction { * \return The reshape tensor instruction. */ static Instruction ReshapeTensor(RegName tensor, RegName newshape, RegName dst); + /*! + * \brief Copy tensor cross different devices. + * \param src The source register. + * \param src_device_type The device type of the tensor for the source register. + * \param dst_device_type The device type of the tensor ofr the destination register. + * \param dst The destination register to store the copied tensor. + * \return The device copy instruction. + */ + static Instruction DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, + RegName dst); Instruction(); Instruction(const Instruction& instr); diff --git a/include/tvm/runtime/vm/executable.h b/include/tvm/runtime/vm/executable.h index cc38da75a0c7..8d3f651758d1 100644 --- a/include/tvm/runtime/vm/executable.h +++ b/include/tvm/runtime/vm/executable.h @@ -161,6 +161,8 @@ class Executable : public ModuleNode { std::unordered_map primitive_map; /*! \brief The virtual machine's function table. */ std::vector functions; + /*! \brief The device type for each constant. */ + std::vector const_device_type; private: /*! diff --git a/include/tvm/runtime/vm/vm.h b/include/tvm/runtime/vm/vm.h index 273b8fe60847..e9f51de611b6 100644 --- a/include/tvm/runtime/vm/vm.h +++ b/include/tvm/runtime/vm/vm.h @@ -83,13 +83,17 @@ struct VMFunction { std::vector instructions; /*! \brief The size of the frame for this function */ Index register_file_size; + /*! \brief The device type of each parameter for this function. */ + std::vector params_device_type; VMFunction(const std::string& name, std::vector params, - const std::vector& instructions, Index register_file_size) + const std::vector& instructions, Index register_file_size, + const std::vector params_device_type = {}) : name(name), params(params), instructions(instructions), - register_file_size(register_file_size) {} + register_file_size(register_file_size), + params_device_type(params_device_type) {} VMFunction() {} @@ -244,8 +248,8 @@ class VirtualMachine : public runtime::ModuleNode { /*! \brief Run VM dispatch loop. */ void RunLoop(); - /*! \brief Get device context for params. */ - TVMContext GetParamsContext() const; + /*! \brief Get context from the context list based on a given device type. */ + TVMContext GetContext(Index device_type) const; /*! * \brief Invoke a global setting up the VM state to execute. @@ -273,8 +277,8 @@ class VirtualMachine : public runtime::ModuleNode { std::unordered_map> inputs_; /*! \brief The set of TVM contexts the VM is currently executing on. */ std::vector ctxs_; - /*! \brief The mapping from TVM context to memory allocator. */ - std::unordered_map allocators_; + /*! \brief The cached memory allocators. */ + std::vector allocators_; /*! * \brief The constant pool for runtime. It caches the device dependent * object to avoid rellocation of constants during inference. diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py index 8d75d8e8ee21..2f6fd2069460 100644 --- a/python/tvm/ir/module.py +++ b/python/tvm/ir/module.py @@ -118,6 +118,20 @@ def update(self, other): other = Module(other) return _ffi_api.Module_Update(self, other) + def update_func(self, var, func): + """Update the function corresponding to a global variable in the + module. + + Parameters + ---------- + var: GlobalVar + The global variable. + + func: tvm.relay.Function + The function to be inserted. + """ + return _ffi_api.Module_UpdateFunction(self, var, func) + def get_global_var(self, name): """Get a global variable in the function by name. diff --git a/python/tvm/relay/analysis/analysis.py b/python/tvm/relay/analysis/analysis.py index 99f4252ac4f7..d417c2b39b08 100644 --- a/python/tvm/relay/analysis/analysis.py +++ b/python/tvm/relay/analysis/analysis.py @@ -28,6 +28,21 @@ from .feature import Feature +def context_analysis(mod, default_context): + """Analyze the device context information of each IR node in a Relay + program. + + Parameters + ---------- + mod : tvm.IRModule + The input module. + + default_context : tvm.runtime.TVMContext + The default context allocated to an IR node. + """ + return _ffi_api.ContextAnalysis(mod, default_context) + + def post_order_visit(expr, fvisit): """Recursively visit the ir in post DFS order node, apply fvisit. Each node is guaranteed to be visited diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 73b0d22804bd..656652c23004 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -27,7 +27,6 @@ import tvm.runtime.vm as vm_rt from tvm import autotvm from tvm.relay import expr as _expr -from tvm.relay.ty import is_dynamic from tvm.relay.backend.interpreter import Executor from . import _vm @@ -261,12 +260,6 @@ def _make_executor(self, expr=None): def _vm_wrapper(*args, **kwargs): args = self._convert_args(main, args, kwargs) - ret_type = self.mod["main"].checked_type.ret_type - if is_dynamic(ret_type) and "llvm" not in str(self.target) and "arm" not in str( - self.target): - raise ValueError( - "Virtual Machine only supports dynamic graphs on CPU, got output type", - ret_type, "on target", self.target) return self.vm.run(*args) return _vm_wrapper diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index eccc2c3c5f15..c81d4c51c502 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -84,6 +84,7 @@ register_injective_schedule("left_shift") register_injective_schedule("shape_of") register_injective_schedule("ndarray_size") +register_injective_schedule("device_copy") register_broadcast_schedule("fast_exp") register_broadcast_schedule("fast_tanh") register_broadcast_schedule("fast_erf") @@ -241,3 +242,4 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("fast_erf", False, elemwise_shape_func) register_shape_func("floor", False, elemwise_shape_func) register_shape_func("log", False, elemwise_shape_func) +register_shape_func("device_copy", False, elemwise_shape_func) diff --git a/python/tvm/relay/transform/memory_alloc.py b/python/tvm/relay/transform/memory_alloc.py index ae7db3384214..e6f17f996bbf 100644 --- a/python/tvm/relay/transform/memory_alloc.py +++ b/python/tvm/relay/transform/memory_alloc.py @@ -19,9 +19,12 @@ A pass for manifesting explicit memory allocations. """ import numpy as np + +from tvm.ir.transform import PassContext, module_pass +from tvm import nd, container +from ..function import Function from ..expr_functor import ExprVisitor, ExprMutator from ..scope_builder import ScopeBuilder -from . import transform from .. import op from ... import DataType, register_func from .. import ty, expr @@ -29,16 +32,32 @@ from ..op.memory import flatten_tuple_type, from_tuple_type, to_tuple_type from ...import cpu from ..op.memory import alloc_storage +from ..analysis import context_analysis +from ..._ffi.runtime_ctypes import TVMContext def alloc_tensor(storage, shape, dtype='float32', assert_shape=None): offset = expr.const(0, dtype="int64") return op.memory.alloc_tensor(storage, offset, shape, dtype, assert_shape) + def is_primitive(call): return hasattr(call, 'op') and hasattr(call.op, 'attrs') and \ hasattr(call.op.attrs, 'Primitive') and int(call.op.attrs.Primitive) == 1 +def is_device_copy(func): + """ + Check if the current relay expression is a device copy call. We can simply check + the body of it if it is a function becase the device_copy op is opaque. + """ + if isinstance(func, Function): + body = func.body + return isinstance(body, expr.Call) and body.op == op.get("device_copy") + if isinstance(func, expr.Call): + return func.op == op.get("device_copy") + return False + + class CheckReshapeOnly(ExprVisitor): """A pass to check if the fused op contains only reshape ops.""" def __init__(self): @@ -66,7 +85,7 @@ def is_reshape_only(func): class ManifestAllocPass(ExprMutator): """A pass for explicitly manifesting all memory allocations in Relay.""" - def __init__(self, target_host): + def __init__(self, target_host, context_analysis_map): self.invoke_tvm = op.vm.invoke_tvm_op self.shape_func = op.vm.shape_func self.shape_of = op.vm.shape_of @@ -75,8 +94,22 @@ def __init__(self, target_host): self.target_host = target_host self.default_context = cpu(0) self.compute_dtype = "int64" + self.context_analysis_map = context_analysis_map super().__init__() + def get_context(self, exp): + """Get the context of a given expression""" + assert exp in self.context_analysis_map, exp.astext(False) + val = self.context_analysis_map[exp] + # val[0], val[1] are device_type and device_id, respectively. + # We don't need to unpack after porting this pass to C++. + assert len(val) == 2 + return TVMContext(val[0].value, val[1].value) + + def device_copy(self, inp, src_ctx, dst_ctx): + """Insert a device copy node.""" + return self.visit(op.tensor.device_copy(inp, src_ctx, dst_ctx)) + def current_scope(self): return self.scopes[-1] @@ -116,7 +149,7 @@ def compute_storage(self, tensor_type): size *= (dtype.bits * dtype.lanes + 7) // 8 return expr.const(size, dtype=self.compute_dtype) - def make_static_allocation(self, scope, tensor_type, i): + def make_static_allocation(self, scope, tensor_type, ctx, name_hint): """Allocate a tensor with a statically known shape.""" shape = [int(sh) for sh in tensor_type.shape] if len(shape) == 0: @@ -126,11 +159,13 @@ def make_static_allocation(self, scope, tensor_type, i): size = self.compute_storage(tensor_type) alignment = self.compute_alignment(tensor_type.dtype) dtype = tensor_type.dtype - sto = scope.let("storage_{0}".format(i), alloc_storage( - size, alignment, self.default_context, dtype)) + sto = scope.let("storage_{0}".format(name_hint), alloc_storage(size, + alignment, + ctx, + dtype)) # TODO(@jroesch): There is a bug with typing based on the constant shape. tensor = alloc_tensor(sto, shape, dtype, tensor_type.shape) - return scope.let("tensor_{0}".format(i), tensor) + return scope.let("tensor_{0}".format(name_hint), tensor) def visit_let(self, let): scope = ScopeBuilder() @@ -156,13 +191,13 @@ def emit_shape_func(self, scope, func, new_args): is_inputs = [] input_pos = 0 + cpu_ctx = nd.cpu(0) for i, (arg, state) in enumerate(zip(new_args, input_states)): state = int(state) # Pass Shapes if state == 2: for j, subexp in enumerate(from_tuple_type(arg.type_annotation, arg)): - let_in_arg = scope.let("in_arg_{0}".format(input_pos + j), subexp) - sh_of = self.visit(self.shape_of(let_in_arg)) + sh_of = self.visit(self.shape_of(subexp)) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos + j), sh_of)) input_pos += 1 @@ -170,6 +205,9 @@ def emit_shape_func(self, scope, func, new_args): # Pass Inputs elif state == 1: new_arg = self.visit(arg) + ctx = self.get_context(arg) + if ctx.device_type != cpu_ctx.device_type: + new_arg = self.device_copy(new_arg, ctx, cpu_ctx) shape_func_ins.append( scope.let("in_shape_{0}".format(input_pos), new_arg)) input_pos += 1 @@ -181,7 +219,9 @@ def emit_shape_func(self, scope, func, new_args): out_shapes = [] for i, out in enumerate(cfunc.outputs): tt = ty.TensorType(out.shape, out.dtype) - alloc = self.make_static_allocation(scope, tt, i) + # Put shape func on CPU. This also ensures that everything between + # shape_of and shape_func are on CPU. + alloc = self.make_static_allocation(scope, tt, cpu_ctx, i) alloc = scope.let("shape_func_out_{0}".format(i), alloc) out_shapes.append(alloc) @@ -198,12 +238,12 @@ def dynamic_invoke(self, scope, func, ins, new_args, out_types, ret_type): out_shapes = self.emit_shape_func(scope, func, new_args) storages = [] + func_ctx = self.get_context(func) for i, (out_shape, out_type) in enumerate(zip(out_shapes, out_types)): - size = self.compute_storage_in_relay( - out_shape, out_type.dtype) + size = self.compute_storage_in_relay(out_shape, out_type.dtype) alignment = self.compute_alignment(out_type.dtype) sto = scope.let("storage_{i}".format(i=i), alloc_storage( - size, alignment, self.default_context, out_type.dtype)) + size, alignment, func_ctx, out_type.dtype)) storages.append(sto) outs = [] @@ -253,6 +293,16 @@ def visit_call(self, call): # Handle fused op that only contains reshape op return self.emit_reshape_tensor(scope, call.op, new_args, ret_type) + if is_device_copy(call.op): + # Handle device copy op + if isinstance(call.op, Function): + attr = call.op.body.attrs + else: + attr = call.attr + return self.device_copy(new_args[0], + TVMContext(attr.src_dev_type, 0), + TVMContext(attr.dst_dev_type, 0)) + if self.is_dynamic(ret_type): # Handle dynamic case. return self.dynamic_invoke(scope, call.op, ins, new_args, out_types, ret_type) @@ -260,7 +310,9 @@ def visit_call(self, call): # Handle static case. outs = [] for i, out_ty in enumerate(out_types): - out = self.make_static_allocation(scope, out_ty, i) + ctx = self.get_context(call) + assert isinstance(ctx, TVMContext) + out = self.make_static_allocation(scope, out_ty, ctx, i) outs.append(out) output = expr.Tuple(outs) @@ -270,19 +322,59 @@ def visit_call(self, call): return super().visit_call(call) -@transform.function_pass(opt_level=0) +def mk_analysis_annotator(results): + """Pretty print the annotated relay program with device info""" + def _annotator(exp): + if exp in results: + val = results[exp] + assert len(val) == 2 + ctx = TVMContext(val[0].value, val[1].value) + return f"<{ctx}>" + else: + return "" + + return _annotator + + +@module_pass(opt_level=0) class ManifestAlloc: """The explicit pass wrapper around ManifestAlloc.""" - def __init__(self, target_host): + # TODO(zhiics, jroesch) Port this pass to C++. + def __init__(self, target_host, targets): self.target_host = target_host + self.targets = targets - def transform_function(self, func, mod, _): + def transform_module(self, mod, _): + """Invokes the pass""" # TODO(@jroesch): Is there a way to do one shot initialization? # can we have def pass_init? mod.import_from_std("core.rly") - ea = ManifestAllocPass(self.target_host) - func = ea.visit(func) - return func + + assert isinstance(self.targets, (dict, container.Map)) + if len(self.targets) > 1: + pass_ctx = PassContext.current() + if "relay.fallback_device_type" in pass_ctx.config: + fallback_ctx = nd.context(pass_ctx.config["relay.fallback_device_type"]) + else: + fallback_ctx = cpu(0) + ca = context_analysis(mod, TVMContext(fallback_ctx.device_type, 0)) + else: + if isinstance(self.targets, dict): + dev = list(self.targets.keys())[0] + else: + dev, _ = self.targets.items()[0] + ca = context_analysis(mod, nd.context(dev.value)) + + # The following code can be used for debugging the module after + # annotation. + # print(mod.astext(show_meta_data=False, annotate=mk_analysis_annotator(ca))) + + gv_funcs = mod.functions + for gv, f in gv_funcs.items(): + ea = ManifestAllocPass(self.target_host, ca) + f = ea.visit(f) + mod.update_func(gv, f) + return mod register_func("relay.transform.ManifestAlloc", ManifestAlloc) diff --git a/python/tvm/relay/transform/memory_plan.py b/python/tvm/relay/transform/memory_plan.py index 8f21af9292a9..248a79ba44de 100644 --- a/python/tvm/relay/transform/memory_plan.py +++ b/python/tvm/relay/transform/memory_plan.py @@ -280,6 +280,13 @@ def process_alloc_storage(self, dynamic_regions, lhs, call): if not isinstance(size, expr.Constant): self.enter_scope() dynamic_regions.append(lhs) + else: + # A new scope is created when entering a new region with different + # device context. + region = self.current_region(dtype) + if region.ctx and region.ctx.device_type != ctx.device_type: + self.enter_scope() + dynamic_regions.append(lhs) region = self.current_region(dtype) region.grow(lhs, size, alignment, ctx, dtype) diff --git a/python/tvm/runtime/vm.py b/python/tvm/runtime/vm.py index f88f43d838d5..fbc7a7d7b71e 100644 --- a/python/tvm/runtime/vm.py +++ b/python/tvm/runtime/vm.py @@ -307,8 +307,17 @@ def __init__(self, exe, ctx, memory_cfg=None): def _setup_ctx(self, ctx, memory_cfg): """Init context and allocators.""" - if isinstance(ctx, tvm.runtime.TVMContext): - ctx = [ctx] + ctxs = ctx + if not isinstance(ctx, (list, tuple)): + if not isinstance(ctx, tvm.runtime.TVMContext): + raise TypeError("ctx is expected to be TVMContext or \ + List[TVMContext]") + ctxs = [ctx] + + # CPU is required for executing shape functions + if not any(c.device_type == tvm.cpu().device_type for c in ctxs): + ctxs.append(tvm.cpu()) + default_alloc_type = VirtualMachine.POOLED_ALLOCATOR if memory_cfg is None: memory_cfg = {} @@ -321,7 +330,7 @@ def _setup_ctx(self, ctx, memory_cfg): raise TypeError("memory_cfg is expected be string or dictionary, " + "but received {}".format(type(memory_cfg))) init_args = [] - for context in ctx: + for context in ctxs: init_args.append(context.device_type) init_args.append(context.device_id) alloc_type = memory_cfg[context] if context in memory_cfg else default_alloc_type diff --git a/src/ir/module.cc b/src/ir/module.cc index bcab39aabf32..66bce0f6b882 100644 --- a/src/ir/module.cc +++ b/src/ir/module.cc @@ -448,6 +448,9 @@ TVM_REGISTER_GLOBAL("ir.Module_Update").set_body_typed([](IRModule mod, IRModule mod->Update(from); }); +TVM_REGISTER_GLOBAL("ir.Module_UpdateFunction") + .set_body_typed([](IRModule mod, GlobalVar gv, BaseFunc func) { mod->Update(gv, func); }); + TVM_REGISTER_GLOBAL("ir.Module_Import").set_body_typed([](IRModule mod, String path) { mod->Import(path); }); diff --git a/src/relay/analysis/context_analysis.cc b/src/relay/analysis/context_analysis.cc new file mode 100644 index 000000000000..bbea0399c117 --- /dev/null +++ b/src/relay/analysis/context_analysis.cc @@ -0,0 +1,720 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/analysis/context_analysis.cc + * \brief A pass for analyzing device attribute of each IR node. + * + * We use union-find data structures to analyze the context information of each + * sub-expression in a Relay program in this pass. Only the device copy node in + * Relay directly contains bidiretional device information. We use it to + * bidirectionally propagate the device info of its inputs and outputs. + * + * However, to support dynamism (e.g dynamic inputs), Relay introduces several + * concepts to compute the shape of tensors and operators at runtime, i.e. + * shape_of, shape_func, and reshape_tensor. These nodes are also referred to as + * VM dialects as we have native VM instructions for them. These dialects are + * intrinsically CPU friendly, therefore, they are only designed to be + * executed on CPU. We, hence, unify their inputs and outputs to CPU as well. + * Note the input of shape_of is a tensor and we only need the tensor shape. + * Therefore, the input could be sitting on GPU as well since no real data is + * needed. The context of the input would be propagated from its other + * consumers or fallback to the default device. + * + * Another type of dialect is used fo memory allocation, namely, alloc_storage + * and alloc_tensor. alloc_storage contains a context field to indicate where + * the chunk of memory is allocated. Therefore, we unify the context of + * alloc_storage with the context field. Other inputs, such as size and + * alignment, are left on CPU. + * + * Based on the above rules, we keep unifying the connected expressions and + * propagating their device information. An error will be raised whenever there + * is a unification conflict. All IR nodes that are not propagated with device + * context will fallback to the specified device. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace tvm { +namespace relay { + +using PackedAnalysisResultMap = Map>; +using AnalysisResultMap = + std::unordered_map; + +namespace analysis { + +// Cache ops +static const Op& device_copy_op = Op::Get("device_copy"); +static const Op& alloc_storage_op = Op::Get("memory.alloc_storage"); +static const Op& alloc_tensor_op = Op::Get("memory.alloc_tensor"); +static const Op& shape_of_op = Op::Get("vm.shape_of"); +static const Op& invoke_tvm_op = Op::Get("vm.invoke_tvm_op"); +static const Op& shape_func_of = Op::Get("vm.shape_func"); +static const Op& reshape_tensor_op = Op::Get("vm.reshape_tensor"); + +class DeviceDomain; +using DeviceDomainPtr = std::shared_ptr; + +/* + * \brief A class to represent the device of a domain, i.e. a segment of relay program. + */ +class DeviceDomain { + public: + // Construct an empty domain. + DeviceDomain() { + ctx_.device_type = static_cast(-1); + ctx_.device_id = -1; + } + + // Construct a domain based on a given context. + explicit DeviceDomain(const TVMContext& ctx) : ctx_(ctx) {} + + // Check if the current domain is empty. + bool IsEmptyDomain() const { + return static_cast(ctx_.device_type) == -1 && ctx_.device_id == -1; + } + + // Check if the current domain equals the other one. + bool operator==(const DeviceDomain& other) const { + return ctx_.device_type == other.ctx_.device_type && ctx_.device_id == other.ctx_.device_id; + } + + bool operator!=(const DeviceDomain& other) const { return !(*this == other); } + + private: + // Create a hash for a domain. + struct Hash { + size_t operator()(const DeviceDomainPtr& domain) const { + if (domain->IsEmptyDomain()) { + return (size_t)(domain.get()); + } else { + size_t const h1(std::hash()(static_cast(domain->ctx_.device_type))); + size_t const h2(std::hash()(domain->ctx_.device_id)); + return h1 ^ (h2 << 1); + } + } + }; + + // Create an equality for domains. + struct Equal { + public: + bool operator()(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) const { + // We compare the pointer for empty domains. + if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) return lhs.get() == rhs.get(); + + // Otherwise device type and id are used to check equality. + return (*lhs.get() == *rhs.get()); + } + }; + + /* \brief The device to be assigned to the current domain. */ + TVMContext ctx_; + + friend DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs); + friend class ContextAnalyzer; +}; + +// Join two domains. +DeviceDomainPtr Join(const DeviceDomainPtr& lhs, const DeviceDomainPtr& rhs) { + if (lhs->IsEmptyDomain() && rhs->IsEmptyDomain()) { + return lhs; + } else if (lhs->IsEmptyDomain()) { + return rhs; + } else if (rhs->IsEmptyDomain()) { + return lhs; + } else { + CHECK(*lhs.get() == *rhs.get()) << "All expressions must have a singular device to unify"; + return lhs; + } +} + +/* + * \brief Compute on which device each sub-expression will execute. A union find + * algorithm is used to assign and merge the context domains. + */ +class ContextAnalyzer : public ExprVisitor { + public: + ContextAnalyzer(const IRModule& mod, const GlobalVar& current_func, + const TVMContext& default_context) + : mod_(mod), current_func_(current_func), default_context_(default_context) { + cpu_ctx_.device_type = kDLCPU; + cpu_ctx_.device_id = 0; + } + + // Create an empty domain. + // This usually happens when we enter a new scope, i.e. Function. + DeviceDomainPtr Bottom() { return std::make_shared(DeviceDomain()); } + + // Create a domain with the given device context. + DeviceDomainPtr DeviceType(const TVMContext& ctx) { + return std::make_shared(DeviceDomain(ctx)); + } + + // Find the root of a device. + DeviceDomainPtr Lookup(DeviceDomainPtr device) { + while (device_uf_.count(device) && device != device_uf_[device]) { + // Path compression + if (device_uf_.count(device_uf_[device])) { + device_uf_[device] = device_uf_[device_uf_[device]]; + } + device = device_uf_[device]; + } + return device; + } + + // Unify two domains. + DeviceDomainPtr Unify(DeviceDomainPtr lhs, DeviceDomainPtr rhs) { + lhs = Lookup(lhs); + rhs = Lookup(rhs); + auto unified_device = Join(lhs, rhs); + if (lhs != unified_device) { + device_uf_[lhs] = unified_device; + } + + if (rhs != unified_device) { + device_uf_[rhs] = unified_device; + } + + return unified_device; + } + + // Unify the domain for two IR nodes. + DeviceDomainPtr UnifyExpr(const Expr& lhs, const Expr& rhs) { + auto lhs_dom = DeviceFor(lhs); + auto rhs_dom = DeviceFor(rhs); + return Unify(lhs_dom, rhs_dom); + } + + // Lookup or insert an IR node to device domain map. + DeviceDomainPtr DeviceFor(const Expr& expr) { + auto it = expr_to_device_.find(expr); + if (it == expr_to_device_.end()) { + auto bottom = Bottom(); + expr_to_device_[expr] = bottom; + return bottom; + } else { + return it->second; + } + } + + // Unify the device context for a device copy node. Device copy node is + // the only node that carries bidirectional devices in the input program. The device + // attribute of other nodes can be propagated from it. + void UnifyDeviceCopy(const std::vector& inps, const std::vector& outputs, + DLDeviceType src_dev_type, DLDeviceType dst_dev_type) { + TVMContext src_ctx; + src_ctx.device_type = src_dev_type; + src_ctx.device_id = 0; + auto src_domain = DeviceType(src_ctx); + for (const auto& it : inps) { + auto lhs = DeviceFor(it); + Unify(lhs, src_domain); + } + + TVMContext dst_ctx; + dst_ctx.device_type = dst_dev_type; + dst_ctx.device_id = 0; + auto dst_domain = DeviceType(dst_ctx); + for (const auto& it : outputs) { + auto lhs = DeviceFor(it); + Unify(lhs, dst_domain); + } + } + + // Unify the domain of inputs and outputs of a relay call. + // + // For most call nodes, the op, inputs, and outputs should all be in the + // same domain, i.e. having the same context. However, device_copy call node + // needs to be handled differently as it copies data from one device to + // another. + DeviceDomainPtr UnifyCall(const Expr& call_op, const Array& inps, + const Array& outputs, DeviceDomainPtr device) { + device = Unify(device, DeviceFor(call_op)); + + for (const auto& it : inps) { + device = Unify(device, DeviceFor(it)); + } + + for (const auto& it : outputs) { + device = Unify(device, DeviceFor(it)); + } + + return device; + } + + void VisitExpr_(const CallNode* cn) final { + Call call = GetRef(cn); + + if (IsDeviceCopy(call)) { + UnifyDeviceCopyCall(cn); + } else if (call->op == alloc_storage_op) { + UnifyAllocStorageCall(cn); + } else if (call->op == alloc_tensor_op) { + UnifyAllocTensorCall(cn); + } else if (call->op == shape_func_of) { + UnifyShapeFuncCall(cn); + } else if (call->op == shape_of_op) { + UnifyShapeOfCall(cn); + } else if (call->op == invoke_tvm_op) { + UnifyInvokeTVMOpCall(cn); + } else if (call->op == reshape_tensor_op) { + UnifyReshapeTensorCall(cn); + } else if (call->op.as()) { + UnifyFunctionCall(cn); + } else if (call->op.as()) { + UnifyGlobalVarCall(cn); + } else if (call->op.as()) { + UnifyVarCall(cn); + } else { + UnifyCall(call, cn->args, {call}, Bottom()); + ExprVisitor::VisitExpr_(cn); + } + } + + void VisitExpr_(const LetNode* ln) final { + Expr expr = GetRef(ln); + // Iteratively visit let nodes to avoid stack overflow. + while (expr->IsInstance()) { + Let let = Downcast(expr); + // Save currying/closures since they will be invoked later + auto ty = let->value->checked_type(); + if (ty->IsInstance()) { + auto gv = ExtractClosure(let); + CHECK(gv.defined() && gv->IsInstance()); + closures_[let->var] = Downcast(gv); + } + + // Unify let var, value, and body + Unify(DeviceFor(let->var), DeviceFor(let->value)); + UnifyExpr(let, let->body); + ExprVisitor::VisitExpr(let->value); + expr = let->body; + } + // Visit the last body + ExprVisitor::VisitExpr(expr); + } + + void VisitExpr_(const FunctionNode* fn) final { + auto func = GetRef(fn); + auto it = visited_.find(func); + // No need to step into fused primitive functions as they are handled as + // a whole. + if (fn->HasNonzeroAttr(attr::kPrimitive) || + (it != visited_.end() && !DeviceFor(func)->IsEmptyDomain())) { + return; + } + + auto device = Unify(DeviceFor(func), DeviceFor(fn->body)); + for (const auto& it : fn->params) { + DeviceFor(it); + } + ExprVisitor::VisitExpr(fn->body); + visited_.insert(func); + } + + void VisitExpr_(const TupleNode* tn) final { + // We only support tuple with the same of device. + Tuple tup = GetRef(tn); + if (tn->fields.size() > 0) { + auto device = DeviceFor(tup->fields[0]); + for (size_t i = 1; i < tup->fields.size(); i++) { + device = Unify(device, DeviceFor(tup->fields[i])); + } + Unify(device, DeviceFor(tup)); + } + ExprVisitor::VisitExpr_(tn); + } + + void VisitExpr_(const TupleGetItemNode* tn) final { + TupleGetItem item = GetRef(tn); + + Unify(DeviceFor(item), DeviceFor(item->tuple)); + + ExprVisitor::VisitExpr_(tn); + } + + void VisitExpr_(const MatchNode* mn) final { + // For match node, we unify the value and the rhs of each clause + Match m = GetRef(mn); + auto device = Unify(DeviceFor(m), DeviceFor(m->data)); + for (const auto& c : m->clauses) { + device = Unify(device, DeviceFor(c->rhs)); + } + ExprVisitor::VisitExpr_(mn); + } + + void VisitExpr_(const GlobalVarNode* gvn) final { DeviceFor(GetRef(gvn)); } + + void VisitExpr_(const VarNode* vn) { DeviceFor(GetRef(vn)); } + + void VisitExpr_(const ConstantNode* cn) final { DeviceFor(GetRef(cn)); } + + // Return the analysis results. + AnalysisResultMap Results() { + AnalysisResultMap ret; + for (const auto& it : expr_to_device_) { + auto device = Lookup(it.second); + if (device->IsEmptyDomain()) { + ret[it.first] = default_context_; + } else { + ret[it.first] = device->ctx_; + } + } + + return ret; + } + + private: + Expr ExtractClosure(Expr expr) const { + while (expr->IsInstance()) { + Let let = Downcast(expr); + expr = let->value; + if (expr->IsInstance()) { + return expr; + } else { + const auto* cn = expr.as(); + if (cn && cn->op->IsInstance()) { + return cn->op; + } + } + } + return Expr(nullptr); + } + + // Check if an expression is a device copy call. + bool IsDeviceCopy(const Expr& expr) const { + if (!expr->IsInstance()) return false; + + Call call = Downcast(expr); + if (call->op == device_copy_op) return true; + + // Fused function with device copy op as the body + // device copy op is opaque therefore the fused function only has one node. + if (const FunctionNode* fn = call->op.as()) { + if (const CallNode* cn = fn->body.as()) { + return cn->op == device_copy_op; + } + } + + return false; + } + + // Check if a function is a closure. + bool IsClosure(const Function& func) { return func->GetAttr(attr::kClosure, 0) != 0; } + + // Check if a function is a currying function. + bool IsCurrying(const Function& func) { + if (const auto* let = func->body.as()) { + return closures_.find(let->var) != closures_.end(); + } + return false; + } + + // Process device copy call node + void UnifyDeviceCopyCall(const CallNode* call) { + CHECK_EQ(call->args.size(), 1U); + + std::vector inps{call->args[0]}; + std::vector outs{GetRef(call)}; + DLDeviceType src_dev_type, dst_dev_type; + const DeviceCopyAttrs* attrs = nullptr; + if (const auto* fn = call->op.as()) { + // device_copy is fused, propagate device to the fused function. + inps.push_back(fn->params[0]); + outs.push_back(call->op); + Expr body = fn->body; + CHECK(body->IsInstance() && IsDeviceCopy(body)); + Call call_body = Downcast(body); + attrs = call_body->attrs.as(); + } else { + attrs = call->attrs.as(); + } + CHECK(attrs != nullptr); + src_dev_type = static_cast(attrs->src_dev_type); + dst_dev_type = static_cast(attrs->dst_dev_type); + + // Device copy op only has one input which is now annotated with the + // same device to the source device type of the device copy op. + // The call itself has the same device type to the destination. + UnifyDeviceCopy(inps, outs, src_dev_type, dst_dev_type); + ExprVisitor::VisitExpr_(call); + } + + void UnifyAllocStorageCall(const CallNode* call) { + // [size, alignment] + CHECK_EQ(call->args.size(), 2U); + + // The arguments of alloc storage should be on CPU. + for (int i = 0; i < 2; i++) { + Unify(DeviceFor(call->args[i]), DeviceType(cpu_ctx_)); + ExprVisitor::VisitExpr(call->args[i]); + } + TVMContext ctx; + const auto* attrs = call->attrs.as(); + ctx.device_type = static_cast(attrs->device_type); + ctx.device_id = attrs->device_id; + Unify(DeviceFor(GetRef(call)), DeviceType(ctx)); + } + + void UnifyAllocTensorCall(const CallNode* call) { + // [storage, offset, shape] + CHECK_EQ(call->args.size(), 3U); + + Expr storage = call->args[0]; + Expr shape = call->args[1]; + Unify(DeviceFor(storage), DeviceFor(GetRef(call))); + + // The shape for alloc_tensor should be on CPU. + Unify(DeviceFor(shape), DeviceType(cpu_ctx_)); + ExprVisitor::VisitExpr(shape); + } + + void UnifyShapeFuncCall(const CallNode* call) { + // [func, inputs, outputs] + CHECK_EQ(call->args.size(), 3U); + auto shape_func_domain = DeviceType(cpu_ctx_); + + // No need to unify the op of a shape_func as shape_func doesn't + // invoke the op itself. It should be handled by invoke_tvm_op. + // Therefore, we skip call.args[0] here. + Tuple inps = Downcast(call->args[1]); + Tuple outputs = Downcast(call->args[2]); + UnifyCall(GetRef(call), inps->fields, outputs->fields, shape_func_domain); + for (const auto& it : inps->fields) { + ExprVisitor::VisitExpr(it); + } + + for (const auto& it : outputs->fields) { + ExprVisitor::VisitExpr(it); + } + } + + void UnifyInvokeTVMOpCall(const CallNode* call) { + // [op, inputs, outputs] + CHECK_EQ(call->args.size(), 3U); + Tuple inps = Downcast(call->args[1]); + Tuple outputs = Downcast(call->args[2]); + UnifyCall(call->args[0], inps->fields, outputs->fields, Bottom()); + ExprVisitor::VisitExpr_(call); + } + + void UnifyShapeOfCall(const CallNode* call) { + // vm shape_of is always on the CPU. + CHECK_EQ(call->args.size(), 1U); + ExprVisitor::VisitExpr(call->args[0]); + // Note we don't unify the input of a shape_of with the cpu domain. This is + // because vm.shape_of has a native instruction to compute the shape of + // a tensor regardless its device type. + // Instead, the device type of the input is left for its other consumers to + // unify or it will fallback to the default context. + Unify(DeviceFor(GetRef(call)), DeviceType(cpu_ctx_)); + } + + void UnifyReshapeTensorCall(const CallNode* call) { + // [data, shape] + CHECK_EQ(call->args.size(), 2U); + Expr data = call->args[0]; + Expr shape = call->args[1]; + Unify(DeviceFor(GetRef(call)), DeviceFor(data)); + + // The shape field of reshape_tensor is always on the CPU. + Unify(DeviceFor(shape), DeviceType(cpu_ctx_)); + ExprVisitor::VisitExpr(data); + ExprVisitor::VisitExpr(shape); + } + + void UnifyFunctionCall(const CallNode* call) { + auto device = DeviceFor(GetRef(call)); + // Unify the arguments of the caller. + for (const auto& arg : call->args) { + device = Unify(device, DeviceFor(arg)); + ExprVisitor::VisitExpr(arg); + } + + // Unify the parameters of the callee. + if (!call->op->IsInstance()) return; + Function func = Downcast(call->op); + for (const auto& param : func->params) { + device = Unify(device, DeviceFor(param)); + ExprVisitor::VisitExpr(param); + } + + // Unify the function expression and its body + Unify(device, DeviceFor(call->op)); + Unify(device, DeviceFor(func->body)); + + // Step into the callee. It will be skipped if the callee if a primitive + // function + ExprVisitor::VisitExpr(call->op); + } + + // Invoke a global function. + void UnifyGlobalVarCall(const CallNode* call) { + auto device = DeviceFor(GetRef(call)); + CHECK(mod_.defined()) << "Cannot analyze context on a globalvar without module"; + GlobalVar gv = Downcast(call->op); + auto func = Downcast(mod_->Lookup(gv)); + CHECK_EQ(call->args.size(), func->params.size()) + << "The number of arguments doesn't match the number of parameters of the function."; + + for (size_t i = 0; i < call->args.size(); i++) { + Expr arg = call->args[i]; + Expr param = func->params[i]; + ExprVisitor::VisitExpr(arg); + + // Save the the arg to function mapping for closures as it will + // be invoked/unified later. + CHECK(arg->checked_type().defined()) + << "Type inference is required to run the context analysis passes."; + if (arg->checked_type()->IsInstance()) { + auto it = closures_.find(arg); + if (it != closures_.end()) { + closures_[param] = it->second; + } else { + CHECK(arg->IsInstance()); + closures_[param] = Downcast(arg); + } + } + Unify(DeviceFor(arg), DeviceFor(param)); + } + device = Unify(device, DeviceFor(call->op)); + device = Unify(device, DeviceFor(func)); + device = Unify(device, DeviceFor(func->body)); + + // Step into the callee. We need to skip recursive calls, otherwise, it + // would be a infinite loop. + // + // TODO(@zhiics) This may cause problem for mutual recursive calls as well. + auto cur_func = current_func_; + current_func_ = gv; + if (cur_func->name_hint != gv->name_hint) { + ExprVisitor::VisitExpr(func); + visited_.insert(func); + } + // Exit the frame. + current_func_ = cur_func; + } + + void UnifyVarCall(const CallNode* call) { + // It is a closure when we call a var. + // Unify the corresponding arguement and parameter. + auto device = DeviceFor(GetRef(call)); + auto it = closures_.find(call->op); + CHECK(it != closures_.end()) << "Cannot find var: " << call->op; + auto glb_var = it->second; + CHECK(mod_.defined()) << "Cannot analyze context on a globalvar without module"; + Function func = Downcast(mod_->Lookup(glb_var)); + // Unify the underlying function for clousre or currying funcitons. + while (IsClosure(func) || IsCurrying(func)) { + device = Unify(device, DeviceFor(func)); + if (IsClosure(func)) { + func = Downcast(func->body); + } else if (IsCurrying(func)) { + Let let = Downcast(func->body); + func = Downcast(mod_->Lookup(closures_[let->var])); + } else { + LOG(FATAL) << "func is expected to be a closure or a currying funciton"; + } + } + + CHECK_EQ(call->args.size(), func->params.size()); + for (size_t i = 0; i < call->args.size(); i++) { + Unify(DeviceFor(call->args[i]), DeviceFor(func->params[i])); + ExprVisitor::VisitExpr(call->args[i]); + } + device = Unify(device, DeviceFor(call->op)); + device = Unify(device, DeviceFor(glb_var)); + device = Unify(device, DeviceFor(func)); + + // Step into the global function. + auto cur_func = current_func_; + current_func_ = glb_var; + if (cur_func->name_hint != glb_var->name_hint) { + ExprVisitor::VisitExpr(func); + visited_.insert(func); + } + current_func_ = cur_func; + } + + private: + /* \brief The cpu context. */ + TVMContext cpu_ctx_; + /* \brief The module that helps context analysis. */ + const IRModule& mod_; + /* \brief The current function that is being analyzed. */ + GlobalVar current_func_; + /* \brief The default device that could be attached to an expression. */ + const TVMContext& default_context_; + /* \brief The IR node to device domain mapping. */ + std::unordered_map + expr_to_device_; + /* \brief The domain map for union-find. */ + std::unordered_map + device_uf_; + /* + * \brief The expr to global var map. It saves the closures/currying that + * will be invoked lazily. + */ + std::unordered_map closures_; + /* \brief Cache the visited functions. */ + std::unordered_set visited_; +}; + +} // namespace analysis + +AnalysisResultMap ContextAnalysis(const IRModule& mod, const TVMContext& default_context) { + // TODO(@zhiics) Apply the pass to all functions/entries + auto entry = mod->GetGlobalVar("main"); + auto ca = analysis::ContextAnalyzer(mod, entry, default_context); + auto expr = mod->Lookup(entry); + ca.VisitExpr(expr); + return ca.Results(); +} + +// Unpack the device type and deivce id fields in TVMContext for PackedFunc calls +// as TVMContext is not in the object system. +PackedAnalysisResultMap ContextAnalysisPacked(const IRModule& mod, + const TVMContext& default_context) { + PackedAnalysisResultMap ret; + auto res = ContextAnalysis(mod, default_context); + for (const auto& it : res) { + Integer dev_ty = static_cast(it.second.device_type); + Integer dev_id = it.second.device_id; + ret.Set(it.first, {dev_ty, dev_id}); + } + + return ret; +} + +TVM_REGISTER_GLOBAL("relay.analysis.ContextAnalysis").set_body_typed(ContextAnalysisPacked); + +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 33854f783d45..18b23c42c6ea 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -26,6 +26,8 @@ #include #include +#include +#include #include #include #include @@ -56,10 +58,10 @@ namespace transform { Pass LambdaLift(); Pass InlinePrimitives(); -Pass ManifestAlloc(Target target_host) { +Pass ManifestAlloc(Target target_host, vm::TargetsMap targets) { auto f = tvm::runtime::Registry::Get("relay.transform.ManifestAlloc"); CHECK(f != nullptr) << "unable to load allocation manifestation pass"; - return (*f)(target_host); + return (*f)(target_host, targets); } Pass MemoryPlan() { @@ -228,15 +230,41 @@ std::vector ToAllocTensorShape(NDArray shape) { return raw_shape; } +/*! + * \brief Create a default type. + * \param device_type The device type index. + * \return the default target for the device. + */ +Target CreateDefaultTarget(int device_type) { + std::string name = runtime::DeviceName(device_type); + if (name == "cpu") return Target::Create("llvm"); + if (name == "gpu") return Target::Create("cuda"); + return Target::Create(name); +} + +int GetFallbackDevice() { + transform::PassContext pass_ctx = PassContext::Current(); + Optional opt_fallback_dev = + pass_ctx->GetConfig("relay.fallback_device_type", Integer(static_cast(kDLCPU))); + auto fallback_dev = opt_fallback_dev.value(); + CHECK_GT(fallback_dev->value, 0U); + return fallback_dev->value; +} + class VMFunctionCompiler : ExprFunctor { public: - VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host) + VMFunctionCompiler(VMCompilerContext* context, TargetsMap targets, Target target_host, + ExprDeviceMap expr_device_map) : last_register_(0), registers_num_(0), engine_(CompileEngine::Global()), context_(context), - targets_(targets), - target_host_(target_host) {} + target_host_(target_host), + expr_device_map_(std::move(expr_device_map)) { + for (const auto& it : targets) { + targets_[it.first->value] = it.second; + } + } VMFunction Compile(const GlobalVar& var, const Function& func) { size_t i = 0; @@ -263,7 +291,19 @@ class VMFunctionCompiler : ExprFunctor { this->VisitExpr(func->body); } instructions_.push_back(Instruction::Ret(last_register_)); - return VMFunction(var->name_hint, params_, instructions_, registers_num_); + + std::vector params_device_type; + for (const auto& it : func->params) { + if (!expr_device_map_.empty()) { + CHECK_GT(expr_device_map_.count(it), 0U); + params_device_type.push_back(expr_device_map_[it].device_type); + } else { + CHECK_EQ(targets_.size(), 1U); + params_device_type.push_back((targets_.begin())->first); + } + } + + return VMFunction(var->name_hint, params_, instructions_, registers_num_, params_device_type); } protected: @@ -287,6 +327,7 @@ class VMFunctionCompiler : ExprFunctor { case Opcode::ReshapeTensor: case Opcode::Move: case Opcode::InvokeClosure: + case Opcode::DeviceCopy: last_register_ = instr.dst; break; case Opcode::InvokePacked: @@ -310,6 +351,13 @@ class VMFunctionCompiler : ExprFunctor { } } size_t konst_idx = context_->constants.size(); + if (expr_device_map_.empty()) { + context_->const_device_type.push_back(targets_.begin()->first); + } else { + auto con = GetRef(const_node); + CHECK_GT(expr_device_map_.count(con), 0U); + context_->const_device_type.push_back(expr_device_map_[con].device_type); + } context_->constants.push_back(const_node->data); Emit(Instruction::LoadConst(konst_idx, NewRegister())); } @@ -477,13 +525,21 @@ class VMFunctionCompiler : ExprFunctor { target = tvm::target::ext_dev(); } else { // Next generate the invoke instruction. - if (targets_.size() == 1) { + if (expr_device_map_.empty()) { // homogeneous execution. + CHECK_EQ(targets_.size(), 1U); const auto& it = targets_.begin(); target = (*it).second; } else { - // heterogeneous execution. - LOG(FATAL) << "Currently VM compiler doesn't support heterogeneous compilation"; + CHECK_GT(expr_device_map_.count(func), 0U) + << "Found not annotated expression, please make sure " + "context analysis has been executed"; + int dev_type = expr_device_map_[func].device_type; + if (targets_.count(dev_type) == 0) { + target = CreateDefaultTarget(dev_type); + } else { + target = targets_[expr_device_map_[func].device_type]; + } } } @@ -561,7 +617,8 @@ class VMFunctionCompiler : ExprFunctor { } }) .Match("memory.alloc_storage", - [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + [this, call_node](const Array& args, const Attrs& attrs, + const Array& type_arg) { CHECK_EQ(args.size(), 2); // Compute the size of the allocation. this->VisitExpr(args[0]); @@ -577,10 +634,23 @@ class VMFunctionCompiler : ExprFunctor { // Get the dtype hint from the attributes. auto alloc_attrs = attrs.as(); - CHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs"; + CHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs"; auto dtype = alloc_attrs->dtype; - Emit(Instruction::AllocStorage(size_register, alignment, dtype, NewRegister())); + Index device_type; + if (expr_device_map_.empty()) { + // TODO(zhiics) There is bug if all expressions are annotated with the device + // that is different the first one in the target list. + auto& kv = *(targets_.begin()); + device_type = kv.first; + } else { + CHECK_GT(expr_device_map_.count(GetRef(call_node)), 0U) + << " The alloc_storage node is not annotated"; + device_type = expr_device_map_[GetRef(call_node)].device_type; + } + + Emit(Instruction::AllocStorage(size_register, alignment, dtype, device_type, + NewRegister())); }) .Match("vm.shape_func", [this](const Array& args, const Attrs& attrs, const Array& type_arg) { @@ -611,6 +681,19 @@ class VMFunctionCompiler : ExprFunctor { auto shape_reg = last_register_; Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister())); }) + .Match("device_copy", + [this](const Array& args, const Attrs& attrs, const Array& type_arg) { + CHECK_EQ(args.size(), 1U); + this->VisitExpr(args[0]); + auto src_reg = last_register_; + + auto device_copy_attrs = attrs.as(); + CHECK(device_copy_attrs != nullptr) << "Must be the device copy attrs"; + Index src_device_type = device_copy_attrs->src_dev_type; + Index dst_device_type = device_copy_attrs->dst_dev_type; + Emit(Instruction::DeviceCopy(src_reg, src_device_type, dst_device_type, + NewRegister())); + }) .Match("memory.kill", [](const Array& args, const Attrs& attrs, const Array& type_arg) { LOG(FATAL) << "memory.kill is not yet supported"; @@ -769,9 +852,11 @@ class VMFunctionCompiler : ExprFunctor { /*! \brief Global shared meta data */ VMCompilerContext* context_; /*! \brief Target devices. */ - TargetsMap targets_; + std::unordered_map targets_; /*! \brief Host target. */ Target target_host_; + /*! \brief Map from Relay expr to device type. */ + ExprDeviceMap expr_device_map_; }; PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) { @@ -820,7 +905,6 @@ void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { } void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host) { - CHECK_EQ(targets.size(), 1) << "Currently VM compiler doesn't support heterogeneous compilation"; if (params_.size()) { BaseFunc base_func = mod->Lookup("main"); CHECK(base_func->IsInstance()) @@ -847,11 +931,15 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe // the global state. exec_->functions.resize(context_.module->functions.size()); + // Collect the annotated device information. + // This indicates which device each Relay expr should be executed on. + ExprDeviceMap expr_device_map = AnalyzeContext(); + for (auto named_func : context_.module->functions) { auto gvar = named_func.first; if (auto* n = named_func.second.as()) { auto func = GetRef(n); - VMFunctionCompiler func_compiler(&context_, targets_, target_host_); + VMFunctionCompiler func_compiler(&context_, targets_, target_host_, expr_device_map); auto vm_func = func_compiler.Compile(gvar, func); size_t func_index = context_.global_map.at(gvar); @@ -871,6 +959,10 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe exec_->constants.push_back(data); } + for (auto i : context_.const_device_type) { + exec_->const_device_type.push_back(i); + } + // update global function map for (auto gv : context_.global_map) { exec_->global_map.insert({gv.first->name_hint, gv.second}); @@ -883,10 +975,10 @@ void VMCompiler::Lower(IRModule mod, const TargetsMap& targets, const tvm::Targe } } -transform::Sequential MemoryOpt(tvm::Target host_target) { +transform::Sequential MemoryOpt(tvm::Target host_target, TargetsMap targets) { Array pass_seqs; // Manifest the allocations. - pass_seqs.push_back(transform::ManifestAlloc(host_target)); + pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -895,7 +987,7 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { pass_seqs.push_back(transform::FuseOps()); // Manifest the allocations needed for the shape functions. - pass_seqs.push_back(transform::ManifestAlloc(host_target)); + pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); // Fuse the shape functions. pass_seqs.push_back(transform::FuseOps()); @@ -910,7 +1002,7 @@ transform::Sequential MemoryOpt(tvm::Target host_target) { pass_seqs.push_back(transform::FuseOps()); // Create allocations for math introduced by dynamic region math. - pass_seqs.push_back(transform::ManifestAlloc(host_target)); + pass_seqs.push_back(transform::ManifestAlloc(host_target, targets)); // Compute away possibly introduced constant computation. pass_seqs.push_back(transform::FoldConstant()); @@ -977,6 +1069,12 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe pass_seqs.push_back(transform::FastMath()); pass_seqs.push_back(transform::FoldConstant()); + if (targets_.size() > 1) { + // Handle heterogeneous compilation. + int fallback_dev = GetFallbackDevice(); + pass_seqs.push_back(transform::RewriteAnnotatedOps(fallback_dev)); + } + pass_seqs.push_back(transform::FuseOps()); pass_seqs.push_back(transform::ToANormalForm()); pass_seqs.push_back(transform::LambdaLift()); @@ -989,11 +1087,10 @@ IRModule VMCompiler::OptimizeModule(const IRModule& mod, const TargetsMap& targe // external codegen. pass_seqs.push_back(transform::Inline()); - pass_seqs.push_back(MemoryOpt(target_host)); + pass_seqs.push_back(MemoryOpt(target_host, targets)); transform::Sequential seq(pass_seqs); transform::PassContext pass_ctx = PassContext::Current(); - // TODO(wweic): Support heterogenous execution tvm::With ctx(pass_ctx); if (targets.size() == 1) { const auto& it = targets.begin(); @@ -1061,6 +1158,25 @@ void VMCompiler::Codegen() { } } +ExprDeviceMap VMCompiler::AnalyzeContext() const { + TVMContext default_device; + ExprDeviceMap expr_device_map; + if (targets_.size() > 1) { + int fallback_dev = GetFallbackDevice(); + default_device.device_type = static_cast(fallback_dev); + default_device.device_id = 0; + expr_device_map = ContextAnalysis(context_.module, default_device); + } else { + const auto& tgt = targets_.begin(); + default_device.device_type = static_cast((*tgt).first->value); + if (default_device.device_type != kDLCPU) { + default_device.device_id = 0; + expr_device_map = ContextAnalysis(context_.module, default_device); + } + } + return expr_device_map; +} + runtime::Module CreateVMCompiler() { auto exec = make_object(); return runtime::Module(exec); diff --git a/src/relay/backend/vm/compiler.h b/src/relay/backend/vm/compiler.h index b4b86d3d6d8e..19924ab38358 100644 --- a/src/relay/backend/vm/compiler.h +++ b/src/relay/backend/vm/compiler.h @@ -62,6 +62,7 @@ using GlobalMap = NodeMap; using ConstMap = NodeMap; using ConstTensorShapeMap = NodeMap>; using TargetsMap = Map; +using ExprDeviceMap = std::unordered_map; struct VMCompilerContext { // The module context for the compilation @@ -76,6 +77,8 @@ struct VMCompilerContext { GlobalMap global_map; // List of constants std::vector constants; + // Device type for constants + std::vector const_device_type; // List of cached functions std::vector cached_funcs; // The functions that have been lowered. @@ -103,7 +106,7 @@ class VMCompiler : public runtime::ModuleNode { * * \param mod Relay Module * \param targets For heterogeneous compilation, it is a dictionary indicating context - to target mapping. For homogeneous compilation, it is a build target. + * to target mapping. For homogeneous compilation, it is a build target. * \param target_host Host compilation target, if target is device. */ void Lower(IRModule mod, const TargetsMap& targets, const tvm::Target& target_host); @@ -112,11 +115,28 @@ class VMCompiler : public runtime::ModuleNode { void Codegen(); protected: + /* + * \brief Perform a series of optimizations on the input IR module. + * + * \param mod The input IRModule. + * \param targets For heterogeneous compilation, it is a dictionary indicating context + * to target mapping. For homogeneous compilation, it is a build target. + * \param target_host Host compilation target. + * + * \return The optimized IRModule. + */ IRModule OptimizeModule(const IRModule& mod, const TargetsMap& targets, const Target& target_host); + /*! + * \brief Populate the global function names in a map where the value is used + * as the index by the VMFunctions. + */ void PopulateGlobalMap(); + /*! \brief Analyze the device context of each expression. */ + ExprDeviceMap AnalyzeContext() const; + protected: /*! \brief Target devices. */ TargetsMap targets_; diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 923965f98192..3a58607e6dd8 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include "../transforms/infer_layout_util.h" #include "type_relations.h" @@ -60,7 +61,12 @@ on different devices. .add_type_rel("Identity", IdentityRel) .set_attr("TOpPattern", kOpaque) .set_attr("TOpIsStateful", false) - .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); + .set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout) + .set_attr("FTVMCompute", + [](const Attrs& attrs, const Array& inputs, + const Type& out_dtype) -> Array { + return {topi::identity(inputs[0])}; + }); } // namespace relay } // namespace tvm diff --git a/src/relay/transforms/fold_constant.cc b/src/relay/transforms/fold_constant.cc index 7273c28a0e93..bdc613d85cdb 100644 --- a/src/relay/transforms/fold_constant.cc +++ b/src/relay/transforms/fold_constant.cc @@ -79,6 +79,7 @@ class ConstantFolder : public ExprMutator { public: explicit ConstantFolder(IRModule module) : module_(module), + device_copy_op_(Op::Get("device_copy")), shape_of_op_(Op::Get("shape_of")), vm_shape_of_op_(Op::Get("vm.shape_of")), invoke_tvm_op_(Op::Get("vm.invoke_tvm_op")), @@ -134,7 +135,7 @@ class ConstantFolder : public ExprMutator { // We should think about potentially constant evaluation over these ops too. if (call->op == invoke_tvm_op_ || call->op == shape_func_op_ || call->op == alloc_tensor_op_ || - call->op == alloc_storage_op_) { + call->op == alloc_storage_op_ || call->op == device_copy_op_) { return GetRef(call); } @@ -168,6 +169,7 @@ class ConstantFolder : public ExprMutator { IRModule module_; // Cache the following ops for equivalence checking in this pass. + const Op& device_copy_op_; const Op& shape_of_op_; const Op& vm_shape_of_op_; const Op& invoke_tvm_op_; diff --git a/src/runtime/vm/bytecode.cc b/src/runtime/vm/bytecode.cc index edfd3acfb3e2..78972beb1ed2 100644 --- a/src/runtime/vm/bytecode.cc +++ b/src/runtime/vm/bytecode.cc @@ -123,6 +123,11 @@ Instruction::Instruction(const Instruction& instr) { this->reshape_tensor.tensor = instr.reshape_tensor.tensor; this->reshape_tensor.newshape = instr.reshape_tensor.newshape; return; + case Opcode::DeviceCopy: + this->src = instr.src; + this->src_device_type = instr.src_device_type; + this->dst_device_type = instr.dst_device_type; + return; default: std::ostringstream out; out << "Invalid instruction " << static_cast(instr.op); @@ -220,6 +225,15 @@ Instruction& Instruction::operator=(const Instruction& instr) { case Opcode::ShapeOf: this->shape_of.tensor = instr.shape_of.tensor; return *this; + case Opcode::ReshapeTensor: + this->reshape_tensor.tensor = instr.reshape_tensor.tensor; + this->reshape_tensor.newshape = instr.reshape_tensor.newshape; + return *this; + case Opcode::DeviceCopy: + this->src = instr.src; + this->src_device_type = instr.src_device_type; + this->dst_device_type = instr.dst_device_type; + return *this; default: std::ostringstream out; out << "Invalid instruction " << static_cast(instr.op); @@ -241,6 +255,7 @@ Instruction::~Instruction() { case Opcode::AllocStorage: case Opcode::ShapeOf: case Opcode::ReshapeTensor: + case Opcode::DeviceCopy: case Opcode::Fatal: return; case Opcode::AllocTensor: @@ -324,13 +339,14 @@ Instruction Instruction::AllocTensorReg(RegName storage, RegName offset, RegName } Instruction Instruction::AllocStorage(RegName size, Index alignment, DLDataType dtype_hint, - RegName dst) { + Index device_type, RegName dst) { Instruction instr; instr.op = Opcode::AllocStorage; instr.dst = dst; instr.alloc_storage.allocation_size = size; instr.alloc_storage.alignment = alignment; instr.alloc_storage.dtype_hint = dtype_hint; + instr.alloc_storage.device_type = device_type; return instr; } @@ -351,6 +367,17 @@ Instruction Instruction::ReshapeTensor(RegName tensor, RegName newshape, RegName return instr; } +Instruction Instruction::DeviceCopy(RegName src, Index src_device_type, Index dst_device_type, + RegName dst) { + Instruction instr; + instr.op = Opcode::DeviceCopy; + instr.dst = dst; + instr.src = src; + instr.src_device_type = src_device_type; + instr.dst_device_type = dst_device_type; + return instr; +} + Instruction Instruction::AllocADT(Index tag, Index num_fields, const std::vector& datatype_fields, RegName dst) { Instruction instr; @@ -582,7 +609,8 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { case Opcode::AllocStorage: { os << "alloc_storage $" << instr.dst << " $" << instr.alloc_storage.allocation_size << " " << instr.alloc_storage.alignment << " " - << DLDataType2String(instr.alloc_storage.dtype_hint); + << DLDataType2String(instr.alloc_storage.dtype_hint) << " " + << instr.alloc_storage.device_type; break; } case Opcode::ShapeOf: { @@ -594,6 +622,11 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { << instr.reshape_tensor.newshape; break; } + case Opcode::DeviceCopy: { + os << "device_copy $" << instr.dst << " $" << instr.src << " " << instr.dst_device_type << " " + << instr.src_device_type; + break; + } default: LOG(FATAL) << "should never hit this case" << static_cast(instr.op); break; diff --git a/src/runtime/vm/executable.cc b/src/runtime/vm/executable.cc index 998762187dfc..cc1dc8dd19e5 100644 --- a/src/runtime/vm/executable.cc +++ b/src/runtime/vm/executable.cc @@ -32,6 +32,7 @@ #include #include #include +#include #include #include #include @@ -249,6 +250,13 @@ void Executable::SaveConstantSection(dmlc::Stream* strm) { for (const auto& it : arrays) { runtime::SaveDLTensor(strm, it); } + + // Save the const to device mapping. + std::vector const_device_type; + for (auto dev_type : this->const_device_type) { + const_device_type.push_back(static_cast(dev_type)); + } + strm->Write(const_device_type); } void Executable::SavePrimitiveOpNames(dmlc::Stream* strm) { @@ -351,6 +359,7 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.push_back(dtype.code); fields.push_back(dtype.bits); fields.push_back(dtype.lanes); + fields.push_back(instr.alloc_storage.device_type); fields.push_back(instr.dst); break; } @@ -428,6 +437,11 @@ VMInstructionSerializer SerializeInstruction(const Instruction& instr) { fields.assign({instr.reshape_tensor.tensor, instr.reshape_tensor.newshape, instr.dst}); break; } + case Opcode::DeviceCopy: { + // Number of fields = 4 + fields.assign({instr.src, instr.src_device_type, instr.dst_device_type, instr.dst}); + break; + } default: LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); break; @@ -442,7 +456,7 @@ void Executable::SaveCodeSection(dmlc::Stream* strm) { for (const auto& func : this->functions) { // Save the function info. VMFunctionSerializer func_format(func.name, func.register_file_size, func.instructions.size(), - func.params); + func.params, func.params_device_type); func_format.Save(strm); // Serialize each instruction. @@ -509,6 +523,14 @@ void Executable::LoadConstantSection(dmlc::Stream* strm) { STREAM_CHECK(constant.Load(strm), "constant"); this->constants.push_back(constant); } + + // Load the const to device mapping. + std::vector const_device_type; + STREAM_CHECK(strm->Read(&const_device_type), "constant"); + CHECK_EQ(size, const_device_type.size()); + for (auto dev : const_device_type) { + this->const_device_type.push_back(static_cast(dev)); + } } void Executable::LoadPrimitiveOpNames(dmlc::Stream* strm) { @@ -622,7 +644,8 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); } case Opcode::AllocStorage: { - DCHECK_GE(instr.fields.size(), 6U); + // Number of fields = 7 + DCHECK_GE(instr.fields.size(), 7U); Index allocation_size = instr.fields[0]; Index alignment = instr.fields[1]; @@ -631,9 +654,10 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { dtype.bits = instr.fields[3]; dtype.lanes = instr.fields[4]; - RegName dst = instr.fields[5]; + Index device_type = instr.fields[5]; + RegName dst = instr.fields[6]; - return Instruction::AllocStorage(allocation_size, alignment, dtype, dst); + return Instruction::AllocStorage(allocation_size, alignment, dtype, device_type, dst); } case Opcode::If: { // Number of fields = 4 @@ -704,6 +728,12 @@ Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { DCHECK_EQ(instr.fields.size(), 3U); return Instruction::ReshapeTensor(instr.fields[0], instr.fields[1], instr.fields[2]); } + case Opcode::DeviceCopy: { + // Number of fields = 4 + DCHECK_EQ(instr.fields.size(), 4U); + return Instruction::DeviceCopy(instr.fields[0], instr.fields[1], instr.fields[2], + instr.fields[3]); + } default: LOG(FATAL) << "Invalid opcode" << instr.opcode; return Instruction(); @@ -733,7 +763,7 @@ void Executable::LoadCodeSection(dmlc::Stream* strm) { // Create the VM function. VMFunction vm_func = VMFunction(loaded_func.name, loaded_func.params, instructions, - loaded_func.register_file_size); + loaded_func.register_file_size, loaded_func.params_device_type); auto it = this->global_map.find(loaded_func.name); CHECK(it != this->global_map.end()); CHECK_LE(it->second, this->global_map.size()); diff --git a/src/runtime/vm/profiler/vm.cc b/src/runtime/vm/profiler/vm.cc index 7273b565cd69..63001634558e 100644 --- a/src/runtime/vm/profiler/vm.cc +++ b/src/runtime/vm/profiler/vm.cc @@ -105,9 +105,19 @@ void VirtualMachineDebug::LoadExecutable(const Executable* exec) { void VirtualMachineDebug::InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count, Index output_size, const std::vector& args) { CHECK(exec_); - auto ctx = this->GetParamsContext(); - // warmup - VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size, args); + CHECK(!ctxs_.empty()) << "Context has not been initialized yet."; + // The device context of any input of the operator is used for + // synchronization. + CHECK_GT(arg_count, 0U); + ObjectRef arg = args[0]; + while (arg->IsInstance()) { + ADT adt = Downcast(arg); + arg = adt[0]; + } + CHECK(arg->IsInstance()); + auto nd_array = Downcast(arg); + auto ctx = nd_array->ctx; + TVMSynchronize(ctx.device_type, ctx.device_id, nullptr); auto op_begin = std::chrono::high_resolution_clock::now(); diff --git a/src/runtime/vm/serialize_util.h b/src/runtime/vm/serialize_util.h index d52b73d81a78..d17256d6a079 100644 --- a/src/runtime/vm/serialize_util.h +++ b/src/runtime/vm/serialize_util.h @@ -57,15 +57,19 @@ struct VMFunctionSerializer { size_t num_instructions; /*! \brief The parameters of the VMFunction. */ std::vector params; + /*! \brief The device type of each parameter of the VMFunction. */ + std::vector params_device_type; VMFunctionSerializer() = default; VMFunctionSerializer(const std::string& name, Index register_file_size, size_t num_instructions, - const std::vector& params) + const std::vector& params, + const std::vector& params_device_type) : name(name), register_file_size(register_file_size), num_instructions(num_instructions), - params(params) {} + params(params), + params_device_type(params_device_type) {} /*! * \brief Load the serialized function header. @@ -81,7 +85,9 @@ struct VMFunctionSerializer { register_file_size = std::stoll(func_info[1]); // Get the number of instructions. num_instructions = static_cast(std::stoll(func_info[2])); - return strm->Read(¶ms); + if (!strm->Read(¶ms)) return false; + if (!strm->Read(¶ms_device_type)) return false; + return true; } /*! @@ -95,6 +101,7 @@ struct VMFunctionSerializer { func_info.push_back(std::to_string(num_instructions)); strm->Write(func_info); strm->Write(params); + strm->Write(params_device_type); } }; diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index 9af520228fee..aeee137530b1 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -68,8 +68,17 @@ inline ObjectRef CopyTo(ObjectRef src, const DLContext& ctx) { if (nd_array->ctx.device_type != ctx.device_type) { return nd_array.CopyTo(ctx); } + return src; + } else { + CHECK(src->IsInstance()) + << "VM data must be NDArray or a list of NDArray, but received: " << src->_type_key; + std::vector ret; + ADT adt = Downcast(src); + for (size_t i = 0; i < adt.size(); i++) { + ret.push_back(CopyTo(adt[i], ctx)); + } + return ADT(adt->tag, ret.begin(), ret.end()); } - return src; } std::vector ToShape(NDArray shape_tensor) { @@ -146,12 +155,14 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, auto func_index = gvit->second; const auto& vm_func = exec_->functions[func_index]; const auto& param_names = vm_func.params; - // TODO(icemelon9): For heterogeneous execution, get input device information - TVMContext ctx = ctxs_[0]; CHECK_EQ(args.size() - 1, param_names.size()) << "The number of provided parameters doesn't match the number of arguments"; + CHECK_EQ(param_names.size(), vm_func.params_device_type.size()) + << "The number of provided parameters doesn't match the number of assigned devices"; std::vector func_args(param_names.size()); for (int i = 1; i < args.size(); ++i) { + Index device_type = vm_func.params_device_type[i - 1]; + DLContext ctx = GetContext(device_type); ObjectRef obj = CopyTo(args[i], ctx); func_args[i - 1] = obj; } @@ -164,18 +175,13 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } } -TVMContext VirtualMachine::GetParamsContext() const { - CHECK(!ctxs_.empty()) << "Context has not been initialized yet."; +inline TVMContext VirtualMachine::GetContext(Index device_type) const { + CHECK_GE(ctxs_.size(), device_type) << "ctxs_ list doesn't contain device:" << device_type; - // Use the fallback device if no device index is available. - int fallback_device_type = static_cast(ctxs_[0].device_type); - // TODO(wweic): For heterogeneous execution, get device information from byte - - const auto& cit = - std::find_if(ctxs_.begin(), ctxs_.end(), [&fallback_device_type](const TVMContext& c) { - return fallback_device_type == static_cast(c.device_type); - }); - return (cit == ctxs_.end() ? ctxs_[0] : *cit); + auto ctx = ctxs_[device_type]; + CHECK_EQ(static_cast(ctx.device_type), device_type) + << "device type " << device_type << " has not been initialized int the context list."; + return ctx; } void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { @@ -283,10 +289,16 @@ void VirtualMachine::LoadExecutable(const Executable* exec) { void VirtualMachine::Init(const std::vector& ctxs, const std::vector& alloc_types) { CHECK_EQ(ctxs.size(), alloc_types.size()); - ctxs_ = ctxs; - for (size_t i = 0; i < ctxs.size(); ++i) { + // Cache the context + for (size_t i = 0; i < ctxs.size(); i++) { + auto dev_type = static_cast(ctxs[i].device_type); auto alloc = MemoryManager::GetOrCreateAllocator(ctxs[i], alloc_types[i]); - allocators_.emplace(ctxs[i], alloc); + if (ctxs_.size() <= dev_type) { + ctxs_.resize(dev_type + 1); + allocators_.resize(dev_type + 1); + } + ctxs_[dev_type] = ctxs[i]; + allocators_[dev_type] = alloc; } } @@ -364,8 +376,8 @@ void VirtualMachine::RunLoop() { } if (!const_pool_[instr.const_index].defined()) { - // TODO(wweic) ctx could be obtained from the ctxs list. - const_pool_[instr.const_index] = CopyTo(constant_obj, ctxs_[0]); + TVMContext ctx = GetContext(exec_->const_device_type[instr.const_index]); + const_pool_[instr.const_index] = CopyTo(constant_obj, ctx); } WriteRegister(instr.dst, const_pool_[instr.const_index]); pc_++; @@ -473,9 +485,7 @@ void VirtualMachine::RunLoop() { goto main_loop; } case Opcode::AllocTensorReg: { - DLContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; + DLContext cpu_ctx = GetContext(static_cast(kDLCPU)); auto shape_obj = ReadRegister(instr.alloc_tensor_reg.shape_register); NDArray shape_tensor = Downcast(CopyTo(shape_obj, cpu_ctx)); auto shape = ToShape(shape_tensor); @@ -511,14 +521,16 @@ void VirtualMachine::RunLoop() { auto size = LoadScalarInt(instr.alloc_storage.allocation_size); auto alignment = instr.alloc_storage.alignment; - DLOG(INFO) << "AllocStorage: allocation_size=" << size << "alignment=" << alignment - << "dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint); + DLOG(INFO) << "AllocStorage: allocation_size=" << size << ", alignment=" << alignment + << ", dtype_hint=" << DLDataType2String(instr.alloc_storage.dtype_hint) + << ", device_type=" << instr.alloc_storage.device_type; auto storage_obj = SimpleObjAllocator().make_object(); - auto it = allocators_.find(ctxs_[0]); - CHECK(it != allocators_.end()) - << "Did you forget to init the VirtualMachine with contexts?"; - auto alloc = it->second; + auto dev_type = instr.alloc_storage.device_type; + CHECK_LT(static_cast(dev_type), allocators_.size()) + << "Memory allocator for device " << dev_type << " has not been initialized"; + auto* alloc = allocators_[dev_type]; + CHECK(alloc) << "Did you forget to init the VirtualMachine with contexts?"; storage_obj->buffer = alloc->Alloc(size, alignment, instr.alloc_storage.dtype_hint); Storage storage(storage_obj); WriteRegister(instr.dst, storage); @@ -553,9 +565,7 @@ void VirtualMachine::RunLoop() { } } case Opcode::ReshapeTensor: { - DLContext cpu_ctx; - cpu_ctx.device_type = kDLCPU; - cpu_ctx.device_id = 0; + DLContext cpu_ctx = GetContext(static_cast(kDLCPU)); auto tensor_obj = ReadRegister(instr.reshape_tensor.tensor); NDArray tensor_arr = Downcast(tensor_obj); // Read the shape from shape tensor @@ -573,6 +583,21 @@ void VirtualMachine::RunLoop() { pc_++; goto main_loop; } + case Opcode::DeviceCopy: { + auto tensor_src = ReadRegister(instr.src); + NDArray src_data = Downcast(tensor_src); + DLContext src_ctx = src_data->ctx; + CHECK_EQ(static_cast(src_ctx.device_type), instr.src_device_type); + + DLContext dst_ctx; + dst_ctx.device_type = static_cast(instr.dst_device_type); + dst_ctx.device_id = 0; + + NDArray dst_data = src_data.CopyTo(dst_ctx); + WriteRegister(instr.dst, dst_data); + pc_++; + goto main_loop; + } default: LOG(FATAL) << "Unknown instruction opcode: " << int(instr.op); } diff --git a/tests/python/relay/benchmarking/benchmark_vm.py b/tests/python/relay/benchmarking/benchmark_vm.py index 80e9e4141c1d..4fcf39d0aae2 100644 --- a/tests/python/relay/benchmarking/benchmark_vm.py +++ b/tests/python/relay/benchmarking/benchmark_vm.py @@ -67,8 +67,8 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32', if measure: print("Evaluate vm inference cost of {} on {}".format(model, repr(ctx))) - ftimer = rly_vm.mod.time_evaluator("invoke", ctx, number=number, - repeat=repeat) + ftimer = rly_vm.module.time_evaluator("invoke", ctx, number=number, + repeat=repeat) # Measure in millisecond. prof_res = np.array(ftimer("main", data).results) * 1000 print("Mean vm inference time (std dev): %.2f ms (%.2f ms)" % @@ -78,14 +78,13 @@ def get_vm_output(mod, data, params, target, ctx, dtype='float32', # random input data = np.random.uniform(size=data_shape).astype(dtype) - target = "llvm" - ctx = tvm.cpu(0) - - tvm_out = get_graph_runtime_output(mod, tvm.nd.array(data.astype(dtype)), - params, target, ctx, dtype) - vm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params, - target, ctx, dtype) - tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5) + + for target, ctx in testing.enabled_targets(): + tvm_out = get_graph_runtime_output(mod, tvm.nd.array(data.astype(dtype)), + params, target, ctx, dtype) + vm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params, + target, ctx, dtype) + tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5) def test_mlp(): diff --git a/tests/python/relay/dyn/test_dynamic_op_level10.py b/tests/python/relay/dyn/test_dynamic_op_level10.py index 0097a4eed9dc..8bc551be0ff1 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level10.py +++ b/tests/python/relay/dyn/test_dynamic_op_level10.py @@ -47,12 +47,11 @@ def test_dyn_broadcast_to(): dyn_shape = (1, ) * rank ref_res = np.broadcast_to(x, dyn_shape) for target, ctx in tvm.testing.enabled_targets(): - if (target != 'cuda'): #skip cuda because we don't have dynamic support for GPU - for kind in ["vm", "debug"]: - mod = tvm.ir.IRModule.from_expr(func) - intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) - op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) - tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) + for kind in ["vm", "debug"]: + mod = tvm.ir.IRModule.from_expr(func) + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate(func)(x, np.array(dyn_shape).astype(shape_type)) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res, rtol=1e-5) @tvm.testing.uses_gpu diff --git a/tests/python/relay/dyn/test_dynamic_op_level2.py b/tests/python/relay/dyn/test_dynamic_op_level2.py index bab4869b3fe0..15b6b7acd7e9 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level2.py +++ b/tests/python/relay/dyn/test_dynamic_op_level2.py @@ -52,7 +52,6 @@ def verify_upsampling(dshape, scale_h, scale_w, layout, method, align_corners=Fa func = relay.Function([x, scale_h_var, scale_w_var], z) for target, ctx in tvm.testing.enabled_targets(): - if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) diff --git a/tests/python/relay/dyn/test_dynamic_op_level3.py b/tests/python/relay/dyn/test_dynamic_op_level3.py index 193de85a5242..d6a2806719ab 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level3.py +++ b/tests/python/relay/dyn/test_dynamic_op_level3.py @@ -28,8 +28,6 @@ def verify_func(func, data, ref_res): assert isinstance(data, list) for target, ctx in tvm.testing.enabled_targets(): - #TODO(mbrookhart): enable Cuda tests onces the VM supports dynamic shapes - if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) diff --git a/tests/python/relay/dyn/test_dynamic_op_level5.py b/tests/python/relay/dyn/test_dynamic_op_level5.py index 226bbfe2678e..eb804fe430e3 100644 --- a/tests/python/relay/dyn/test_dynamic_op_level5.py +++ b/tests/python/relay/dyn/test_dynamic_op_level5.py @@ -60,7 +60,6 @@ def verify_resize(dshape, scale, method, layout): func = relay.Function([x, size_var], z) for target, ctx in tvm.testing.enabled_targets(): - if "llvm" not in target: continue for kind in ["vm", "debug"]: mod = tvm.ir.IRModule.from_expr(func) intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) diff --git a/tests/python/relay/test_adt.py b/tests/python/relay/test_adt.py index ff76e1c64bcb..d0e010570c8a 100644 --- a/tests/python/relay/test_adt.py +++ b/tests/python/relay/test_adt.py @@ -15,8 +15,8 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te from tvm import relay +from tvm.relay import testing from tvm.relay.backend.interpreter import ConstructorValue from tvm.relay import create_executor from tvm.relay.prelude import Prelude, StaticTensorArrayOps @@ -719,13 +719,15 @@ def test_iterate(): assert count(res) == 12 -def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", - ta_ctx=tvm.cpu(), target="llvm", rtol=1e-5): +def check_tensor_array(ta_mod, ref_res, *args, dtype="float32", rtol=1e-5): for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=ta_mod, ctx=ta_ctx, target=target) - result = ex.evaluate()(*args) - got = vmobj_to_list(result, dtype) - tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol) + for target, ctx in testing.enabled_targets(): + if kind == "debug" and ctx.device_type != tvm.cpu().device_type: + continue + ex = relay.create_executor(kind, mod=ta_mod, ctx=ctx, target=target) + result = ex.evaluate()(*args) + got = vmobj_to_list(result, dtype) + tvm.testing.assert_allclose(ref_res, got, rtol=rtol, atol=rtol) def test_tensor_expand_dims(): diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 7f373369d301..e33e2679fe5d 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -33,8 +33,26 @@ def any_dims(ndim): shape.append(relay.Any()) return tuple(shape) -# TODO(@wweic): because vm doesn't support heterogeneous exec, we can only test -# shape function on CPU. +def check_result(args, mod, expected, flatten=False, assert_shape=False, + only_vm=False): + for kind in ["debug", "vm"]: + for tgt, ctx in tvm.testing.enabled_targets(): + if kind == "debug" and (only_vm or ctx.device_type != + tvm.cpu().device_type): + continue + ex = relay.create_executor(kind, mod=mod, ctx=ctx, target=tgt) + result = ex.evaluate()(*args) + result = result.asnumpy() + if assert_shape: + assert result.shape == expected, \ + "Shape mismatch: expect %s but got %s." \ + % (str(expected), str(result.shape)) + return + + if flatten: + result = result.flatten() + expected = expected.flatten() + tvm.testing.assert_allclose(result, expected) def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): dtype = 'float32' @@ -45,10 +63,7 @@ def verify_any_broadcast(x_shape, y_shape, x_np_shape, y_np_shape, op, np_op): x_np = np.random.uniform(size=x_np_shape).astype(dtype) y_np = np.random.uniform(size=y_np_shape).astype(dtype) res_np = np_op(x_np, y_np) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, y_np) - tvm.testing.assert_allclose(result.asnumpy(), res_np) + check_result([x_np, y_np], mod, res_np) def test_any_broadcast(): # Test broadcast with 1s @@ -69,10 +84,7 @@ def verify_any_elemwise(x_shape, x_np_shape, op, np_op): mod["main"] = relay.Function([x], op(x)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) res_np = np_op(x_np) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np) - tvm.testing.assert_allclose(result.asnumpy(), res_np) + check_result([x_np], mod, res_np) def test_any_elemwise(): verify_any_elemwise((relay.Any(),), (3,), relay.sqrt, np.sqrt) @@ -103,10 +115,7 @@ def verify_any_full_like(x_shape, x_np_shape, relay_op, np_op, dtype='float32'): mod['main'] = relay.Function([x], relay_op(x)) x_np = np.random.uniform(size=x_np_shape).astype(dtype) res_np = np_op(x_np) - for kind in ['debug', 'vm']: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm') - result = ex.evaluate()(x_np).asnumpy() - tvm.testing.assert_allclose(result, res_np) + check_result([x_np], mod, res_np) def test_any_full_like(): # zeros_like, ones_like @@ -124,10 +133,7 @@ def verify_any_full(x_np_shape, relay_op, np_op, dtype='float32', value=None): mod['main'] = relay.Function([x], out) res_np = np_op(x_np_shape) if value is None else np_op(x_np_shape, value) x_np = np.array(x_np_shape).astype("int32") - for kind in ['debug', 'vm']: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target='llvm') - result = ex.evaluate()(x_np).asnumpy() - tvm.testing.assert_allclose(result, res_np) + check_result([x_np], mod, res_np) def test_any_full(): # zeros, ones, full @@ -151,10 +157,7 @@ def test_any_concat(): x_np = np.random.uniform(size=(3, 2)).astype('float32') y_np = np.random.uniform(size=(1, 2)).astype('float32') ref = np.concatenate([x_np - 3.0, y_np * 5.0], axis=0) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(x_np, y_np) - tvm.testing.assert_allclose(result.asnumpy(), ref) + check_result([x_np, y_np], mod, ref) def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newshape=False): x = relay.var('x', shape=x_shape, dtype="float32") @@ -172,12 +175,7 @@ def verify_any_reshape(x_shape, newshape, x_np_shape, out_shape, variable_newsha y = relay.reshape(relu_x, newshape=newshape) mod = tvm.IRModule() mod["main"] = relay.Function(params, y) - - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(*args).asnumpy() - assert result.shape == out_shape - tvm.testing.assert_allclose(result.flatten(), data.flatten()) + check_result(args, mod, data, flatten=True) def test_any_reshape(): for variable_newshape in [False, True]: @@ -201,6 +199,9 @@ def verify_any_argwhere(x_shape, x_np_shape, dtype="bool"): assert result.shape == expected.shape tvm.testing.assert_allclose(result.flatten(), expected.flatten()) + # TODO(@zhiics) argwhere gpu schedule is currently not avaiable + # check_result([data], mod, expected, flatten=True) + def test_any_argwhere(): verify_any_argwhere(any_dims(1), (5,)) verify_any_argwhere(any_dims(2), (5, 5)) @@ -231,10 +232,7 @@ def verify_any_take(data_shape, indices_shape, axis, data_np_shape, indices_np_s max_index = data_np.shape[axis] indices_np = np.random.randint(max_index, size=indices_np_shape).astype('int32') ref = np.take(data_np, indices_np, axis=axis) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, indices_np) - tvm.testing.assert_allclose(result.asnumpy(), ref) + check_result([data_np, indices_np], mod, ref) def test_any_take(): verify_any_take(any_dims(2), (1,), 0, (4, 5), (1,)) @@ -251,11 +249,7 @@ def verify_any_tile(dshape, reps, np_dshape, np_reps): mod["main"] = relay.Function([x], y) x_data = np.random.uniform(size=np_dshape).astype("float32") ref_res = np.tile(x_data, reps=np_reps) - - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - res = ex.evaluate()(x_data) - tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5) + check_result([x_data], mod, ref_res) def test_any_tile(): verify_any_tile(any_dims(3), (3, 2, 1), (2, 3, 4), (3, 2, 1)) @@ -269,10 +263,7 @@ def test_any_shape_of(): mod = tvm.IRModule() mod["main"] = relay.Function([x], y) data = np.random.uniform(size=(3, 4)).astype('float32') - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - tvm.testing.assert_allclose(result.asnumpy(), np.array([3,4]).astype("int64")) + check_result([data], mod, np.array([3,4]).astype("int64")) x = relay.var('x', shape=any_dims(3), dtype='float32') y0 = relay.shape_of(x) @@ -280,10 +271,7 @@ def test_any_shape_of(): mod = tvm.IRModule() mod["main"] = relay.Function([x], y1) data = np.random.uniform(size=(2, 3, 4)).astype('float32') - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - tvm.testing.assert_allclose(result.asnumpy(), np.array(3).astype("int64")) + check_result([data], mod, np.array(3).astype("int64")) def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims, static_data_shape, ref_out_shape): @@ -293,11 +281,7 @@ def verify_any_reduce(reduce_op, data_shape, axis, exclude, keepdims, y = reduce_op(data, axis, keepdims, exclude) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_reduce(): verify_any_reduce(relay.argmax, any_dims(3), None, False, False, (3, 4, 5), ()) @@ -316,11 +300,7 @@ def verify_any_layout_transform(data_shape, src_layout, dst_layout, static_data_ y = relay.layout_transform(data, src_layout, dst_layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_layout_transform(): verify_any_layout_transform(any_dims(4), "NCHW", "NHWC", (3, 4, 5, 6), (3, 5, 6, 4)) @@ -336,11 +316,7 @@ def verify_any_expand_dims(data_shape, axis, num_newaxis, static_data_shape, ref y = relay.expand_dims(data, axis=axis, num_newaxis=num_newaxis) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_expand_dims(): verify_any_expand_dims(any_dims(3), 1, 2, (1, 2, 3), (1, 1, 1, 2, 3)) @@ -354,10 +330,7 @@ def verify_any_transpose(data_shape, axes, static_data_shape): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) ref_out = np.transpose(data_np, axes) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) + check_result([data_np], mod, ref_out) def test_any_transpose(): verify_any_transpose(any_dims(3), (1, 0, 2), (10, 3, 2)) @@ -373,10 +346,7 @@ def verify_any_squeeze(data_shape, axis, static_data_shape): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) ref_out = np.squeeze(data_np, axis) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) + check_result([data_np], mod, ref_out) def test_any_squeeze(): verify_any_squeeze((1, relay.Any(), relay.Any()), (0,), (1, 9, 8)) @@ -391,11 +361,7 @@ def test_any_reshape_like(): mod["main"] = relay.Function([data, shape_like], y) data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype) shape_like_np = np.random.uniform(size=(3, 5, 6)).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, shape_like_np) - assert result.asnumpy().shape == shape_like_np.shape, \ - "Shape mismatch: expect %s but got %s." % (str(shape_like_np.shape), str(result.asnumpy().shape)) + check_result([data_np, shape_like_np], mod, shape_like_np.shape, assert_shape=True) def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation, data_layout, kernel_layout, out_layout, @@ -412,11 +378,7 @@ def verify_any_conv2d_NCHWc(data_shape, kernel_shape, strides, padding, dilation mod["main"] = relay.Function([data, kernel], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) kernel_np = np.random.uniform(size=kernel_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, kernel_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np, kernel_np], mod, ref_out_shape, assert_shape=True) # TODO(@kevinthesun): Need to fix the compute in conv2d_NCHWc to support any @pytest.mark.skip @@ -435,11 +397,7 @@ def verify_any_pool2d(pool_type, data_shape, pool_size, strides, padding, y = pool_func(data, pool_size, strides, padding, layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_pool2d(): verify_any_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()), @@ -457,11 +415,7 @@ def verify_any_global_pool2d(pool_type, data_shape, layout, static_data_shape, r y = pool_func(data, layout) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_global_pool2d(): verify_any_global_pool2d("max", (relay.Any(), 3, relay.Any(), relay.Any()), @@ -499,11 +453,7 @@ def test_any_batch_flatten(): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=(3, 3, 10)).astype(dtype) ref_out_shape = (3, 30) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def verify_any_dense(data_shape, weight_shape, units, static_data_shape, static_weight_shape, ref_out_shape): @@ -515,11 +465,7 @@ def verify_any_dense(data_shape, weight_shape, units, static_data_shape, mod["main"] = relay.Function([data, weight], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) weight_np = np.random.uniform(size=static_weight_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, weight_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np, weight_np], mod, ref_out_shape, assert_shape=True) def test_any_dense(): verify_any_dense(any_dims(2), any_dims(2), None, (4, 16), (8, 16), (4, 8)) @@ -533,10 +479,7 @@ def verify_any_pad(data_shape, pad_width, static_data_shape): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) ref_out = np.pad(data_np, pad_width) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) + check_result([data_np], mod, ref_out) def test_any_pad(): verify_any_pad(any_dims(3), ((0, 0), (1, 1), (2, 2)), (1, 2, 3)) @@ -554,11 +497,7 @@ def verify_any_dilate(data_shape, strides, static_data_shape): for i in range(len(static_data_shape))) ref_out = np.zeros(shape=ref_shape, dtype=dtype) ref_out[tuple(slice(None, None, strides[i]) for i in range(len(data_shape)))] = data_np - - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy(), ref_out) + check_result([data_np], mod, ref_out) def test_any_dilate(): verify_any_dilate(any_dims(1), (1,), (1,)) @@ -577,11 +516,7 @@ def verify_any_softmax(data_shape, axis, static_data_shape, ref_out_shape): y = relay.nn.softmax(data, axis) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_softmax(): verify_any_softmax(any_dims(3), -1, (1, 2, 3), (1, 2, 3)) @@ -613,6 +548,9 @@ def verify_any_topk(data_shape, kval, np_dshape, dtype, const_k=False): result = ex.evaluate()(*in_vals) tvm.testing.assert_allclose(result.asnumpy(), ref_out) + # TODO(@zhiics) Fix topk cuda schedule for dynamic inputs + # check_result(in_vals, mod, ref_out) + def test_any_topk(): verify_any_topk(any_dims(1), 5, (10,), "float32") verify_any_topk(any_dims(2), 2, (6, 3), "int32") @@ -625,10 +563,7 @@ def test_fused_ops(): mod = tvm.IRModule() mod["main"] = relay.Function([x], y1) data = np.random.uniform(size=(5, 4)).astype('float32') - for kind in ["vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - tvm.testing.assert_allclose(result.asnumpy(), (data + 1) * 2) + check_result([data], mod, (data + 1) * 2) def test_arange_with_dynamic_shape(): # m, n, k = relay.ShapeVar('m'), relay.ShapeVar('n'), relay.ShapeVar('k') @@ -641,10 +576,7 @@ def test_arange_with_dynamic_shape(): data = np.random.rand(10, 5, 3).astype('float32') mod = tvm.IRModule() mod["main"] = relay.Function([x], y3) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - tvm.testing.assert_allclose(result.asnumpy(), np.array(range(10)).astype("int32")+1) + check_result([data], mod, np.array(range(10)).astype("int32")+1) def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape, data_np_shape, slice_mode="end", const_attrs=False): @@ -677,11 +609,7 @@ def verify_any_strided_slice(data_shape, begin_shape, end_shape, strides_shape, strides=strides, slice_mode=slice_mode) mod["main"] = relay.Function(args, y) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(*np_inputs) - tvm.testing.assert_allclose(result.asnumpy(), ref_res) - + check_result(np_inputs, mod, ref_res) def test_any_strided_slice(): verify_any_strided_slice(any_dims(2), (2,), (2,), (2,), (15, 21)) @@ -724,10 +652,7 @@ def _body(i, st): mod["main"] = func data = np.array(0.0, dtype='int32') ref = np.array([0] + list(range(10))).reshape((11, 1)).astype("int32") - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data) - np.testing.assert_allclose(result.asnumpy(), ref) + check_result([data], mod, ref) def test_recursive_concat_with_wrong_annotation(): """ @@ -789,11 +714,7 @@ def test_tuple_get_item(): mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) ref_out_shape = (9, 2) - for kind in ["vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(ret.asnumpy().shape)) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_mixed_input_type(): mod = tvm.IRModule() @@ -811,11 +732,8 @@ def test_mixed_input_type(): data_np0 = np.random.uniform(size=static_data_shape).astype(dtype) data_np1 = np.random.uniform(size=static_data_shape).astype(dtype) ref_out_shape = (9, 4) - for kind in ["vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()([[data_np0, data_np0], data_np0], data_np1) - assert result.asnumpy().shape == ref_out_shape, \ - "Shape mismatch: expect %s but got %s." % (str(ref_out_shape), str(result.asnumpy().shape)) + check_result([[[data_np0, data_np0], data_np0], data_np1], mod, + ref_out_shape, assert_shape=True, only_vm=True) def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_size, layout, static_boxes, static_box_indices_shape, ref_out_shape): @@ -829,11 +747,8 @@ def verify_any_crop_and_resize(data_shape, boxes_shape, box_indices_shape, crop_ mod["main"] = relay.Function([data, boxes, box_indices], y) data_np = np.random.uniform(size=data_shape).astype(dtype) boxes_np = np.random.uniform(size=static_boxes).astype(dtype) - box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np, boxes_np, box_indices_np) - tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape) + box_indices_np = np.random.uniform(size=static_box_indices_shape).astype(indices_dtype) + check_result([data_np, boxes_np, box_indices_np], mod, ref_out_shape, assert_shape=True) def test_any_crop_and_resize(): verify_any_crop_and_resize( @@ -863,10 +778,7 @@ def verify_any_mirror_pad(data_shape, pad_width, static_data_shape, ref_out_shap y = relay.nn.mirror_pad(data, pad_width) mod["main"] = relay.Function([data], y) data_np = np.random.uniform(size=static_data_shape).astype(dtype) - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(data_np) - tvm.testing.assert_allclose(result.asnumpy().shape, ref_out_shape) + check_result([data_np], mod, ref_out_shape, assert_shape=True) def test_any_mirror_pad(): verify_any_mirror_pad( @@ -882,11 +794,7 @@ def verify_any_ndarray_size(data_np_shape): mod['main'] = relay.Function([v], n) np_data = np.zeros(data_np_shape, dtype='float32') ref_res = np.size(np_data) - - for kind in ["debug", "vm"]: - ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm") - result = ex.evaluate()(np_data) - tvm.testing.assert_allclose(result.asnumpy(), ref_res) + check_result([np_data], mod, ref_res) def test_any_ndarray_size(): verify_any_ndarray_size((2,)) diff --git a/tests/python/relay/test_pass_annotation.py b/tests/python/relay/test_pass_annotation.py index 7a2ff55790a7..c55120e4b7cf 100644 --- a/tests/python/relay/test_pass_annotation.py +++ b/tests/python/relay/test_pass_annotation.py @@ -25,6 +25,41 @@ from tvm.relay import transform import tvm.testing +def _trace(module, metadata, _): + if metadata.name == 'ManifestAlloc': + pass # import pdb; pdb.set_trace() + + +def check_graph_runtime(target, ref_res, device, func, params, config, + opt_level, expected_index=None): + with tvm.transform.PassContext(opt_level=opt_level, config=config): + graph, lib, new_params = relay.build( + func, + target, + params=params) + contexts = [tvm.cpu(0), tvm.context(device)] + graph_json = json.loads(graph) + if "device_index" in graph_json["attrs"]: + device_index = graph_json["attrs"]["device_index"][1] + assert device_index == expected_index + mod = graph_runtime.create(graph, lib, contexts) + mod.set_input(**new_params) + mod.run() + res = mod.get_output(0).asnumpy() + tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) + + +def check_vm_runtime(target, ref_res, device, func, params, config, + opt_level, expected_index=None): + with tvm.transform.PassContext(opt_level=opt_level, trace=_trace, config=config): + mod = tvm.IRModule() + mod["main"] = func + exe = relay.vm.compile(mod, target) + ctx = [tvm.cpu(0), tvm.context(device)] + vm = tvm.runtime.vm.VirtualMachine(exe, ctx) + res = vm.invoke("main", **params) + tvm.testing.assert_allclose(res.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + def run_opt_pass(expr, passes): passes = passes if isinstance(passes, list) else [passes] mod = tvm.IRModule.from_expr(expr) @@ -400,6 +435,7 @@ def run_fusible_network(dev, tgt): tmp_log = np.log(tmp_add) tmp_sub = np.subtract(tmp_sqrt, tmp_log) ref_res = np.exp(tmp_sub) + params = {"x": x_data, "y": y_data} def get_func(): add = relay.add(x, y) @@ -411,28 +447,6 @@ def get_func(): func = relay.Function([x, y], exp) return func - def test_runtime(target, device, func, fallback_device=None, - expected_index=None): - params = {"x": x_data, "y": y_data} - config = {} - if fallback_device: - config["relay.fallback_device_type"] = fallback_device.device_type - with tvm.transform.PassContext(opt_level=1, config=config): - graph, lib, params = relay.build( - func, - target, - params=params) - contexts = [tvm.cpu(0), tvm.context(device)] - graph_json = json.loads(graph) - if "device_index" in graph_json["attrs"]: - device_index = graph_json["attrs"]["device_index"][1] - assert device_index == expected_index - mod = graph_runtime.create(graph, lib, contexts) - mod.set_input(**params) - mod.run() - res = mod.get_output(0).asnumpy() - tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) - def test_fuse_log_add(device, tgt): """ Only log and add are fused.""" fallback_device = tvm.context("cpu") @@ -473,8 +487,13 @@ def expected(): dev_idx = ctx.device_type expected_index = [1, 1, 1, dev_idx, dev_idx, 1, 1, dev_idx, dev_idx] check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func, fallback_device, - expected_index) + opt_level = 1 + config = {"relay.fallback_device_type": fallback_device.device_type} + check_graph_runtime(target, ref_res, device, annotated_func, params, + config, opt_level, expected_index) + opt_level = 2 + check_vm_runtime(target, ref_res, device, annotated_func, params, + config, opt_level, expected_index) def test_fuse_all(device, tgt): """Fuse all operators.""" @@ -503,7 +522,13 @@ def annotated(): annotated_func = annotated() expected_func = get_func() check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func, fallback_device) + opt_level = 1 + config = {"relay.fallback_device_type": fallback_device.device_type} + check_graph_runtime(target, ref_res, device, annotated_func, params, + config, opt_level) + opt_level = 2 + check_vm_runtime(target, ref_res, device, annotated_func, params, + config, opt_level) def test_fallback_exp(device, tgt): fallback_device = tvm.context("cpu") @@ -540,16 +565,25 @@ def expected(): ctx = tvm.context(device, 0) dev_idx = ctx.device_type expected_index = [dev_idx, dev_idx, dev_idx, 1, 1] + opt_level = 1 + config = {"relay.fallback_device_type": fallback_device.device_type} check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func, fallback_device, - expected_index) + check_graph_runtime(target, ref_res, device, annotated_func, params, config, + opt_level, expected_index) + opt_level = 2 + check_vm_runtime(target, ref_res, device, annotated_func, params, config, + opt_level, expected_index) def test_fallback_all_operators(device, tgt): target = {device: tgt, "cpu": "llvm"} annotated_func = get_func() expected_func = get_func() check_annotated_graph(annotated_func, expected_func) - test_runtime(target, device, annotated_func) + opt_level = 2 + check_graph_runtime(target, ref_res, device, annotated_func, params, {}, + opt_level) + check_vm_runtime(target, ref_res, device, annotated_func, params, {}, + opt_level) test_fuse_log_add(dev, tgt) @@ -557,6 +591,7 @@ def test_fallback_all_operators(device, tgt): test_fallback_exp(dev, tgt) test_fallback_all_operators(dev, tgt) + def run_unpropagatable_graph(dev, tgt): R""" The network is as following: a b c d @@ -608,20 +643,15 @@ def expected(): expected_index = [2, 2, 2, 1, 1, 1, 2, 2] check_annotated_graph(annotated_func, expected_func) params = {"a": a_data, "b": b_data, "c": c_data, "d": d_data} - with tvm.transform.PassContext(opt_level=0, - config={"relay.fallback_device_type": - fallback_device.device_type}): - graph, lib, params = relay.build(annotated_func, target, params=params) - contexts = [tvm.cpu(0), tvm.context(dev)] - graph_json = json.loads(graph) - if "device_index" in graph_json["attrs"]: - device_index = graph_json["attrs"]["device_index"][1] - assert device_index == expected_index - mod = graph_runtime.create(graph, lib, contexts) - mod.set_input(**params) - mod.run() - res = mod.get_output(0).asnumpy() - tvm.testing.assert_allclose(res, ref_res, rtol=1e-5, atol=1e-5) + opt_level = 0 + config = {"relay.fallback_device_type": fallback_device.device_type} + + check_graph_runtime(target, ref_res, dev, annotated_func, params, config, + opt_level, expected_index) + + opt_level = 2 + check_vm_runtime(target, ref_res, dev, annotated_func, params, config, + opt_level) @tvm.testing.requires_opencl @@ -686,5 +716,4 @@ def annotated(): test_annotate_all() test_annotate_none() test_conv_network() - test_check_run() test_tuple_get_item() diff --git a/tests/python/relay/test_pass_context_analysis.py b/tests/python/relay/test_pass_context_analysis.py new file mode 100644 index 000000000000..e54682be7871 --- /dev/null +++ b/tests/python/relay/test_pass_context_analysis.py @@ -0,0 +1,205 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks + +import numpy as np +import pytest + +import tvm +from tvm import relay +from tvm.relay import expr as _expr +from tvm.relay.analysis import context_analysis + + +def test_device_copy(): + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + x = relay.var("x", shape=(2, 3)) + copy = relay.op.device_copy(x, tvm.cpu(), tvm.gpu()) + out = copy + relay.const(np.random.rand(2, 3)) + glb_var = relay.GlobalVar("main") + mod[glb_var] = relay.Function([x], out) + ca = context_analysis(mod, tvm.cpu()) + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + for expr, dev in ca.items(): + if isinstance(expr, _expr.Call): + assert dev[0].value == gpu_dev + elif isinstance(expr, _expr.Var): + assert dev[0].value == cpu_dev + elif isinstance(expr, _expr.Constant): + assert dev[0].value == gpu_dev + + +def test_shape_func(): + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + data_shape = (relay.Any(),) + x = relay.var("x", shape=data_shape) + y = relay.op.vm.shape_of(x) + z = relay.nn.relu(y) + p0 = relay.var("p0", shape=data_shape) + fn = relay.Function([p0], z) + out = relay.var("out", shape=(1,), dtype="int64") + ins = relay.Tuple([y]) + outs = relay.Tuple([out]) + is_inputs = [False] + shape_func = relay.op.vm.shape_func(fn, ins, outs, is_inputs) + mod["main"] = relay.Function([x, out], shape_func) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + # The output of shape func should be on cpu. + assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev + # shape func is the body and it should be on cpu + assert main.body in ca and ca[main.body][0].value == cpu_dev + + +def test_vm_shape_of(): + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + data_shape = (relay.Any(),) + x = relay.var("x", shape=data_shape) + y = relay.op.vm.shape_of(x) + mod["main"] = relay.Function([x], y) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + assert main.body in ca and ca[main.body][0].value == cpu_dev + + +def test_alloc_storage(): + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + mod.import_from_std("core.rly") + size = relay.Var("size", relay.scalar_type("int64")) + alignment = relay.Var("alignment", relay.scalar_type("int64")) + # allocate a chunk on of memory on gpu. + sto = relay.op.memory.alloc_storage(size, alignment, tvm.gpu()) + mod["main"] = relay.Function([size, alignment], sto) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + body = main.body + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + # Inputs are unified with alloc storage inputs which are on cpu + assert main.params[0] in ca and ca[main.params[0]][0].value == cpu_dev + assert main.params[1] in ca and ca[main.params[1]][0].value == cpu_dev + + assert isinstance(body, relay.Call) and len(body.args) == 2 + # size of alloc_storage is on cpu + assert body.args[0] in ca and ca[body.args[0]][0].value == cpu_dev + # alignment of alloc_storage is on cpu + assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev + # alloc_storage is on gpu as specified + assert body in ca and ca[body][0].value == gpu_dev + + +def test_alloc_tensor(): + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + mod.import_from_std("core.rly") + sto_type = relay.TypeCall(mod.get_global_type_var("Storage"), []) + sto = relay.Var("x", sto_type) + sh = relay.const(np.array([3, 2]), dtype="int64") + at = relay.op.memory.alloc_tensor(sto, relay.const(0, dtype="int64"), sh) + mod["main"] = relay.Function([sto], at) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + body = main.body + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + # Input of the function falls back to the default device gpu + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + + assert isinstance(body, relay.Call) and len(body.args) == 3 + # storage of alloc_tensor falls back to the default device gpu + assert body.args[0] in ca and ca[body.args[0]][0].value == gpu_dev + # shape of alloc_tensor is on cpu + assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev + # alloc_tensor keeps the same device context as storage which is is on gpu + assert body in ca and ca[body][0].value == gpu_dev + + +def test_vm_reshape_tensor(): + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: + return + + x = relay.var("x", shape=(2, 8), dtype="float32") + shape = relay.const([-1, 4, 2], dtype="int64") + y = relay.op.vm.reshape_tensor(x, shape, [2, 4, 2]) + mod = tvm.IRModule() + mod["main"] = relay.Function([x], y) + ca = context_analysis(mod, tvm.gpu()) + main = mod["main"] + body = main.body + + cpu_dev = tvm.cpu().device_type + gpu_dev = tvm.gpu().device_type + # Input of the function falls back to the default device gpu + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + + # dats of reshape_tensor falls back to the default device gpu + assert body.args[0] in ca and ca[body.args[0]][0].value == gpu_dev + # shape of reshape_tensor is on cpu + assert body.args[1] in ca and ca[body.args[1]][0].value == cpu_dev + # reshape_tensor sits on the same device as the data + assert body in ca and ca[body][0].value == gpu_dev + + +def test_dynamic_input(): + if not tvm.testing.device_enabled("cuda") or not tvm.gpu(0).exist: + return + + mod = tvm.IRModule() + data_shape = (relay.Any(), relay.Any()) + x0 = relay.var("x0", shape=data_shape) + x1 = relay.var("x1", shape=data_shape) + mod["main"] = relay.Function([x0, x1], x0 + x1) + + compiler = relay.vm.VMCompiler() + mod, _ = compiler.optimize(mod, target="cuda") + ca = context_analysis(mod, tvm.cpu()) + main = mod["main"] + + gpu_dev = tvm.gpu().device_type + assert main.params[0] in ca and ca[main.params[0]][0].value == gpu_dev + assert main.params[1] in ca and ca[main.params[1]][0].value == gpu_dev + assert main.body in ca and ca[main.body][0].value == gpu_dev + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py index df3bbc19cb58..b304c435fdfa 100644 --- a/tests/python/relay/test_vm_serialization.py +++ b/tests/python/relay/test_vm_serialization.py @@ -298,5 +298,20 @@ def test_vm_shape_of(): tvm.testing.assert_allclose(res.flatten(), data.flatten()) +def test_dynamic_bcast(): + dtype = 'float32' + x = relay.var('x', shape=(relay.Any(), 2), dtype=dtype) + y = relay.var('y', shape=(3, 2), dtype=dtype) + mod = tvm.IRModule() + mod['main'] = relay.Function([x, y], relay.add(x, y)) + x_data = np.random.uniform(size=(1, 2)).astype(dtype) + y_data = np.random.uniform(size=(3, 2)).astype(dtype) + res_np = np.add(x_data, y_data) + for target, ctx in testing.enabled_targets(): + res = get_serialized_output(mod, *(x_data, y_data), target=target, + ctx=ctx) + tvm.testing.assert_allclose(res.asnumpy(), res_np) + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/python/unittest/test_runtime_vm_profiler.py b/tests/python/unittest/test_runtime_vm_profiler.py index 97b54c6c5505..9e484357cf3e 100644 --- a/tests/python/unittest/test_runtime_vm_profiler.py +++ b/tests/python/unittest/test_runtime_vm_profiler.py @@ -16,25 +16,23 @@ # under the License. import numpy as np -import tvm -from tvm import te from tvm.runtime import profiler_vm from tvm import relay -from tvm.relay.testing import resnet +from tvm.relay.testing import resnet, enabled_targets def test_basic(): mod, params = resnet.get_workload() - target = 'llvm' - ctx = tvm.cpu() if not profiler_vm.enabled(): return - exe = relay.vm.compile(mod, target, params=params) - vm = profiler_vm.VirtualMachineProfiler(exe, ctx) - data = np.random.rand(1, 3, 224, 224).astype('float32') - res = vm.invoke("main", [data]) - print("\n{}".format(vm.get_stat())) - print("\n{}".format(vm.get_stat(False))) + for target, ctx in enabled_targets(): + exe = relay.vm.compile(mod, target, params=params) + vm = profiler_vm.VirtualMachineProfiler(exe, ctx) + + data = np.random.rand(1, 3, 224, 224).astype('float32') + res = vm.invoke("main", [data]) + print("\n{}".format(vm.get_stat())) + print("\n{}".format(vm.get_stat(False))) if __name__ == "__main__": test_basic()