diff --git a/src/runtime/hexagon/hexagon/hexagon_common.cc b/src/runtime/hexagon/hexagon/hexagon_common.cc index 00d74f90111e..9aee341d64b8 100644 --- a/src/runtime/hexagon/hexagon/hexagon_common.cc +++ b/src/runtime/hexagon/hexagon/hexagon_common.cc @@ -47,72 +47,6 @@ namespace tvm { namespace runtime { namespace hexagon { -void HexagonLookupLinkedParam(TVMArgs args, TVMRetValue* rv) { - Module mod = args[0]; - int64_t storage_id = args[1]; - DLTensor* template_tensor = args[2]; - Device dev = args[3]; - auto lookup_linked_param = mod.GetFunction(::tvm::runtime::symbol::tvm_lookup_linked_param, true); - if (lookup_linked_param == nullptr) { - *rv = nullptr; - return; - } - - TVMRetValue opaque_handle = lookup_linked_param(storage_id); - if (opaque_handle.type_code() == kTVMNullptr) { - *rv = nullptr; - return; - } - - std::vector shape_vec{template_tensor->shape, - template_tensor->shape + template_tensor->ndim}; - - Optional scope("global"); - auto* param_buffer = - new HexagonBuffer(static_cast(opaque_handle), GetDataSize(*template_tensor), scope); - auto* container = new NDArray::Container(static_cast(param_buffer), shape_vec, - template_tensor->dtype, dev); - container->SetDeleter([](Object* container) { - // The NDArray::Container needs to be deleted - // along with the HexagonBuffer wrapper. However the - // buffer's data points to global const memory and - // so should not be deleted. - auto* ptr = static_cast(container); - delete static_cast(ptr->dl_tensor.data); - delete ptr; - }); - *rv = NDArray(GetObjectPtr(container)); -} - -PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& sptr_to_self) { - return PackedFunc([faddr, sptr_to_self](TVMArgs args, TVMRetValue* rv) { - TVMValue ret_value; - int ret_type_code = kTVMNullptr; - - TVMValue* arg_values = const_cast(args.values); - std::vector> buffer_args; - for (int i = 0; i < args.num_args; i++) { - if (args.type_codes[i] == kTVMDLTensorHandle) { - DLTensor* tensor = static_cast(arg_values[i].v_handle); - buffer_args.emplace_back(i, static_cast(tensor->data)); - HexagonBuffer* hexbuf = buffer_args.back().second; - tensor->data = hexbuf->GetPointer(); - } - } - int ret = (*faddr)(const_cast(args.values), const_cast(args.type_codes), - args.num_args, &ret_value, &ret_type_code, nullptr); - ICHECK_EQ(ret, 0) << TVMGetLastError(); - - for (auto& arg : buffer_args) { - DLTensor* tensor = static_cast(arg_values[arg.first].v_handle); - tensor->data = arg.second; - } - - if (ret_type_code != kTVMNullptr) { - *rv = TVMRetValue::MoveFromCHost(ret_value, ret_type_code); - } - }); -} #if defined(__hexagon__) class HexagonTimerNode : public TimerNode { @@ -165,12 +99,9 @@ void LogMessageImpl(const std::string& file, int lineno, const std::string& mess } } // namespace detail -TVM_REGISTER_GLOBAL("tvm.runtime.hexagon.lookup_linked_params") - .set_body(hexagon::HexagonLookupLinkedParam); - TVM_REGISTER_GLOBAL("runtime.module.loadfile_hexagon").set_body([](TVMArgs args, TVMRetValue* rv) { ObjectPtr n = CreateDSOLibraryObject(args[0]); - *rv = CreateModuleFromLibrary(n, hexagon::WrapPackedFunc); + *rv = CreateModuleFromLibrary(n); }); } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon/hexagon_common.h b/src/runtime/hexagon/hexagon/hexagon_common.h index e1eca72627a5..9e534bdaf1a9 100644 --- a/src/runtime/hexagon/hexagon/hexagon_common.h +++ b/src/runtime/hexagon/hexagon/hexagon_common.h @@ -44,20 +44,6 @@ } \ } while (0) -namespace tvm { -namespace runtime { -namespace hexagon { - -/*! \brief Unpack HexagonBuffers in packed functions - * prior to invoking. - * \param faddr The function address. - * \param mptr The module pointer node. - * \return A packed function wrapping the requested function. - */ -PackedFunc WrapPackedFunc(TVMBackendPackedCFunc faddr, const ObjectPtr& mptr); -} // namespace hexagon -} // namespace runtime -} // namespace tvm inline bool IsHexagonDevice(DLDevice dev) { return TVMDeviceExtType(dev.device_type) == kDLHexagon; } diff --git a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc index 2804b2d837a5..ea1cf18f3cc0 100644 --- a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc +++ b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.cc @@ -74,11 +74,11 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, int ndim, const int64_t* sh if (ndim == 1) { size_t nbytes = shape[0] * typesize; - return new HexagonBuffer(nbytes, alignment, mem_scope); + return AllocateHexagonBuffer(nbytes, alignment, mem_scope); } else if (ndim == 2) { size_t nallocs = shape[0]; size_t nbytes = shape[1] * typesize; - return new HexagonBuffer(nallocs, nbytes, alignment, mem_scope); + return AllocateHexagonBuffer(nallocs, nbytes, alignment, mem_scope); } else { LOG(FATAL) << "Hexagon Device API supports only 1d and 2d allocations, but received ndim = " << ndim; @@ -94,16 +94,14 @@ void* HexagonDeviceAPIv2::AllocDataSpace(Device dev, size_t nbytes, size_t align if (alignment < kHexagonAllocAlignment) { alignment = kHexagonAllocAlignment; } - return new HexagonBuffer(nbytes, alignment, String("global")); + return AllocateHexagonBuffer(nbytes, alignment, String("global")); } void HexagonDeviceAPIv2::FreeDataSpace(Device dev, void* ptr) { bool is_valid_device = (TVMDeviceExtType(dev.device_type) == kDLHexagon) || (DLDeviceType(dev.device_type) == kDLCPU); CHECK(is_valid_device) << "dev.device_type: " << dev.device_type; - auto* hexbuf = static_cast(ptr); - CHECK(hexbuf != nullptr); - delete hexbuf; + FreeHexagonBuffer(ptr); } // WorkSpace: runtime allocations for Hexagon @@ -114,21 +112,14 @@ struct HexagonWorkspacePool : public WorkspacePool { void* HexagonDeviceAPIv2::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; - auto* hexbuf = static_cast( - dmlc::ThreadLocalStore::Get()->AllocWorkspace(dev, size)); - - void* ptr = hexbuf->GetPointer(); - workspace_allocations_.insert({ptr, hexbuf}); - return ptr; + return dmlc::ThreadLocalStore::Get()->AllocWorkspace(dev, size); } void HexagonDeviceAPIv2::FreeWorkspace(Device dev, void* data) { CHECK(TVMDeviceExtType(dev.device_type) == kDLHexagon) << "dev.device_type: " << dev.device_type; - auto it = workspace_allocations_.find(data); - CHECK(it != workspace_allocations_.end()) + CHECK(hexagon_buffer_map_.count(data) != 0) << "Attempt made to free unknown or already freed workspace allocation"; - dmlc::ThreadLocalStore::Get()->FreeWorkspace(dev, it->second); - workspace_allocations_.erase(it); + dmlc::ThreadLocalStore::Get()->FreeWorkspace(dev, data); } void* HexagonDeviceAPIv2::AllocVtcmWorkspace(Device dev, int ndim, const int64_t* shape, @@ -148,21 +139,26 @@ void HexagonDeviceAPIv2::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamH CHECK_EQ(to->byte_offset, 0); CHECK_EQ(GetDataSize(*from), GetDataSize(*to)); - HexagonBuffer* hex_from_buf = static_cast(from->data); - HexagonBuffer* hex_to_buf = static_cast(to->data); + auto lookup_hexagon_buffer = [this](void* ptr) -> HexagonBuffer* { + auto it = this->hexagon_buffer_map_.find(ptr); + CHECK(it != this->hexagon_buffer_map_.end()) + << "Lookup failed for non-HexagonBuffer allocation, CopyDataFromTo can only copy data " + "from, to or between HexagonBuffers"; + return it->second.get(); + }; if (TVMDeviceExtType(from->device.device_type) == kDLHexagon && TVMDeviceExtType(to->device.device_type) == kDLHexagon) { - CHECK(hex_from_buf != nullptr); - CHECK(hex_to_buf != nullptr); + HexagonBuffer* hex_from_buf = lookup_hexagon_buffer(from->data); + HexagonBuffer* hex_to_buf = lookup_hexagon_buffer(to->data); hex_to_buf->CopyFrom(*hex_from_buf, GetDataSize(*from)); } else if (from->device.device_type == kDLCPU && TVMDeviceExtType(to->device.device_type) == kDLHexagon) { - CHECK(hex_to_buf != nullptr); + HexagonBuffer* hex_to_buf = lookup_hexagon_buffer(to->data); hex_to_buf->CopyFrom(from->data, GetDataSize(*from)); } else if (TVMDeviceExtType(from->device.device_type) == kDLHexagon && to->device.device_type == kDLCPU) { - CHECK(hex_from_buf != nullptr); + HexagonBuffer* hex_from_buf = lookup_hexagon_buffer(from->data); hex_from_buf->CopyTo(to->data, GetDataSize(*to)); } else { CHECK(false) @@ -177,6 +173,14 @@ void HexagonDeviceAPIv2::CopyDataFromTo(const void* from, size_t from_offset, vo memcpy(static_cast(to) + to_offset, static_cast(from) + from_offset, size); } +void HexagonDeviceAPIv2::FreeHexagonBuffer(void* ptr) { + auto it = hexagon_buffer_map_.find(ptr); + CHECK(it != hexagon_buffer_map_.end()) + << "Attempt made to free unknown or already freed dataspace allocation"; + CHECK(it->second != nullptr); + hexagon_buffer_map_.erase(it); +} + TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVMRetValue* rv) { void* dst = args[0]; void* src = args[1]; @@ -187,8 +191,6 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.mem_copy").set_body([](TVMArgs args, TVM *rv = static_cast(0); }); -std::map vtcmallocs; - TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVMRetValue* rv) { int32_t device_type = args[0]; int32_t device_id = args[1]; @@ -210,12 +212,7 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.alloc_nd").set_body([](TVMArgs args, TVM type_hint.lanes = 1; HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global(); - HexagonBuffer* hexbuf = reinterpret_cast( - hexapi->AllocVtcmWorkspace(dev, ndim, shape, type_hint, String(scope))); - - void* ptr = hexbuf->GetPointer(); - vtcmallocs[ptr] = hexbuf; - *rv = ptr; + *rv = hexapi->AllocVtcmWorkspace(dev, ndim, shape, type_hint, String(scope)); }); TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd").set_body([](TVMArgs args, TVMRetValue* rv) { @@ -224,17 +221,13 @@ TVM_REGISTER_GLOBAL("device_api.hexagon.free_nd").set_body([](TVMArgs args, TVMR std::string scope = args[2]; CHECK(scope.find("global.vtcm") != std::string::npos); void* ptr = args[3]; - CHECK(vtcmallocs.find(ptr) != vtcmallocs.end()); - - HexagonBuffer* hexbuf = vtcmallocs[ptr]; - vtcmallocs.erase(ptr); Device dev; dev.device_type = static_cast(device_type); dev.device_id = device_id; HexagonDeviceAPIv2* hexapi = HexagonDeviceAPIv2::Global(); - hexapi->FreeVtcmWorkspace(dev, hexbuf); + hexapi->FreeVtcmWorkspace(dev, ptr); *rv = static_cast(0); }); diff --git a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h index 43f4272f1943..96805e55bb1f 100644 --- a/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h +++ b/src/runtime/hexagon/hexagon/hexagon_device_api_v2.h @@ -23,16 +23,18 @@ #include #include +#include #include #include +#include #include +#include "hexagon_buffer.h" + namespace tvm { namespace runtime { namespace hexagon { -class HexagonBuffer; - /*! * \brief Hexagon Device API that is compiled and run on Hexagon. */ @@ -70,7 +72,7 @@ class HexagonDeviceAPIv2 final : public DeviceAPI { */ void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; - //! Dereference workspace pool and erase from tracked workspace_allocations_. + //! Erase from tracked hexagon_buffer_map and free void FreeWorkspace(Device dev, void* data) final; /*! @@ -125,8 +127,23 @@ class HexagonDeviceAPIv2 final : public DeviceAPI { TVMStreamHandle stream) final; private: - //! Lookup table for the HexagonBuffer managing a workspace allocation. - std::unordered_map workspace_allocations_; + /*! \brief Helper to allocate a HexagonBuffer and register the result + * in the owned buffer map. + * \return Raw data storage managed by the hexagon buffer + */ + template + void* AllocateHexagonBuffer(Args&&... args) { + auto buf = std::make_unique(std::forward(args)...); + void* ptr = buf->GetPointer(); + hexagon_buffer_map_.insert({ptr, std::move(buf)}); + return ptr; + } + /*! \brief Helper to free a HexagonBuffer and unregister the result + * from the owned buffer map. + */ + void FreeHexagonBuffer(void* ptr); + //! Lookup table for the HexagonBuffer managing an allocation. + std::unordered_map> hexagon_buffer_map_; }; } // namespace hexagon } // namespace runtime diff --git a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc index 1bd2a8e16a44..d14b178cf7d7 100644 --- a/src/runtime/hexagon/rpc/hexagon/rpc_server.cc +++ b/src/runtime/hexagon/rpc/hexagon/rpc_server.cc @@ -288,5 +288,5 @@ TVM_REGISTER_GLOBAL("tvm.hexagon.load_module") .set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) { std::string soname = args[0]; tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); - *rv = CreateModuleFromLibrary(n, tvm::runtime::hexagon::WrapPackedFunc); + *rv = CreateModuleFromLibrary(n); }); diff --git a/src/runtime/hexagon/rpc/simulator/rpc_server.cc b/src/runtime/hexagon/rpc/simulator/rpc_server.cc index ec04df46b341..76f168cd20ad 100644 --- a/src/runtime/hexagon/rpc/simulator/rpc_server.cc +++ b/src/runtime/hexagon/rpc/simulator/rpc_server.cc @@ -321,5 +321,5 @@ TVM_REGISTER_GLOBAL("tvm.hexagon.load_module") .set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue* rv) { std::string soname = args[0]; tvm::ObjectPtr n = tvm::runtime::CreateDSOLibraryObject(soname); - *rv = CreateModuleFromLibrary(n, tvm::runtime::hexagon::WrapPackedFunc); + *rv = CreateModuleFromLibrary(n); });