diff --git a/docs/dev/virtual_machine.rst b/docs/dev/virtual_machine.rst index 79e1fd94edcd..2791ee71177e 100644 --- a/docs/dev/virtual_machine.rst +++ b/docs/dev/virtual_machine.rst @@ -204,7 +204,7 @@ InvokeClosure **Arguments**: :: RegName closure - size_t closure_args_num + size_t num_closure_args RegName* closure_args Invokes `closure`, consuming the number of arguments declared in the closure's VMFunction. diff --git a/include/tvm/runtime/vm.h b/include/tvm/runtime/vm.h index ea1f1b14f023..3e832d49f5bc 100644 --- a/include/tvm/runtime/vm.h +++ b/include/tvm/runtime/vm.h @@ -36,6 +36,9 @@ namespace tvm { namespace runtime { namespace vm { +/*! \brief Magic number for NDArray list file */ +constexpr uint64_t kTVMNDArrayListMagic = 0xF7E58D4F05049CB7; + /*! \brief A register name. */ using RegName = int64_t; @@ -103,7 +106,7 @@ struct Instruction { /*! \brief The register containing the closure. */ RegName closure; /*! \brief The number of arguments to the closure. */ - Index closure_args_num; + Index num_closure_args; /*! \brief The closure arguments as an array. */ RegName* closure_args; }; @@ -115,7 +118,7 @@ struct Instruction { /*! \brief The source register for a move operation. */ RegName from; }; - struct /* Packed Operands */ { + struct /* InvokePacked Operands */ { /*! \brief The index into the packed function table. */ Index packed_index; /*! \brief The arity of the packed function. */ @@ -149,7 +152,7 @@ struct Instruction { }; struct /* LoadConsti Operands */ { /* \brief The index into the constant pool. */ - size_t val; + Index val; } load_consti; struct /* Jump Operands */ { /*! \brief The jump offset. */ @@ -284,7 +287,7 @@ struct Instruction { * \param dst The destination register. * \return The load_constanti instruction. */ - static Instruction LoadConsti(size_t val, RegName dst); + static Instruction LoadConsti(Index val, RegName dst); /*! \brief Construct a move instruction. * \param src The source register. * \param dst The destination register. @@ -379,6 +382,8 @@ class VirtualMachine : public runtime::ModuleNode { return "VirtualMachine"; } + /*! \brief The runtime module/library that contains generated code. */ + runtime::Module lib; /*! \brief The virtual machine's packed function table. */ std::vector packed_funcs; /*! \brief The virtual machine's function table. */ @@ -448,16 +453,30 @@ class VirtualMachine : public runtime::ModuleNode { void Init(const std::vector& contexts); void Run(); + /*! + * \brief Load parameters from the parameter bytearray. + * \param params The binary file that contains parameters. + */ + void LoadParams(const std::string& params); + /*! \brief A map from globals (as strings) to their index in the function map. */ std::unordered_map global_map; + /*! \brief A mapping from the packed function (as string) to the index that + * corresponds to the position of the `packed_funcs` list. + */ + std::unordered_map primitive_map; + private: /*! \brief Invoke a global setting up the VM state to execute. * * This does not begin execution of the VM. */ void InvokeGlobal(const VMFunction& func, const std::vector& args); + + /*! \brief The parameter name to data mapping. */ + std::unordered_map params_; }; } // namespace vm diff --git a/python/tvm/relay/__init__.py b/python/tvm/relay/__init__.py index e94ef411d29d..da14c80b33b4 100644 --- a/python/tvm/relay/__init__.py +++ b/python/tvm/relay/__init__.py @@ -34,6 +34,8 @@ from . import param_dict from . import feature from .backend import vm +from .backend import serializer +from .backend import deserializer from .backend import vmobj # Root operators diff --git a/python/tvm/relay/backend/deserializer.py b/python/tvm/relay/backend/deserializer.py new file mode 100644 index 000000000000..fde702b1cd04 --- /dev/null +++ b/python/tvm/relay/backend/deserializer.py @@ -0,0 +1,81 @@ +# License .to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +""" +The Relay Virtual Machine deserializer. + +Python interface for deserializing a Relay VM. +""" +from tvm import module +from tvm._ffi.runtime_ctypes import TVMByteArray +from . import _vm +from . import vm as rly_vm + +def _create_deserializer(code, lib): + """Create a deserializer object. + + Parameters + ---------- + code : bytearray + The serialized virtual machine code. + + lib : :py:class:`~tvm.module.Module` + The serialized runtime module/library that contains the hardware + dependent binary code. + + Returns + ------- + ret : Deserializer + The created virtual machine deserializer. + """ + if isinstance(code, (bytes, str)): + code = bytearray(code) + elif not isinstance(code, (bytearray, TVMByteArray)): + raise TypeError("vm is expected to be the type of bytearray or " + + "TVMByteArray, but received {}".format(type(code))) + + if not isinstance(lib, module.Module): + raise TypeError("lib is expected to be the type of tvm.module.Module" + + ", but received {}".format(type(lib))) + return _vm._Deserializer(code, lib) + + +class Deserializer: + """Relay VM deserializer. + + Parameters + ---------- + code : bytearray + The serialized virtual machine code. + + lib : :py:class:`~tvm.module.Module` + The serialized runtime module/library that contains the hardware + dependent binary code. + """ + def __init__(self, code, lib): + self.mod = _create_deserializer(code, lib) + self._deserialize = self.mod["deserialize"] + + def deserialize(self): + """Deserialize the serialized bytecode into a Relay VM. + + Returns + ------- + ret : VirtualMachine + The deserialized Relay VM. + """ + return rly_vm.VirtualMachine(self._deserialize()) diff --git a/python/tvm/relay/backend/serializer.py b/python/tvm/relay/backend/serializer.py new file mode 100644 index 000000000000..b45ba9116a15 --- /dev/null +++ b/python/tvm/relay/backend/serializer.py @@ -0,0 +1,191 @@ +# License .to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +""" +The Relay Virtual Machine serializer. + +Python interface for serializing a Relay VM. +""" +import tvm +from . import _vm +from . import vm as rly_vm + +def _create_serializer(vm): + """Create a VM serializer. + + Parameters + ---------- + vm : Union[VirtualMachine, :py:class:`~tvm.module.Module`] + The virtual machine to be serialized. + + Returns + ------- + ret : Serializer + The created virtual machine serializer. + """ + if isinstance(vm, rly_vm.VirtualMachine): + vm = vm.module + elif not isinstance(vm, tvm.module.Module): + raise TypeError("vm is expected to be the type of VirtualMachine or " + + "tvm.Module, but received {}".format(type(vm))) + + return _vm._Serializer(vm) + + +class Serializer: + """Relay VM serializer.""" + def __init__(self, vm): + self.mod = _create_serializer(vm) + self._get_lib = self.mod["get_lib"] + self._get_bytecode = self.mod["get_bytecode"] + self._get_globals = self.mod["get_globals"] + self._get_stats = self.mod["get_stats"] + self._get_primitive_ops = self.mod["get_primitive_ops"] + self._serialize = self.mod["serialize"] + + @property + def stats(self): + """Get the statistics of the Relay VM. + + Returns + ------- + ret : String + The serialized statistic information. + """ + return self._get_stats() + + @property + def primitive_ops(self): + """Get the name of the primitive ops that are executed in the VM. + + Returns + ------- + ret : List[:py:class:`~tvm.expr.StringImm`] + The list of primitive ops. + """ + return [prim_op.value for prim_op in self._get_primitive_ops()] + + @property + def bytecode(self): + """Get the bytecode of the Relay VM. + + Returns + ------- + ret : String + The serialized bytecode. + + Notes + ----- + The bytecode is in the following format: + func_name reg_file_size num_instructions + param1 param2 ... paramM + instruction1 + instruction2 + ... + instructionN + + Each instruction is printed in the following format: + hash opcode field1 ... fieldX # The text format. + + The part starting from # is only used for visualization and debugging. + The real serialized code doesn't contain it, therefore the deserializer + doesn't need to deal with it as well. + """ + return self._get_bytecode() + + @property + def globals(self): + """Get the globals used by the Relay VM. + + Returns + ------- + ret : List[:py:class:`~tvm.expr.StringImm`] + The serialized globals. + """ + return [glb.value for glb in self._get_globals()] + + def serialize(self): + """Serialize the Relay VM. + + Returns + ------- + code : bytearray + The binary blob representing a serialized Relay VM. It can then be + saved to disk and later deserialized into a new VM. + + lib : :py:class:`~tvm.module.Module` + The runtime module that contains the generated code. It is + basically a library that is composed of hardware dependent code. + + Notes + ----- + The returned code is organized with the following sections in order. + - Global section. This section contains the globals used by the + virtual machine. + - Constant section. This section is used to store the constant pool of + a virtual machine. + - Primitive name section. This section is introduced to accommodate + the list of primitive operator names that will be invoked by the + virtual machine. + - Code section. The VM functions, including bytecode, are sitting in + this section. + + Examples + -------- + .. code-block:: python + + import numpy as np + import tvm + from tvm import relay + + # define a simple network. + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x + x) + mod = relay.Module({"main": f}) + + # create a Relay VM. + ctx = tvm.cpu() + target = "llvm" + compiler = relay.vm.VMCompiler() + vm = compiler.compile(mod, target) + vm.init(ctx) + + # serialize. + ser = relay.serializer.Serializer(vm) + code, lib = ser.serialize() + + # save and load the code and lib file. + tmp = tvm.contrib.util.tempdir() + path_lib = tmp.relpath("lib.so") + lib.export_library(path_lib) + with open(tmp.relpath("code.bc"), "wb") as fo: + fo.write(code) + + loaded_lib = tvm.module.load(path_lib) + loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) + + # deserialize. + deser = relay.deserializer.Deserializer(loaded_code, loaded_lib) + des_vm = deser.deserialize() + + # execute the deserialized vm. + des_vm.init(ctx) + x_data = np.random.rand(10, 10).astype('float32') + res = des_vm.run(x_data) + print(res.asnumpy()) + """ + return self._serialize(), self._get_lib() diff --git a/python/tvm/relay/backend/vm.py b/python/tvm/relay/backend/vm.py index 5f3413ae82c2..752e9b2cb10b 100644 --- a/python/tvm/relay/backend/vm.py +++ b/python/tvm/relay/backend/vm.py @@ -16,13 +16,14 @@ # under the License. # pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name """ -The Relay Virtual Vachine. +The Relay Virtual Machine. Implements a Python interface to compiling and executing on the Relay VM. """ import numpy as np import tvm +from tvm._ffi.runtime_ctypes import TVMByteArray from . import _vm from . import vmobj as _obj from .interpreter import Executor @@ -71,6 +72,7 @@ class VirtualMachine(object): def __init__(self, mod): self.mod = mod self._init = self.mod["init"] + self._load_params = self.mod["load_params"] self._invoke = self.mod["invoke"] def init(self, ctx): @@ -84,6 +86,23 @@ def init(self, ctx): args = [ctx.device_type, ctx.device_id] self._init(*args) + def load_params(self, params): + """Load parameters for the VM. + + Parameters + ---------- + params : Union[bytearray, Dict] + The dictionary that contains serialized parameters. + """ + if isinstance(params, dict): + params = tvm.relay.save_param_dict(params) + elif isinstance(params, (bytes, str)): + params = bytearray(params) + if not isinstance(params, (bytearray, TVMByteArray)): + raise TypeError("params must be a bytearray") + + self._load_params(bytearray(params)) + def invoke(self, func_name, *args): """Invoke a function. @@ -118,6 +137,11 @@ def run(self, *args): """ return self.invoke("main", *args) + @property + def module(self): + """Return the runtime module contained in a virtual machine.""" + return self.mod + class VMCompiler(object): """Build Relay module to run on VM runtime.""" diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index 11a3d0403067..853cd30f82c3 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -745,7 +745,7 @@ class VMCompiler : public runtime::ModuleNode { } #endif // USE_RELAY_DEBUG - PopulatePackedFuncMap(); + LibraryCodegen(); for (auto gv : context_.global_map) { vm_->global_map.insert({gv.first->name_hint, gv.second}); @@ -775,26 +775,28 @@ class VMCompiler : public runtime::ModuleNode { } } - void PopulatePackedFuncMap() { + void LibraryCodegen() { auto const& lowered_funcs = context_.lowered_funcs; if (lowered_funcs.size() == 0) { return; } - runtime::Module mod; // TODO(@icemelon9): support heterogeneous targets Target target; for (auto kv : targets_) { target = kv.second; } if (const auto* f = runtime::Registry::Get("relay.backend.build")) { - mod = (*f)(tvm::Array(lowered_funcs.begin(), lowered_funcs.end()), - target, target_host_); + runtime::Module mod = + (*f)(tvm::Array(lowered_funcs.begin(), lowered_funcs.end()), target, + target_host_); + CHECK(mod.operator->()); + vm_->lib = mod; } else { LOG(FATAL) << "relay.backend.build is not registered"; } - CHECK(mod.operator->()); + size_t primitive_index = 0; for (auto lfunc : lowered_funcs) { - vm_->packed_funcs.push_back(mod.GetFunction(lfunc->name)); + vm_->primitive_map.insert({lfunc->name, primitive_index++}); } } diff --git a/src/relay/backend/vm/deserializer.cc b/src/relay/backend/vm/deserializer.cc new file mode 100644 index 000000000000..6cf76081de13 --- /dev/null +++ b/src/relay/backend/vm/deserializer.cc @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/backend/vm/deserializer.cc + * \brief Implementation of APIs to deserialize the serialized VM bytecode. + */ + +#include "deserializer.h" + +#include +#include +#include + +#include "serialize_util.h" + +namespace tvm { +namespace relay { +namespace vm { + +#define STREAM_CHECK(val, section) \ + CHECK(val) << "Invalid VM file format in the " << section << " section." \ + << "\n"; + +void Deserializer::Init(const std::string& code, const runtime::Module& lib) { + code_ = code; + vm_ = std::make_shared(); + vm_->lib = lib; + strm_ = new dmlc::MemoryStringStream(&code_); +} + +runtime::PackedFunc Deserializer::GetFunction( + const std::string& name, + const std::shared_ptr& sptr_to_self) { + if (name == "deserialize") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->Deserialize(); + *rv = runtime::Module(vm_); + }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); + } +} + +void Deserializer::Deserialize() { + // Check header. + uint64_t header; + STREAM_CHECK(strm_->Read(&header), "header"); + STREAM_CHECK(header == kTVMVMBytecodeMagic, "header"); + + // Check version. + std::string version; + STREAM_CHECK(strm_->Read(&version), "version"); + STREAM_CHECK(version == TVM_VERSION, "version"); + + // Global section. + DeserializeGlobalSection(); + + // Constant section. + DeserializeConstantSection(); + + // Primitive names that will be invoked by `InvokePacked` instructions. + DeserializePrimitiveOpNames(); + + // Code section. + DeserializeCodeSection(); +} + +void Deserializer::DeserializeGlobalSection() { + std::vector globals; + STREAM_CHECK(strm_->Read(&globals), "global"); + for (size_t i = 0; i < globals.size(); i++) { + vm_->global_map.insert({globals[i], i}); + } +} + +void Deserializer::DeserializeConstantSection() { + uint64_t sz; + // Load the number of constants. + STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "constant"); + + size_t size = static_cast(sz); + // Load each of the constants. + for (size_t i = 0; i < size; i++) { + runtime::NDArray constant; + STREAM_CHECK(constant.Load(strm_), "constant"); + runtime::Object obj = runtime::Object::Tensor(constant); + vm_->constants.push_back(obj); + } +} + +void Deserializer::DeserializePrimitiveOpNames() { + std::vector primitive_names; + STREAM_CHECK(strm_->Read(&primitive_names), "primitive name"); + for (size_t i = 0; i < primitive_names.size(); i++) { + vm_->primitive_map.insert({primitive_names[i], i}); + } +} + +// Extract the `cnt` number of fields started at `start` from the list +// `instr_fields`. +inline std::vector ExtractFields(const std::vector& instr_fields, + Index start, + Index cnt) { + CHECK_LE(static_cast(start + cnt), instr_fields.size()); + std::vector ret; + for (auto i = start; i < start + cnt; i++) { + ret.push_back(instr_fields[i]); + } + return ret; +} + +Instruction DeserializeInstruction(const VMInstructionSerializer& instr) { + Opcode opcode = static_cast(instr.opcode); + switch (opcode) { + case Opcode::Move: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::Move(instr.fields[0], instr.fields[1]); + } + case Opcode::Ret: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Ret(instr.fields[0]); + } + case Opcode::Fatal: { + // Number of fields = 0 + DCHECK(instr.fields.empty()); + return Instruction::Fatal(); + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index packed_index = instr.fields[0]; + Index arity = instr.fields[1]; + Index output_size = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, arity); + return Instruction::InvokePacked(packed_index, arity, output_size, args); + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + DCHECK_GE(instr.fields.size(), 5U); + DCHECK_EQ(instr.fields.size(), 5U + static_cast(instr.fields[3])); + + DLDataType dtype; + dtype.code = instr.fields[0]; + dtype.bits = instr.fields[1]; + dtype.lanes = instr.fields[2]; + + Index ndim = instr.fields[3]; + RegName dst = instr.fields[4]; + + std::vector shape = ExtractFields(instr.fields, 5, ndim); + + return Instruction::AllocTensor(shape, dtype, dst); + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + DCHECK_EQ(instr.fields.size(), 5U); + Index shape_register = instr.fields[0]; + + DLDataType dtype; + dtype.code = instr.fields[1]; + dtype.bits = instr.fields[2]; + dtype.lanes = instr.fields[3]; + + RegName dst = instr.fields[4]; + + return Instruction::AllocTensorReg(shape_register, dtype, dst); + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index constructor_tag = instr.fields[0]; + Index num_fields = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector fields = ExtractFields(instr.fields, 3, num_fields); + + return Instruction::AllocDatatype(constructor_tag, num_fields, fields, dst); + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index clo_index = instr.fields[0]; + Index num_freevar = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector free_vars = ExtractFields(instr.fields, 3, num_freevar); + + return Instruction::AllocClosure(clo_index, num_freevar, free_vars, dst); + } + case Opcode::If: { + // Number of fields = 4 + DCHECK_EQ(instr.fields.size(), 4U); + Index test = instr.fields[0]; + Index target = instr.fields[1]; + Index true_offset = instr.fields[2]; + Index false_offset = instr.fields[3]; + + return Instruction::If(test, target, true_offset, false_offset); + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index func_index = instr.fields[0]; + Index num_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_args); + + return Instruction::Invoke(func_index, args, dst); + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + DCHECK_GE(instr.fields.size(), 3U); + DCHECK_EQ(instr.fields.size(), 3U + static_cast(instr.fields[1])); + + Index closure = instr.fields[0]; + Index num_closure_args = instr.fields[1]; + RegName dst = instr.fields[2]; + std::vector args = ExtractFields(instr.fields, 3, num_closure_args); + + return Instruction::InvokeClosure(closure, args, dst); + } + case Opcode::LoadConst: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConst(instr.fields[0], instr.fields[1]); + } + case Opcode::LoadConsti: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::LoadConsti(instr.fields[0], instr.fields[1]); + } + case Opcode::GetField: { + // Number of fields = 3 + DCHECK_EQ(instr.fields.size(), 3U); + return Instruction::GetField(instr.fields[0], instr.fields[1], instr.fields[2]); + } + case Opcode::GetTag: { + // Number of fields = 2 + DCHECK_EQ(instr.fields.size(), 2U); + return Instruction::GetTag(instr.fields[0], instr.fields[1]); + } + case Opcode::Goto: { + // Number of fields = 1 + DCHECK_EQ(instr.fields.size(), 1U); + return Instruction::Goto(instr.fields[0]); + } + default: + LOG(FATAL) << "Invalid opcode" << instr.opcode; + return Instruction(); + } +} + +void Deserializer::DeserializeCodeSection() { + // Load the number of functions. + uint64_t sz; + STREAM_CHECK(strm_->Read(&sz, sizeof(sz)), "code"); + + size_t num_funcs = static_cast(sz); + vm_->functions.resize(num_funcs); + for (size_t i = 0; i < num_funcs; i++) { + // Load the function info. + VMFunctionSerializer loaded_func; + STREAM_CHECK(loaded_func.Load(strm_), "code/function"); + + // Load the instructions. + std::vector instructions; + for (size_t j = 0; j < loaded_func.num_instructions; j++) { + VMInstructionSerializer instr; + std::vector instr_fields; + STREAM_CHECK(instr.Load(strm_), "code/instruction"); + instructions.push_back(DeserializeInstruction(instr)); + } + + // Create the VM function. + VMFunction vm_func = VMFunction(loaded_func.name, + loaded_func.params, + instructions, + loaded_func.register_file_size); + auto it = vm_->global_map.find(loaded_func.name); + CHECK(it != vm_->global_map.end()); + CHECK_LE(it->second, vm_->global_map.size()); + vm_->functions[it->second] = vm_func; + } +} + +runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) { + std::shared_ptr exec = std::make_shared(); + exec->Init(code, lib); + return runtime::Module(exec); +} + +TVM_REGISTER_GLOBAL("relay._vm._Deserializer") +.set_body_typed(CreateDeserializer); + +} // namespace vm +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/vm/deserializer.h b/src/relay/backend/vm/deserializer.h new file mode 100644 index 000000000000..0caf72bee92c --- /dev/null +++ b/src/relay/backend/vm/deserializer.h @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/backend/vm/deserializer.h + * \brief Define a deserializer for the serialized Relay VM. + */ + +#ifndef TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ +#define TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace vm { + +using namespace tvm::runtime::vm; +namespace runtime = tvm::runtime; + +class Deserializer : public runtime::ModuleNode { + public: + /*! + * \brief Initialize the deserializer for creating a virtual machine object. + * + * \param code The serialized code. + * \param lib The serialized runtime module/library that contains the + * hardware dependent code. + */ + inline void Init(const std::string& code, const runtime::Module& lib); + + /*! + * \brief Return the member function to the frontend. + * + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) final; + + const char* type_key() const final { return "Deserializer"; } + + /*! \brief Deserialize the serialized VM. */ + void Deserialize(); + + virtual ~Deserializer() { delete strm_; } + + private: + /*! \brief Deserialize the globals in `vm_`. */ + void DeserializeGlobalSection(); + + /*! \brief Deserialize the constant pool in `vm_`. */ + void DeserializeConstantSection(); + + /*! \brief Deserialize primitive op names in `vm_`. */ + void DeserializePrimitiveOpNames(); + + /*! \brief Deserialize the vm functions in `vm_`. */ + void DeserializeCodeSection(); + + /*! \brief The code to be serialized. */ + std::string code_; + + /*! \brief The stream used for serialization. */ + dmlc::Stream* strm_; + + /*! \brief The VM to be created. */ + std::shared_ptr vm_; +}; + +} // namespace vm +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_VM_DESERIALIZER_H_ diff --git a/src/relay/backend/vm/serialize_util.h b/src/relay/backend/vm/serialize_util.h new file mode 100644 index 000000000000..3e7508ebee9b --- /dev/null +++ b/src/relay/backend/vm/serialize_util.h @@ -0,0 +1,164 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/backend/vm/serialize_util.h + * \brief Definitions of helpers for serializing and deserializing a Relay VM. + */ +#ifndef TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ +#define TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ + +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace relay { +namespace vm { + +/*! \brief The magic number for the serialized VM bytecode file */ +constexpr uint64_t kTVMVMBytecodeMagic = 0xD225DE2F4214151D; + +template +static inline size_t VectorHash(size_t key, const std::vector& values) { + for (const auto& it : values) { + key = dmlc::HashCombine(key, it); + } + return key; +} + +// A struct to hold the funciton info in the code section. +struct VMFunctionSerializer { + /*! \brief The name of the VMFunction. */ + std::string name; + /*! \brief The number of registers used by the VMFunction. */ + Index register_file_size; + /*! \brief The number of instructions in the VMFunction. */ + size_t num_instructions; + /*! \brief The parameters of the VMFunction. */ + std::vector params; + + VMFunctionSerializer() = default; + + VMFunctionSerializer(const std::string& name, + Index register_file_size, + size_t num_instructions, + const std::vector& params) + : name(name), + register_file_size(register_file_size), + num_instructions(num_instructions), + params(params) {} + + /*! + * \brief Load the serialized function header. + * \param strm The stream used to load data. + * \return True if successful. Otherwise, false. + */ + bool Load(dmlc::Stream* strm) { + std::vector func_info; + if (!strm->Read(&func_info)) return false; + CHECK_EQ(func_info.size(), 3U) << "Failed to decode the vm function." + << "\n"; + name = func_info[0]; + register_file_size = std::stoll(func_info[1]); + // Get the number of instructions. + num_instructions = static_cast(std::stoll(func_info[2])); + return strm->Read(¶ms); + } + + /*! + * \brief Save the VM function header into the serialized form. + * \param strm The stream used to save data. + */ + void Save(dmlc::Stream* strm) const { + std::vector func_info; + func_info.push_back(name); + func_info.push_back(std::to_string(register_file_size)); + func_info.push_back(std::to_string(num_instructions)); + strm->Write(func_info); + strm->Write(params); + } +}; + +struct VMInstructionSerializer { + /*! \brief The opcode of the instruction. */ + Index opcode; + /*! \brief The fields of the instruction. */ + std::vector fields; + + VMInstructionSerializer() = default; + + VMInstructionSerializer(Index opcode, const std::vector& fields) : + opcode(opcode), fields(fields) {} + + /*! + * \brief Compute the hash of the serialized instruction. + * \return The hash that combines the opcode and all fields of the VM + * instruction. + */ + Index Hash() const { + size_t key = static_cast(opcode); + key = VectorHash(key, fields); + return key; + } + + /*! + * \brief Load the serialized instruction. + * \param strm The stream used to load data. + * \return True if successful. Otherwise, false. + */ + bool Load(dmlc::Stream* strm) { + std::vector instr; + if (!strm->Read(&instr)) return false; + CHECK_GE(instr.size(), 2U); + Index loaded_hash = instr[0]; + opcode = instr[1]; + + for (size_t i = 2; i < instr.size(); i++) { + fields.push_back(instr[i]); + } + + Index hash = Hash(); + CHECK_EQ(loaded_hash, hash) << "Found mismatch in hash for opcode: " + << opcode << "\n"; + return true; + } + + /*! + * \brief Save the instruction into the serialized form. + * \param strm The stream used to save data. + */ + void Save(dmlc::Stream* strm) const { + Index hash = Hash(); + std::vector serialized({hash, opcode}); + serialized.insert(serialized.end(), fields.begin(), fields.end()); + strm->Write(serialized); + } +}; + +} // namespace vm +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_VM_SERIALIZE_UTIL_H_ diff --git a/src/relay/backend/vm/serializer.cc b/src/relay/backend/vm/serializer.cc new file mode 100644 index 000000000000..d6e44b4af1f8 --- /dev/null +++ b/src/relay/backend/vm/serializer.cc @@ -0,0 +1,438 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/backend/vm/serializer.cc + * \brief Implementation of serializing APIs for the Relay VM. + */ +#include "serializer.h" + +#include +#include + +#include +#include +#include +#include +#include + +#include "serialize_util.h" + +namespace tvm { +namespace relay { +namespace vm { + +void Serializer::Init(const VirtualMachine* vm) { + vm_ = vm; + // Initialize the stream object. + strm_ = new dmlc::MemoryStringStream(&code_); +} + +runtime::PackedFunc Serializer::GetFunction( + const std::string& name, + const std::shared_ptr& sptr_to_self) { + if (name == "get_lib") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetLib(); + }); + } else if (name == "get_primitive_ops") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetPrimitiveOps(); + }); + } else if (name == "get_bytecode") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetBytecode(); + }); + } else if (name == "get_globals") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->GetGlobals(); + }); + } else if (name == "get_stats") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Stats(); + }); + } else if (name == "serialize") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->Serialize(); + }); + } else { + LOG(FATAL) << "Unknown packed function: " << name; + return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); + } +} + +tvm::Array Serializer::GetPrimitiveOps() const { + std::vector ret; + for (const auto& it : vm_->primitive_map) { + auto packed_name = tvm::ir::StringImm::make(it.first); + auto packed_index = static_cast(it.second); + if (ret.size() <= packed_index) { + ret.resize(packed_index + 1); + } + ret[packed_index] = packed_name; + } + return ret; +} + +std::string Serializer::Stats() const { + std::ostringstream oss; + oss << "Relay VM statistics:" << std::endl; + + // Get the number of constants and the shape of each of them. + oss << " Constant shapes (# " << vm_->constants.size() << "): ["; + for (const auto& it : vm_->constants) { + auto cell = it.AsTensor(); + CHECK(cell.operator->()); + runtime::NDArray data = cell->data; + const auto& shape = data.Shape(); + + // Scalar + if (shape.empty()) { + oss << "scalar, "; + continue; + } + + oss << "["; + for (auto s : shape) { + oss << s << ", "; + } + oss.seekp(-2, oss.cur); + oss << "], " << std::endl; + } + if (!vm_->constants.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of globals and the name of each of them. + oss << " Globals (#" << vm_->global_map.size() << "): ["; + for (const auto& it : vm_->global_map) { + oss << "(\"" << it.first << "\", " << it.second << ")" << ", "; + } + if (!vm_->global_map.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + // Get the number of primitive ops and the name of each of them. + oss << " Primitive ops (#" << vm_->primitive_map.size() << "): ["; + const auto& prim_ops = GetPrimitiveOps(); + for (const auto& it : prim_ops) { + oss << it << ", "; + } + if (!prim_ops.empty()) oss.seekp(-2, oss.cur); + oss << "]" << std::endl; + + return oss.str(); +} + +TVMByteArray Serializer::Serialize() { + uint64_t header = kTVMVMBytecodeMagic; + strm_->Write(header); + std::string version = TVM_VERSION; + strm_->Write(version); + + // Global section. + SerializeGlobalSection(); + + // Constant section. + SerializeConstantSection(); + + // Primitive names. + SerializePrimitiveOpNames(); + + // Code section. + SerializeCodeSection(); + + TVMByteArray arr; + arr.data = code_.c_str(); + arr.size = code_.length(); + return arr; +} + +void Serializer::SerializeGlobalSection() { + auto globals = GetGlobals(); + std::vector glbs; + for (const auto& it : globals) { + glbs.push_back(it.as()->value); + } + strm_->Write(glbs); +} + +void Serializer::SerializeConstantSection() { + std::vector arrays; + for (const auto& obj : vm_->constants) { + auto cell = obj.AsTensor(); + runtime::NDArray data = cell->data; + arrays.push_back(const_cast(data.operator->())); + } + strm_->Write(static_cast(vm_->constants.size())); + for (const auto& it : arrays) { + runtime::SaveDLTensor(strm_, it); + } +} + +void Serializer::SerializePrimitiveOpNames() { + auto names = GetPrimitiveOps(); + std::vector primitive_names; + for (const auto& it : names) { + primitive_names.push_back(it.as()->value); + } + strm_->Write(primitive_names); +} + +// Serialize a virtual machine instruction. It creates a list that contains the +// hash, opcode, and all fields of an instruction. +// +// For example, the function signature used to create an `AllocTensor` +// instruction is: +// Instruction AllocTensor(std::vector shape, DLDataType dtype, RegName dst) +// +// The serialized form will be: +// `hash 5 dtype.code dtype.bits dtype.lanes ndim dst_register val1 val2 ... valn` +// +// where hash is the hash of serialized instruction that is computed internally +// by the `VMInstructionSerializer`. It is used for sanity check before decoding. +// 5 shows opcode of `AllocTensor`, `(dtype.code dtype.bits dtype.lanes)` +// represents a `DLDataType`, `ndim` is the number of dimensions, `dst_register` +// is the destination register, and the rest of it together indicates the shape +// of the tensor to be allocated. +VMInstructionSerializer SerializeInstruction(const Instruction& instr) { + std::vector fields; + // Save the opcode. + DLOG(INFO) << "Serializing: " << instr << std::endl; + switch (instr.op) { + case Opcode::Move: { + // Number of fields = 2 + fields.assign({instr.from, instr.dst}); + break; + } + case Opcode::Ret: { + // Number of fields = 1 + fields.push_back(instr.result); + break; + } + case Opcode::Fatal: { + // Number of fields = 0 + break; + } + case Opcode::InvokePacked: { + // Number of fields = 3 + instr.arity + // Note that arity includes both input arguments and outputs. We will + // put all the `arity` number of fields in the end for serialization. + fields.assign({instr.packed_index, instr.arity, instr.output_size}); + // Save the args. + fields.insert(fields.end(), instr.packed_args, instr.packed_args + instr.arity); + break; + } + case Opcode::AllocTensor: { + // Number of fields = 5 + instr.alloc_tensor.ndim + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + + // The number of dimensions is not needed for constructing an + // `AllocTensor` instruction as it equals to the length of the `shape` + // vector. However, we save it to conveniently deserialize the instruction + // because we will know how many fields are needed by the `shape` argument. + fields.push_back(instr.alloc_tensor.ndim); + fields.push_back(instr.dst); + + // Save the shape of the tensor. + // Note that this field is rotated to the end of the list. + fields.insert(fields.end(), instr.alloc_tensor.shape, + instr.alloc_tensor.shape + instr.alloc_tensor.ndim); + break; + } + case Opcode::AllocTensorReg: { + // Number of fields = 5 + fields.push_back(instr.alloc_tensor_reg.shape_register); + // Save `DLDataType` and the dst register. + const auto& dtype = instr.alloc_tensor.dtype; + fields.assign({dtype.code, dtype.bits, dtype.lanes}); + fields.push_back(instr.dst); + break; + } + case Opcode::AllocDatatype: { + // Number of fields = 3 + instr.num_fields + fields.assign({instr.constructor_tag, instr.num_fields, instr.dst}); + + // Save the fields. + fields.insert(fields.end(), instr.datatype_fields, + instr.datatype_fields + instr.num_fields); + break; + } + case Opcode::AllocClosure: { + // Number of fields = 3 + instr.num_freevar + fields.assign({instr.clo_index, instr.num_freevar, instr.dst}); + + // Save the free vars. + fields.insert(fields.end(), instr.free_vars, + instr.free_vars + instr.num_freevar); + break; + } + case Opcode::If: { + // Number of fields = 4 + fields.assign({instr.if_op.test, + instr.if_op.target, + instr.if_op.true_offset, + instr.if_op.false_offset}); + break; + } + case Opcode::Invoke: { + // Number of fields = 3 + instr.num_args + fields.assign({instr.func_index, instr.num_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.invoke_args_registers, + instr.invoke_args_registers + instr.num_args); + break; + } + case Opcode::InvokeClosure: { + // Number of fields = 3 + instr.num_closure_args + fields.assign({instr.closure, instr.num_closure_args, instr.dst}); + + // Save the args. + fields.insert(fields.end(), instr.closure_args, + instr.closure_args + instr.num_closure_args); + break; + } + case Opcode::LoadConst: { + // Number of fields = 2 + fields.assign({instr.const_index, instr.dst}); + break; + } + case Opcode::LoadConsti: { + // Number of fields = 2 + fields.assign({instr.load_consti.val, instr.dst}); + break; + } + case Opcode::GetField: { + // Number of fields = 3 + fields.assign({instr.object, instr.field_index, instr.dst}); + break; + } + case Opcode::GetTag: { + // Number of fields = 2 + fields.assign({instr.get_tag.object, instr.dst}); + break; + } + case Opcode::Goto: { + // Number of fields = 1 + fields.push_back(instr.pc_offset); + break; + } + default: + LOG(FATAL) << "Invalid opcode" << static_cast(instr.op); + break; + } + + return VMInstructionSerializer(static_cast(instr.op), fields); +} + +void Serializer::SerializeCodeSection() { + // Save the number of functions. + strm_->Write(static_cast(vm_->functions.size())); + for (const auto& func : vm_->functions) { + // Serialize the function info. + VMFunctionSerializer func_format(func.name, + func.register_file_size, + func.instructions.size(), + func.params); + func_format.Save(strm_); + + // Serialize each instruction. + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + serialized_instr.Save(strm_); + } + } +} + +tvm::Array Serializer::GetGlobals() const { + tvm::Array ret; + std::vector > globals(vm_->global_map.begin(), + vm_->global_map.end()); + auto comp = [](const std::pair& a, + const std::pair& b) { + return a.second < b.second; + }; + std::sort(globals.begin(), globals.end(), comp); + for (const auto& it : globals) { + ret.push_back(tvm::ir::StringImm::make(it.first)); + } + return ret; +} + +std::string Serializer::GetBytecode() const { + std::ostringstream oss; + + for (const auto& func : vm_->functions) { + // Print the header of the function format. + oss << "# func name, reg file size, param count, inst count:" + << std::endl; + oss << func.name << " " + << func.register_file_size << " " + << func.params.size() << " " + << func.instructions.size() << std::endl; + + // Print pramams of a `VMFunction`. + oss << "# Parameters:"<< std::endl; + for (const auto& param : func.params) { + oss << param << " "; + } + oss << std::endl; + + // Print the instructions of a `VMFunction`. + // The part after ";" is the instruction in text format. + oss << "hash, opcode, fields # inst(text):"<< std::endl; + for (const auto& instr : func.instructions) { + const auto& serialized_instr = SerializeInstruction(instr); + oss << std::hex << "0x" << serialized_instr.Hash() << " " + << std::dec << serialized_instr.opcode << " "; + for (auto it : serialized_instr.fields) { + oss << it << " "; + } + oss << " # " << instr; + if (oss.str().back() != '\n') oss << std::endl; + } + } + + return oss.str(); +} + +runtime::Module Serializer::GetLib() const { + return vm_->lib; +} + +runtime::Module CreateSerializer(const VirtualMachine* vm) { + std::shared_ptr exec = std::make_shared(); + exec->Init(vm); + return runtime::Module(exec); +} + +TVM_REGISTER_GLOBAL("relay._vm._Serializer") +.set_body([](TVMArgs args, TVMRetValue* rv) { + runtime::Module mod = args[0]; + const auto* vm = dynamic_cast(mod.operator->()); + CHECK(vm) << "Virtual machine has not been defined yet." + << "\n"; + *rv = CreateSerializer(vm); +}); + +} // namespace vm +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/vm/serializer.h b/src/relay/backend/vm/serializer.h new file mode 100644 index 000000000000..2371bb4c94f5 --- /dev/null +++ b/src/relay/backend/vm/serializer.h @@ -0,0 +1,202 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file src/relay/backend/vm/serializer.h + * \brief Define a serializer for the Relay VM. + * + * The following components of a Relay VM will be serialized: + * - The `constants`, e.g., the constant pool, that contains the + * constants used in a Relay program. + * - The `packed_funcs` that essentially contains the generated code for + * a specific target. We return it as a runtime module that can be exported as + * a library file (e.g., .so, .o, or .tar). + * - The `global_map` that contains the globals. + * - The `primitive_map` that contains the name of individual primitive operators. + * - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of + * a list of instructions/bytecode. + * + * Note that only the library is returned as a separate module. All othere parts + * are stored in a single serialized code that is organized with the following + * sections in order. + * - Global section, containing all globals. + * - Constant section, storing the constant pool. + * - Primitive name section, containing the function name of the primitive ops + * used by the virtual machine. + * - Code section, handling the VM functions and bytecode. + * + * The code section is again organized as follows for each VM function: + * func_name, register_file_size, num_instructions (N) + * param1, param2, ..., paramM + * instruction1 + * instruction2 + * ... + * instructionN + * + * Serializing an `Instruction` requires us to deal with the bytecode. Each line + * of the instructions could be serialized as the following format: + * hash, opcode, f1, f2, ..., fX, field with variable length + * 1. hash: the hash of the instruction. This number will be used to help us + * validate if an instruction is well-formed during deserialization. + * 2. opcode: the opcode code of the instruction. + * 3. f1, f2, ..., fX. These fields together represent the fixed fields in + * an instruction, e.g., `from` and `dst` fields of a `Move` instruction. For + * example, `DLDataType` will be unpacked into three fields (code, bits, lanes). + * 4. The rest of the line indicates the field with variable length, e.g., + * the shape of a tensor, the args used by an `InvokPacked` instruction, etc. + */ + +#ifndef TVM_RELAY_BACKEND_VM_SERIALIZER_H_ +#define TVM_RELAY_BACKEND_VM_SERIALIZER_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace vm { + +using namespace tvm::runtime; +using namespace tvm::runtime::vm; + +/*! + * \brief The Relay VM serializer. + */ +class Serializer : public runtime::ModuleNode { + public: + /*! + * \brief Initialize the serializer for a virtual machine. + * + * \param vm The Relay virtual machine. + */ + inline void Init(const VirtualMachine* vm); + + /*! + * \brief Return the member function to the frontend. + * + * \param name The name of the function. + * \param sptr_to_self The pointer to the module node. + * + * \return The corresponding member function. + */ + PackedFunc GetFunction(const std::string& name, + const std::shared_ptr& sptr_to_self) final; + + const char* type_key() const final { return "Serializer"; } + + /*! + * \brief Print the detailed statistics of the given code, i.e. number of + * globls and constants, etc. + */ + std::string Stats() const; + + /*! + * \brief Serialize the `vm_` into global section, constant section, and code + * section. + * + * \return The binary representation of the VM. + */ + TVMByteArray Serialize(); + + /*! + * \brief Get a list of the globals used by the `_vm`. + * + * \return The global map in the form a list. + */ + tvm::Array GetGlobals() const; + + /*! + * \brief Get the primitive operators that are contained in the Relay VM. + * + * \return The list of primitve operators. + */ + tvm::Array GetPrimitiveOps() const; + + /*! + * \brief Get the serialized form of the `functions` in `vm_`. This is + * essentially bytecode serialization. + * + * \return The serialized vm bytecode. + * + * \note The bytecode is in the following format: + * func_name reg_file_size num_instructions + * param1 param2 ... paramM + * instruction1 + * instruction2 + * ... + * instructionN + * + * Each instruction is printed in the following format: + * opcode num_fields field1 ... fieldX # The text format. + * + * The field starting from # is only used for debugging. The serialized code + * doesn't contain it, therefore the deserializer doens't need to handle it. + */ + std::string GetBytecode() const; + + /*! \brief Get the `lib` module in vm_. Serialization of `runtime::module` + * has already been supported by TVM. Therefore, we only return the runtime + * module and let users have the flexibility to call `export_library` from + * the frontend to save the library to disk. + * + * \return The runtime module that contains the hardwre dependent code. + */ + inline runtime::Module GetLib() const; + + virtual ~Serializer() { delete strm_; } + + private: + /*! \brief Serialize the globals in vm_. */ + void SerializeGlobalSection(); + + /*! \brief Serialize the constant pool in vm_. */ + void SerializeConstantSection(); + + /*! \brief Serialize primitive op names in vm_. */ + void SerializePrimitiveOpNames(); + + /*! \brief Serialize the vm functions in vm_. */ + void SerializeCodeSection(); + + /*! \brief The Relay virtual machine for to be serialized. */ + const VirtualMachine* vm_; + + /*! \brief The stream used for serialization. */ + dmlc::Stream* strm_; + + /*! \brief The serialized code. */ + std::string code_; +}; + +} // namespace vm +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_BACKEND_VM_SERIALIZER_H_ diff --git a/src/runtime/vm/vm.cc b/src/runtime/vm/vm.cc index b79c45526941..e9b567bed584 100644 --- a/src/runtime/vm/vm.cc +++ b/src/runtime/vm/vm.cc @@ -23,6 +23,7 @@ * \brief The Relay virtual machine. */ +#include #include #include @@ -91,8 +92,8 @@ Instruction::Instruction(const Instruction& instr) { return; case Opcode::InvokeClosure: this->closure = instr.closure; - this->closure_args_num = instr.closure_args_num; - this->closure_args = Duplicate(instr.closure_args, instr.closure_args_num); + this->num_closure_args = instr.num_closure_args; + this->closure_args = Duplicate(instr.closure_args, instr.num_closure_args); return; case Opcode::Invoke: this->func_index = instr.func_index; @@ -179,9 +180,9 @@ Instruction& Instruction::operator=(const Instruction& instr) { return *this; case Opcode::InvokeClosure: this->closure = instr.closure; - this->closure_args_num = instr.closure_args_num; + this->num_closure_args = instr.num_closure_args; FreeIf(this->closure_args); - this->closure_args = Duplicate(instr.closure_args, instr.closure_args_num); + this->closure_args = Duplicate(instr.closure_args, instr.num_closure_args); return *this; case Opcode::Invoke: this->func_index = instr.func_index; @@ -262,7 +263,9 @@ Instruction Instruction::Fatal() { return instr; } -Instruction Instruction::InvokePacked(Index packed_index, Index arity, Index output_size, +Instruction Instruction::InvokePacked(Index packed_index, + Index arity, + Index output_size, const std::vector& args) { Instruction instr; instr.op = Opcode::InvokePacked; @@ -380,7 +383,7 @@ Instruction Instruction::InvokeClosure(RegName closure, const std::vector -std::string StrJoin(T* items, int offset, int cnt, std::string delim = ",") { +std::string StrJoin(T* items, int offset, int cnt, std::string delim = ", ") { if (cnt == 0) { return ""; } @@ -447,11 +450,11 @@ std::string StrJoin(T* items, int offset, int cnt, std::string delim = ",") { void InstructionPrint(std::ostream& os, const Instruction& instr) { switch (instr.op) { case Opcode::Move: { - os << "move $" << instr.dst << " $" << instr.from; + os << "move $" << instr.dst << " $" << instr.from << std::endl; break; } case Opcode::Ret: { - os << "ret $" << instr.result; + os << "ret $" << instr.result << std::endl; break; } case Opcode::Fatal: { @@ -459,74 +462,86 @@ void InstructionPrint(std::ostream& os, const Instruction& instr) { break; } case Opcode::InvokePacked: { - os << "invoke_packed PackedFunc[" << instr.packed_index << "](in: $" - << StrJoin(instr.packed_args, 0, instr.arity - instr.output_size, ",$") + os << "invoke_packed PackedFunc[" << instr.packed_index << "] (in: $" + << StrJoin(instr.packed_args, 0, + instr.arity - instr.output_size, ", $") << ", out: $" << StrJoin(instr.packed_args, instr.arity - instr.output_size, - instr.output_size, ",$") - << ")"; + instr.output_size, ", $") + << ")" << std::endl; break; } case Opcode::AllocTensor: { os << "alloc_tensor $" << instr.dst << " [" - << StrJoin(instr.alloc_tensor.shape, 0, instr.alloc_tensor.ndim) + << StrJoin(instr.alloc_tensor.shape, 0, + instr.alloc_tensor.ndim) << "] "; DLDatatypePrint(os, instr.alloc_tensor.dtype); + os << std::endl; break; } case Opcode::AllocTensorReg: { os << "alloc_tensor_reg $" << instr.dst << " $" << instr.alloc_tensor_reg.shape_register << " "; DLDatatypePrint(os, instr.alloc_tensor_reg.dtype); + os << std::endl; break; } case Opcode::AllocDatatype: { os << "alloc_data $" << instr.dst << " tag(" << instr.constructor_tag << ") [$" - << StrJoin(instr.datatype_fields, 0, instr.num_fields, ",$") << "]"; + << StrJoin(instr.datatype_fields, 0, instr.num_fields, ",$") << "]" + << std::endl; break; } case Opcode::AllocClosure: { os << "alloc_closure $" << instr.dst << " VMFunc[" << instr.clo_index << "]($" << StrJoin(instr.free_vars, 0, instr.num_freevar, ",$") - << ")"; + << ")" + << std::endl; break; } case Opcode::If: { os << "if " << "$" << instr.if_op.test << " " << instr.if_op.target << " " - << instr.if_op.true_offset << " " << instr.if_op.false_offset; + << instr.if_op.true_offset << " " << instr.if_op.false_offset + << std::endl; break; } case Opcode::Invoke: { os << "invoke $" << instr.dst << " VMFunc[" << instr.func_index << "]($" << StrJoin(instr.invoke_args_registers, 0, instr.num_args, ",$") - << ")"; + << ")" + << std::endl; break; } case Opcode::InvokeClosure: { os << "invoke_closure $" << instr.dst << " $" << instr.closure << "($" - << StrJoin(instr.closure_args, 0, instr.closure_args_num, ",$") - << ")"; + << StrJoin(instr.closure_args, 0, instr.num_closure_args, ",$") + << ")" + << std::endl; break; } case Opcode::LoadConst: { - os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]"; + os << "load_const $" << instr.dst << " Const[" << instr.const_index << "]" + << std::endl; break; } case Opcode::LoadConsti: { - os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]"; + os << "load_consti $" << instr.dst << " Const[" << instr.load_consti.val << "]" + << std::endl; break; } case Opcode::GetField: { os << "get_field $" << instr.dst << " $" << instr.object << "[" - << instr.field_index << "]"; + << instr.field_index << "]" + << std::endl; break; } case Opcode::GetTag: { - os << "get_tag $" << instr.dst << " $" << instr.get_tag.object; + os << "get_tag $" << instr.dst << " $" << instr.get_tag.object << std::endl; break; } case Opcode::Goto: { - os << "goto " << instr.pc_offset; + os << "goto " << instr.pc_offset << std::endl; break; } default: @@ -564,6 +579,23 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, Object obj = args[i]; func_args.push_back(obj); } + auto it = std::find_if(functions.begin(), functions.end(), + [func_name](const VMFunction& func) { + return func.name == func_name; + }); + CHECK(it != functions.end()) << "Cannot find function " << func_name << "\n"; + CHECK_EQ(func_args.size() + params_.size(), it->params.size()) + << "The number of provided parameters doesn't match the number of arguments" + << "\n"; + if (!params_.empty()) { + for (const auto& p : it->params) { + const auto& pit = params_.find(p); + if (pit != params_.end()) { + func_args.push_back(pit->second); + } + } + CHECK_EQ(func_args.size(), it->params.size()); + } *rv = this->Invoke(func_name, func_args); }); } else if (name == "init") { @@ -579,12 +611,40 @@ PackedFunc VirtualMachine::GetFunction(const std::string& name, } this->Init(contexts); }); + } else if (name == "load_params") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + this->LoadParams(args[0].operator std::string()); + }); } else { LOG(FATAL) << "Unknown packed function: " << name; return PackedFunc([sptr_to_self, name](TVMArgs args, TVMRetValue* rv) {}); } } +void VirtualMachine::LoadParams(const std::string& params) { + dmlc::MemoryStringStream mss(const_cast(¶ms)); + dmlc::Stream* strm = &mss; + uint64_t header, reserved; + CHECK(strm->Read(&header)) << "Invalid parameter file"; + CHECK(header == kTVMNDArrayListMagic) << "Invalid parameter file"; + CHECK(strm->Read(&reserved)) << "Invalid parameter file"; + + std::vector names; + CHECK(strm->Read(&names)) << "Invalid parameter file"; + + uint64_t sz; + strm->Read(&sz); + size_t size = static_cast(sz); + CHECK(size == names.size()) << "Invalid parameter file"; + + for (size_t i = 0; i < size; i++) { + NDArray arr; + CHECK(arr.Load(strm)) << "Invalid parameter file"; + runtime::Object obj = runtime::Object::Tensor(arr); + params_.emplace(std::make_pair(names[i], obj)); + } +} + void VirtualMachine::PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func) { auto frame = VMFrame(ret_pc, func_index, arg_count, code, vm_func.register_file_size); frames.push_back(frame); @@ -662,7 +722,22 @@ void InvokePacked(const PackedFunc& func, Index arg_count, Index output_size, func.CallPacked(TVMArgs(values.data(), codes.data(), arity), &rv); } -void VirtualMachine::Init(const std::vector& ctxs) { this->ctxs = ctxs; } +void VirtualMachine::Init(const std::vector& ctxs) { + this->ctxs = ctxs; + + // Get the list of packed functions. + CHECK(primitive_map.empty() || lib.operator->()) + << "runtime module should have been built for primitive functions" + << "\n"; + for (const auto& it : primitive_map) { + const auto& packed_name = it.first; + auto packed_index = static_cast(it.second); + if (packed_funcs.size() <= packed_index) { + packed_funcs.resize(packed_index + 1); + } + packed_funcs[packed_index] = lib.GetFunction(packed_name); + } +} inline void VirtualMachine::WriteRegister(Index r, const Object& val) { frames.back().register_file[r] = val; @@ -716,8 +791,8 @@ void VirtualMachine::Run() { goto main_loop; } case Opcode::LoadConsti: { - auto tensor = NDArray::Empty({1}, {kDLInt, 32, 1}, {kDLCPU, 0}); - reinterpret_cast(tensor->data)[0] = instr.load_consti.val; + auto tensor = NDArray::Empty({1}, {kDLInt, 64, 1}, {kDLCPU, 0}); + reinterpret_cast(tensor->data)[0] = instr.load_consti.val; WriteRegister(instr.dst, Object::Tensor(tensor)); pc++; goto main_loop; @@ -753,7 +828,7 @@ void VirtualMachine::Run() { for (auto free_var : closure->free_vars) { args.push_back(free_var); } - for (Index i = 0; i < instr.closure_args_num; ++i) { + for (Index i = 0; i < instr.num_closure_args; ++i) { args.push_back(ReadRegister(instr.closure_args[i])); } InvokeGlobal(this->functions[closure->func_index], args); diff --git a/tests/python/relay/test_vm_serialization.py b/tests/python/relay/test_vm_serialization.py new file mode 100644 index 000000000000..9a8ab2d87444 --- /dev/null +++ b/tests/python/relay/test_vm_serialization.py @@ -0,0 +1,356 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, missing-docstring, no-else-return +"""Unit tests for the Relay VM serialization and deserialization.""" +import numpy as np + +import tvm +from tvm import relay +from tvm.relay.module import Module as rly_module +from tvm.relay import vm as _vm +from tvm.relay import serializer, deserializer +from tvm.relay.scope_builder import ScopeBuilder +from tvm.relay.prelude import Prelude +from tvm.contrib import util +from tvm.relay import testing + +def create_vm(f, ctx=tvm.cpu(), target="llvm"): + if isinstance(f, relay.Expr): + mod = relay.Module() + mod["main"] = f + compiler = relay.vm.VMCompiler() + vm = compiler.compile(mod, target) + vm.init(ctx) + return vm + else: + assert isinstance(f, relay.Module), "expected mod as relay.Module" + compiler = relay.vm.VMCompiler() + vm = compiler.compile(f, target) + vm.init(ctx) + return vm + + +def veval(vm, *args, ctx=tvm.cpu()): + assert isinstance(vm, _vm.VirtualMachine), "expected VirtualMachine" + vm.init(ctx) + ret = vm.run(*args) + return ret + + +def run_network(mod, + params, + data_shape=(1, 3, 224, 224), + dtype='float32'): + def get_vm_output(mod, data, params, target, ctx, dtype='float32'): + ex = relay.create_executor('vm', mod=mod, ctx=ctx) + result = ex.evaluate()(data, **params) + return result.asnumpy().astype(dtype) + + def get_serialized_output(mod, data, params, target, ctx, dtype='float32'): + vm = create_vm(mod, ctx, target) + ser = serializer.Serializer(vm) + code, lib = ser.serialize() + deser = deserializer.Deserializer(code, lib) + des_vm = deser.deserialize() + des_vm.init(ctx) + des_vm.load_params(params) + result = des_vm.run(data) + return result.asnumpy().astype(dtype) + + data = np.random.uniform(size=data_shape).astype(dtype) + target = "llvm" + ctx = tvm.cpu(0) + + tvm_out = get_vm_output(mod, tvm.nd.array(data.astype(dtype)), params, + target, ctx, dtype) + vm_out = get_serialized_output(mod, tvm.nd.array(data.astype(dtype)), params, + target, ctx, dtype) + tvm.testing.assert_allclose(vm_out, tvm_out, rtol=1e-5, atol=1e-5) + + +def test_serializer(): + mod = rly_module({}) + a = relay.const(1.0, "float32") + x = relay.var('x', shape=(10, 10), dtype='float32') + f1 = relay.Function([x], x + a) + glb_f1 = relay.GlobalVar("f1") + mod[glb_f1] = f1 + + b = relay.const(2.0, "float32") + y = relay.var('y', shape=(10, 10), dtype='float32') + f2 = relay.Function([y], y - b) + glb_f2 = relay.GlobalVar("f2") + mod[glb_f2] = f2 + + x1 = relay.var('x1', shape=(10, 10), dtype='float32') + y1 = relay.var('y1', shape=(10, 10), dtype='float32') + main = relay.Function([x1, y1], glb_f1(x1) * glb_f2(y1)) + mod["main"] = main + + vm = create_vm(mod) + ser = serializer.Serializer(vm) + + stats = ser.stats + assert "scalar" in stats + + glbs = ser.globals + assert len(glbs) == 3 + assert "f1" in glbs + assert "f2" in glbs + assert "main" in glbs + + prim_ops = ser.primitive_ops + assert any(item.startswith('fused_add') for item in prim_ops) + assert any(item.startswith('fused_subtract') for item in prim_ops) + assert any(item.startswith('fused_multiply') for item in prim_ops) + + code = ser.bytecode + assert "main 5 2 5" in code + assert "f1 3 1 4" in code + assert "f2 3 1 4" in code + + code, lib = ser.serialize() + assert isinstance(code, bytearray) + assert isinstance(lib, tvm.module.Module) + + +def test_save_load(): + x = relay.var('x', shape=(10, 10)) + f = relay.Function([x], x + x) + x_data = np.random.rand(10, 10).astype('float32') + + # serialize. + vm = create_vm(f) + ser = serializer.Serializer(vm) + code, lib = ser.serialize() + assert isinstance(code, bytearray) + + # save and load the code and lib file. + tmp = util.tempdir() + path_lib = tmp.relpath("lib.so") + lib.export_library(path_lib) + with open(tmp.relpath("code.bc"), "wb") as fo: + fo.write(code) + + loaded_lib = tvm.module.load(path_lib) + loaded_code = bytearray(open(tmp.relpath("code.bc"), "rb").read()) + + # deserialize. + deser = deserializer.Deserializer(loaded_code, loaded_lib) + des_vm = deser.deserialize() + + res = veval(des_vm, x_data) + tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) + + +def test_const(): + c = relay.const(1.0, "float32") + x = relay.var('x', shape=(10, 10), dtype='float32') + f = relay.Function([x], x + c) + vm = create_vm(f) + ser = serializer.Serializer(vm) + code, lib = ser.serialize() + assert isinstance(code, bytearray) + deser = deserializer.Deserializer(code, lib) + des_vm = deser.deserialize() + x_data = np.random.rand(10, 10).astype('float32') + res = veval(des_vm, x_data) + tvm.testing.assert_allclose(res.asnumpy(), x_data + 1) + + +def test_if(): + x = relay.var('x', shape=(10, 10)) + y = relay.var('y', shape=(10, 10)) + equal = relay.op.equal(x, y) + equal = relay.op.nn.batch_flatten(equal) + f = relay.Function([x, y], relay.If(relay.op.min(equal, axis=[0, 1]), x, + y)) + x_data = np.random.rand(10, 10).astype('float32') + y_data = np.random.rand(10, 10).astype('float32') + + vm = create_vm(f) + ser = serializer.Serializer(vm) + code, lib = ser.serialize() + deser = deserializer.Deserializer(code, lib) + des_vm = deser.deserialize() + + # same + res = veval(des_vm, x_data, x_data) + tvm.testing.assert_allclose(res.asnumpy(), x_data) + + # diff + res = veval(des_vm, x_data, y_data) + tvm.testing.assert_allclose(res.asnumpy(), y_data) + + +def test_loop(): + mod = relay.module.Module({}) + sum_up = relay.GlobalVar('sum_up') + i = relay.var('i', shape=[], dtype='int32') + accum = relay.var('accum', shape=[], dtype='int32') + sb = ScopeBuilder() + with sb.if_scope(relay.equal(i, relay.const(0, 'int32'))): + sb.ret(accum) + with sb.else_scope(): + one_less = relay.subtract(i, relay.const(1, 'int32')) + new_accum = relay.add(accum, i) + sb.ret(relay.Call(sum_up, [one_less, new_accum])) + func = relay.Function([i, accum], sb.get()) + mod[sum_up] = func + loop_bound = 0 + i_data = np.array(loop_bound, dtype='int32') + accum_data = np.array(0, dtype='int32') + iarg = relay.var('i', shape=[], dtype='int32') + aarg = relay.var('accum', shape=[], dtype='int32') + mod["main"] = relay.Function([iarg, aarg], sum_up(iarg, aarg)) + + vm = create_vm(mod) + ser = serializer.Serializer(vm) + code, lib = ser.serialize() + deser = deserializer.Deserializer(code, lib) + des_vm = deser.deserialize() + + result = veval(des_vm, i_data, accum_data) + tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, loop_bound + 1))) + + +def test_tuple(): + ttype = relay.TupleType([relay.TensorType((1,)), relay.TensorType((10,))]) + tup = relay.var('tup', type_annotation=ttype) + f = relay.Function([tup], relay.TupleGetItem(tup, 1)) + i_data = np.random.rand(41).astype('float32') + j_data = np.random.rand(10).astype('float32') + + vm = create_vm(f) + ser = serializer.Serializer(vm) + code, lib = ser.serialize() + deser = deserializer.Deserializer(code, lib) + des_vm = deser.deserialize() + + result = veval(des_vm, (i_data, j_data)) + tvm.testing.assert_allclose(result.asnumpy(), j_data) + + +def test_adt_list(): + mod = relay.Module() + p = Prelude(mod) + + l1 = p.cons(relay.const(1), p.nil()) + l21 = p.cons(relay.const(2), l1) + l321 = p.cons(relay.const(3), l21) + + f = relay.Function([], l321) + mod["main"] = f + + vm = create_vm(mod) + ser = serializer.Serializer(vm) + code, lib = ser.serialize() + deser = deserializer.Deserializer(code, lib) + des_vm = deser.deserialize() + + result = veval(des_vm) + assert len(result) == 2 + assert len(result[1]) == 2 + assert len(result[1][1]) == 2 + res = [] + res.append(result[0].asnumpy().tolist()) + res.append(result[1][0].asnumpy().tolist()) + res.append(result[1][1][0].asnumpy().tolist()) + tvm.testing.assert_allclose(res, np.array([3, 2, 1])) + + +def test_adt_compose(): + mod = relay.Module() + p = Prelude(mod) + + compose = p.compose + + # add_one = fun x -> x + 1 + sb = relay.ScopeBuilder() + x = relay.var('x', 'float32') + x1 = sb.let('x1', x) + xplusone = x1 + relay.const(1.0, 'float32') + sb.ret(xplusone) + body = sb.get() + add_one = relay.GlobalVar("add_one") + add_one_func = relay.Function([x], body) + + # add_two = compose(add_one, add_one) + sb = relay.ScopeBuilder() + y = relay.var('y', 'float32') + add_two_func = sb.let('add_two', compose(add_one_func, add_one_func)) + add_two_res = add_two_func(y) + sb.ret(add_two_res) + add_two_body = sb.get() + + mod[add_one] = add_one_func + + f = relay.Function([y], add_two_body) + mod["main"] = f + + vm = create_vm(mod) + ser = serializer.Serializer(vm) + code, lib = ser.serialize() + deser = deserializer.Deserializer(code, lib) + des_vm = deser.deserialize() + + x_data = np.array(np.random.rand()).astype('float32') + result = veval(des_vm, x_data) + + tvm.testing.assert_allclose(result.asnumpy(), x_data + 2.0) + + +def test_closure(): + x = relay.var('x', shape=()) + y = relay.var('y', shape=()) + f = relay.Function([x], x + y) + ff = relay.Function([y], f) + clo = ff(relay.const(1.0)) + main = clo(relay.const(2.0)) + + vm = create_vm(main) + ser = serializer.Serializer(vm) + code, lib = ser.serialize() + deser = deserializer.Deserializer(code, lib) + des_vm = deser.deserialize() + + res = veval(des_vm) + tvm.testing.assert_allclose(res.asnumpy(), 3.0) + + +def test_resnet(): + mod, params = testing.resnet.get_workload(batch_size=1, num_layers=18) + run_network(mod, params) + + +def test_mobilenet(): + mod, params = testing.mobilenet.get_workload(batch_size=1) + run_network(mod, params) + + +if __name__ == "__main__": + test_serializer() + test_save_load() + test_const() + test_if() + test_loop() + test_tuple() + test_adt_list() + test_adt_compose() + test_closure() + test_resnet() + test_mobilenet()