diff --git a/.clang-tidy b/.clang-tidy index 8aac1d0b25e..33505432a60 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -1,5 +1,6 @@ Checks: '*, -altera-id-dependent-backward-branch, + -altera-struct-pack-align, -altera-unroll-loops, -boost-use-ranges, -cppcoreguidelines-avoid-do-while, @@ -9,8 +10,10 @@ Checks: '*, -fuchsia-default-arguments-calls, -fuchsia-default-arguments-declarations, -fuchsia-overloaded-operator, + -fuchsia-virtual-inheritance, -hicpp-vararg, -llvm-else-after-return, -llvmlibc-*, + -misc-include-cleaner, -misc-non-private-member-variables-in-classes, -modernize-use-trailing-return-type' diff --git a/cpp/include/tensorrt_llm/runtime/virtualMemory.h b/cpp/include/tensorrt_llm/runtime/virtualMemory.h new file mode 100644 index 00000000000..c39a60995eb --- /dev/null +++ b/cpp/include/tensorrt_llm/runtime/virtualMemory.h @@ -0,0 +1,540 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/runtime/cudaEvent.h" +#include "tensorrt_llm/runtime/iBuffer.h" +#include "tensorrt_llm/runtime/memoryCounters.h" + +#include +#include +#include +#include +#include + +class VirtualMemoryManagerTest; + +namespace tensorrt_llm::runtime +{ + +/** + * CUDAVirtualMemoryChunk is a handle to a piece of CUDA memory allocation, + * providing the ability to release and rematerialize the allocation. + */ +class CUDAVirtualMemoryChunk +{ +public: + /** + * CUDAVirtualMemoryChunk::Creator is the interface to obtain a CUmemGenericAllocationHandle, + * either by creating one locally, or importing one from remote. + */ + struct Creator + { + Creator() = default; + virtual ~Creator() = default; + Creator(Creator const&) = default; + Creator& operator=(Creator const&) = default; + Creator(Creator&&) = default; + Creator& operator=(Creator&&) = default; + + // Note: create() shall not leak resources when throwing exceptions. + // release() will only, and will always be called if create() success. + // release() will be called with destructing=true when the CUDAVirtualMemoryChunk + // is being destructed. + virtual CUmemGenericAllocationHandle create() = 0; + virtual void release(CUmemGenericAllocationHandle handle, bool destructing) = 0; + }; + + using CreatorPtr = std::unique_ptr; + + /** + * CUDAVirtualMemoryChunk::Configurator is the interface to configure a CUmemGenericAllocationHandle: + * - Map into virtual address + * - Bind to multicast object + * - Backup and restore memory content + */ + struct Configurator + { + Configurator() = default; + virtual ~Configurator() = default; + Configurator(Configurator const&) = default; + Configurator& operator=(Configurator const&) = default; + Configurator(Configurator&&) = default; + Configurator& operator=(Configurator&&) = default; + + // Note: setup() shall not leak resources when throwing exceptions. + // teardown() will only, and will always be called if setup() success. + // teardown() will be called with destructing=true when the CUDAVirtualMemoryChunk + // is being destructed. + virtual void setup(CUmemGenericAllocationHandle handle) = 0; + virtual void teardown(CUmemGenericAllocationHandle handle, bool destructing) = 0; + }; + + using ConfiguratorPtr = std::unique_ptr; + using Configurators = std::vector; + + enum Status + { + INVALID, // This is a default constructed invalid CUDAVirtualMemoryChunk. + RELEASED, // The memory represented by this CUDAVirtualMemoryChunk is not allocated. + MATERIALIZED, // The memory represented by this CUDAVirtualMemoryChunk is allocated. + ERRORED, // Error happened during materialize() or release(). + // This CUDAVirtualMemoryChunk cannot be used anymore. + }; + + [[nodiscard]] Status status() const noexcept + { + if (mCreator == nullptr) + { + return INVALID; + } + + if (mState == 0 && mHandle == 0) + { + return RELEASED; + } + + if (mState == mConfigurators.size() && mHandle != 0) + { + return MATERIALIZED; + } + + return ERRORED; + } + + /** + * Materialize this CUDAVirtualMemoryChunk. + * Shall be called only when status() == RELEASED. + * + * Calls creator.create(), and then configurator.setup() for each configurator in order. + * + * Stop at the first thrown exception and propagates it. + */ + void materialize(); + + /** + * Release this CUDAVirtualMemoryChunk. + * Shall be called only when status() == MATERIALIZED, or materialize() throws. + * Will be called automatically by destructor if necessary. + * + * Calls configurator.teardown() for each configurator that setup() succeed in materialize() in reversed order, + * and then creator.release(). + * + * Never stops early upon exception. The last thrown exception will be propagated, and others logged. + */ + void release() + { + _release(false); + } + + CUDAVirtualMemoryChunk(CUDAVirtualMemoryChunk const&) = delete; + CUDAVirtualMemoryChunk& operator=(CUDAVirtualMemoryChunk const&) = delete; + + CUDAVirtualMemoryChunk(CUDAVirtualMemoryChunk&& other) noexcept + { + mCreator = std::move(other.mCreator); + mConfigurators = std::move(other.mConfigurators); + mHandle = other.mHandle; + mState = other.mState; + new (&other) CUDAVirtualMemoryChunk; // Put other into default constructed state + } + + CUDAVirtualMemoryChunk& operator=(CUDAVirtualMemoryChunk&& other) + { + this->~CUDAVirtualMemoryChunk(); // May throw if current virtual memory need release + new (this) CUDAVirtualMemoryChunk(std::move(other)); + return *this; + } + + CUDAVirtualMemoryChunk() noexcept = default; + + CUDAVirtualMemoryChunk(CreatorPtr&& creator, Configurators&& configurators) + : mCreator(std::move(creator)) + , mConfigurators(std::move(configurators)) + { + } + + virtual ~CUDAVirtualMemoryChunk() + { + // Calling release() is necessary if materialize() succeed or threw an exception. + // If release() is already called by the user, whether succeed or threw an exception, + // we shouldn't call release() again. + if (mHandle != 0 && mState != INVALID_STATE) + { + _release(true); + } + } + + /** + * Test if this CUDAVirtualMemoryChunk is managing a memory block. + */ + explicit operator bool() const noexcept + { + return mCreator != nullptr; + } + +private: + void _release(bool destructing); + + constexpr static size_t INVALID_STATE = static_cast(-1); + size_t mState = 0; + CUmemGenericAllocationHandle mHandle{}; + std::unique_ptr mCreator; + std::vector> mConfigurators; +}; + +/** + * LocalCreator creates memory allocation locally through cuMemCreate. + */ +template +struct LocalCreator : CUDAVirtualMemoryChunk::Creator +{ + LocalCreator(CUmemAllocationProp const& prop, size_t size) + : mProp(prop) + , mSize(size) + { + } + + CUmemGenericAllocationHandle create() override + { + CUmemGenericAllocationHandle handle{}; + TLLM_CU_CHECK(cuMemCreate(&handle, mSize, &mProp, 0)); + if constexpr (count) + { + MemoryCounters::getInstance().allocate( + mProp.location.type == CU_MEM_LOCATION_TYPE_DEVICE ? MemoryType::kGPU : MemoryType::kPINNED, mSize); + } + return handle; + } + + void release(CUmemGenericAllocationHandle handle, bool destructing) override + { + TLLM_CU_CHECK_FREE_RESOURCE(cuMemRelease(handle)); + if constexpr (count) + { + MemoryCounters::getInstance().deallocate( + mProp.location.type == CU_MEM_LOCATION_TYPE_DEVICE ? MemoryType::kGPU : MemoryType::kPINNED, mSize); + } + } + + CUmemAllocationProp mProp{}; + size_t mSize{}; +}; + +/** + * UnicastConfigurator maps the allocation handle into the specified unicast address range. + */ +struct UnicastConfigurator : CUDAVirtualMemoryChunk::Configurator +{ + UnicastConfigurator(CUdeviceptr address, size_t size, CUmemAccessDesc const& desc) + : mAddress(address) + , mSize(size) + , mDesc(desc) + { + } + + void setup(CUmemGenericAllocationHandle handle) override + { + TLLM_CU_CHECK(cuMemMap(mAddress, mSize, 0, handle, 0)); + TLLM_CU_CHECK(cuMemSetAccess(mAddress, mSize, &mDesc, 1)); + } + + void teardown(CUmemGenericAllocationHandle, bool) override + { + TLLM_CU_CHECK_FREE_RESOURCE(cuMemUnmap(mAddress, mSize)); + } + + CUdeviceptr mAddress; + size_t mSize; + CUmemAccessDesc mDesc; +}; + +/** + * MulticastConfigurator binds the allocation handle to the given multicast object and offset. + */ +struct MulticastConfigurator : CUDAVirtualMemoryChunk::Configurator +{ + void setup(CUmemGenericAllocationHandle handle) override + { + TLLM_CU_CHECK(cuMulticastBindMem(mMulticast, 0, handle, mBindOffset, mSize, 0)); + } + + void teardown(CUmemGenericAllocationHandle, bool) override + { + TLLM_CU_CHECK_FREE_RESOURCE(cuMulticastUnbind(mMulticast, mDevice, 0, mSize)); + } + + CUmemGenericAllocationHandle mMulticast; + size_t mBindOffset; + CUdevice mDevice; + size_t mSize; +}; + +/** + * MemsetConfigurator fills the memory with given value. + */ +struct MemsetConfigurator : CUDAVirtualMemoryChunk::Configurator +{ + MemsetConfigurator(CUdeviceptr address, size_t size, uint8_t value, CUstream stream) + : mAddress(address) + , mSize(size) + , mStream(stream) + , mValue(value) + { + } + + void setup(CUmemGenericAllocationHandle) override + { + if (mFirstTime) + { + mFirstTime = false; + } + else + { + TLLM_CU_CHECK(cuMemsetD8Async(mAddress, mValue, mSize, mStream)); + } + } + + void teardown(CUmemGenericAllocationHandle, bool) noexcept override {} + + CUdeviceptr mAddress; + size_t mSize; + CUstream mStream{}; + uint8_t mValue; + bool mFirstTime = true; +}; + +/** + * OffloadConfigurator offload the content of the allocation to the backup storage when teardown, + * and restore the content on the following setup. + */ +struct OffloadConfigurator : CUDAVirtualMemoryChunk::Configurator +{ + OffloadConfigurator(CUdeviceptr address, size_t size, MemoryType backType, CUstream stream, bool ondemand = false) + : mAddress(address) + , mSize(size) + , mBackType(backType) + , mStream(stream) + , mOndemand(ondemand) + { + } + + void setup(CUmemGenericAllocationHandle handle) override; + void teardown(CUmemGenericAllocationHandle handle, bool destructing) override; + + CUdeviceptr mAddress; + size_t mSize; + MemoryType mBackType; + CUstream mStream; + bool mOndemand; + + IBuffer::UniquePtr mBackedStorage; +}; + +class CudaVirtualMemoryManager +{ +public: + /** + * Add memory to be managed by this manager. + * @param handle Unique handle provided to reference this memory in `remove`. + * @param tag Tag the memory, so this memory can be targeted in `releaseWithTag` and `materializeWithTag`. + * @param memory The CUDAVirtualMemory object. + * + * The memory and internal state will remain valid if any exception is thrown. + */ + void add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk&& memory); + + /** + * Creates and adds memory to be managed by this manager. The created memory is automatically materialized. + * @param handle Unique handle provided to reference this memory in `remove`. + * @param tag Tag the memory, so this memory can be targeted in `releaseWithTag` and + * `materializeWithTag`. + * @param creator The creator for the memory. + * @param configurators The configurators for the memory. + * + * The internal state will remain valid if any exception is thrown. + */ + void add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk::CreatorPtr&& creator, + CUDAVirtualMemoryChunk::Configurators&& configurators); + + template + void add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk::CreatorPtr&& creator, + Configurators&&... configurators) + { + add(handle, tag, std::move(creator), {std::forward(configurators)...}); + } + + /** + * Remove the memory from the manager. + * @param handle The handle provided to `add`. + * @return The CUDAVirtualMemory object. If the handle is unknown, an empty CUDAVirtualMemory will be returned. + */ + CUDAVirtualMemoryChunk remove(uintptr_t handle) noexcept; + + /** + * Call release for CUDAVirtualMemoryChunk objects with a given tag. + * @param tag the tag to select target memories. + * @return Number of objects selected. + * + * This function will always call `CUDAVirtualMemoryChunk::release` on all selected objects. + * The last exception thrown by `CUDAVirtualMemoryChunk::release` will be rethrown, and others will be logged. + * + * If any CUDAVirtualMemoryChunk threw an exception during `release`, it will be removed from the manager. + * Call `retrieveBadHandles` to retrieve handles of all CUDAVirtualMemoryChunk that got removed due to exception. + */ + size_t releaseWithTag(std::string const& tag); + + /** + * Call materialize for CUDAVirtualMemoryChunk objects with a given tag. + * @param tag the tag to select target memories. + * @return Number of objects selected. + * + * This function will stop at the first `CUDAVirtualMemoryChunk::materialize` that throws exception, + * and attempt to roll back previous successful `materialize` by calling `release`. + * The exception thrown by `CUDAVirtualMemoryChunk::materialize` will be rethrown, + * and any exception thrown by `release` will be logged. + * + * If any CUDAVirtualMemoryChunk threw an exception during `materialize` or `release`, it will be removed from the + * manager. Successfully roll backed CUDAVirtualMemoryChunk will not be removed. + * Call `retrieveBadHandles` to retrieve handles of all CUDAVirtualMemoryChunk that got removed due to exception. + */ + size_t materializeWithTag(std::string const& tag); + + /** + * Retrieve handles of all CUDAVirtualMemoryChunk that got removed due to exception and reset the list. + * The returned list may not include all removed CUDAVirtualMemoryChunk handles if OOM happened. + * This method is only for diagnostic purpose, and should not be called concurrently with other methods. + * @return The handle list. + */ + std::vector retrieveBadHandles() noexcept; + +private: + CUDAVirtualMemoryChunk unsafeRemove(uintptr_t handle) noexcept; + void addBadHandle(uintptr_t handle) noexcept; + + struct Entry; + // Unordered map invalidates iterator upon rehash, so we can only use the ordered map. + using PointerMemoryMap = std::map; + using TagEntryMap = std::multimap; + + struct Entry + { + CUDAVirtualMemoryChunk mMemory; + TagEntryMap::iterator mEntryIt; + }; + + std::mutex mMutex; + PointerMemoryMap mMemories; + TagEntryMap mEntries; + std::vector mBadHandles; + + friend VirtualMemoryManagerTest; +}; + +class CudaVirtualMemoryAllocator +{ + using CudaStreamPtr = std::shared_ptr; + using Pointer = void*; + +public: + enum RestoreMode + { + NONE, // The memory is not backed. Upon rematerialize, memory has uninitialized content. + MEMSET, // The memory is memset to zero upon rematerialize. + CPU, // The memory is backed by normal CPU memory. The content is restored upon rematerialize. + PINNED // The memory is backed by pinned CPU memory. The content is restored upon rematerialize. + }; + + class Configuration + { + CudaVirtualMemoryManager& mManager; + std::string mTag; + CudaStreamPtr mBackStream; + std::size_t mPageSize; + RestoreMode mMode; + bool mBackground{}; + + friend class CudaVirtualMemoryAllocator; + friend void setVirtualMemoryAllocator( + std::string const& tag, RestoreMode mode, std::shared_ptr backStream); + + public: + /** + * CudaVirtualMemoryAllocator::Configuration + * @param manager Manager used to track and manage virtual memories + * @param tag The tag for allocated memories + * @param mode Backed storage mode + * @param backStream The CUDA stream used for restoring memory content + * Note: Virtual Address Allocation is not async. The stream is not used in allocation. + */ + Configuration(CudaVirtualMemoryManager& manager, std::string tag, RestoreMode mode, CudaStreamPtr backStream) + : mManager(manager) + , mTag(std::move(tag)) + , mBackStream(std::move(backStream)) + , mPageSize(getpagesize()) + , mMode(mode) + { + } + + [[nodiscard]] std::size_t pageAligned(std::size_t n) const noexcept + { + return (n + mPageSize - 1) & ~(mPageSize - 1); + } + + // Background configuration, used to indicate no virtual memory allocator is explicitly configured by the user. + static Configuration backgroundConfiguration; + + private: + Configuration(CudaVirtualMemoryManager& manager, std::string tag, RestoreMode mode, CudaStreamPtr backStream, + bool background) + : Configuration(manager, std::move(tag), mode, std::move(backStream)) + { + mBackground = background; + } + }; + + explicit CudaVirtualMemoryAllocator(std::shared_ptr config) + : mConfig(std::move(config)) + { + } + + // Tells if this is the background allocator. + explicit operator bool() const noexcept + { + return !mConfig->mBackground; + } + + void allocate(Pointer* ptr, std::size_t n, int device) const; + void deallocate(Pointer ptr, std::size_t n) const; + +private: + std::shared_ptr mConfig; +}; + +} // namespace tensorrt_llm::runtime + +namespace tensorrt_llm::runtime +{ +CudaVirtualMemoryManager& getVirtualMemoryManager(); +CudaVirtualMemoryAllocator getVirtualMemoryAllocator(); +void setVirtualMemoryAllocator( + std::string const& tag, CudaVirtualMemoryAllocator::RestoreMode mode, std::shared_ptr backStream); +void clearVirtualMemoryAllocator(); + +} // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/common/cudaDriverWrapper.h b/cpp/tensorrt_llm/common/cudaDriverWrapper.h index 51ec498d92c..cc3328993c9 100644 --- a/cpp/tensorrt_llm/common/cudaDriverWrapper.h +++ b/cpp/tensorrt_llm/common/cudaDriverWrapper.h @@ -155,6 +155,16 @@ void checkDriver( } } +template +void checkDriverExitSafe(T result, char const* const func, char const* const file, int const line) +{ + if (result != CUDA_SUCCESS && result != CUDA_ERROR_DEINITIALIZED) + { + throw TllmException( + file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %d.", func, result).c_str()); + } +} + } // namespace tensorrt_llm::common /* @@ -167,4 +177,11 @@ void checkDriver( (stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \ } while (0) +// Avoid using CUDADriverWrapper when freeing resource, during which the global instance may already be freed. +#define TLLM_CU_CHECK_FREE_RESOURCE(stat) \ + do \ + { \ + tensorrt_llm::common::checkDriverExitSafe((stat), #stat, __FILE__, __LINE__); \ + } while (0) + #endif // CUDA_DRIVER_WRAPPER_H diff --git a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp index f6ba107d446..1550ca4e97c 100644 --- a/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/runtime/bindings.cpp @@ -39,6 +39,7 @@ #include "tensorrt_llm/runtime/speculativeDecodingMode.h" #include "tensorrt_llm/runtime/tllmRuntime.h" #include "tensorrt_llm/runtime/torchView.h" +#include "tensorrt_llm/runtime/virtualMemory.h" #include #include @@ -116,6 +117,10 @@ void initBindings(nb::module_& m) .def_rw("scaling_vec_pointer", &tr::LoraCache::TaskLayerModuleConfig::scalingVecPointer) .def(nb::self == nb::self); + nb::class_(m, "CudaVirtualMemoryManager") + .def("release_with_tag", &tr::CudaVirtualMemoryManager::releaseWithTag, nb::arg("tag")) + .def("materialize_with_tag", &tr::CudaVirtualMemoryManager::materializeWithTag, nb::arg("tag")); + nb::class_(m, "BufferManager") .def(nb::init(), nb::arg("stream"), nb::arg("trim_pool") = false) .def_prop_ro("stream", &tr::BufferManager::getStream); @@ -312,6 +317,29 @@ void initBindings(nb::module_& m) [](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); }, "Calculate the maximum workspace size needed for low precision all-reduce operations"); + nb::enum_(m, "CudaVirtualMemoryAllocatorRestoreMode") + .value("NONE", tr::CudaVirtualMemoryAllocator::RestoreMode::NONE) + .value("CPU", tr::CudaVirtualMemoryAllocator::RestoreMode::CPU) + .value("PINNED", tr::CudaVirtualMemoryAllocator::RestoreMode::PINNED) + .value("MEMSET", tr::CudaVirtualMemoryAllocator::RestoreMode::MEMSET); + + m.def("get_virtual_memory_manager", &tr::getVirtualMemoryManager, "Get the virtual memory manager", + nb::rv_policy::reference); + + m.def( + "set_virtual_memory_allocator", + [](std::string const& tag, tr::CudaVirtualMemoryAllocator::RestoreMode mode, uintptr_t stream) + { + static_assert(sizeof(uintptr_t) == sizeof(cudaStream_t)); + tr::setVirtualMemoryAllocator(tag, mode, + std::make_shared( + reinterpret_cast(stream), tensorrt_llm::common::getDevice(), false)); + }, + "Set the virtual memory allocator and start allocating virtual memory for CUDA allocations"); + + m.def("clear_virtual_memory_allocator", &tr::clearVirtualMemoryAllocator, + "Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations"); + nb::class_(m, "McastGPUBuffer") .def(nb::init()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer) diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index db942c8d33e..2d387f3afec 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -38,6 +38,7 @@ #include "tensorrt_llm/runtime/speculativeDecodingMode.h" #include "tensorrt_llm/runtime/tllmRuntime.h" #include "tensorrt_llm/runtime/torchView.h" +#include "tensorrt_llm/runtime/virtualMemory.h" #include #include @@ -213,6 +214,10 @@ void initBindings(pybind11::module_& m) .def_readwrite("scaling_vec_pointer", &tr::LoraCache::TaskLayerModuleConfig::scalingVecPointer) .def(py::self == py::self); + py::class_(m, "CudaVirtualMemoryManager") + .def("release_with_tag", &tr::CudaVirtualMemoryManager::releaseWithTag, py::arg("tag")) + .def("materialize_with_tag", &tr::CudaVirtualMemoryManager::materializeWithTag, py::arg("tag")); + py::classh(m, "BufferManager") .def(py::init(), py::arg("stream"), py::arg("trim_pool") = false) .def_property_readonly("stream", &tr::BufferManager::getStream); @@ -406,6 +411,29 @@ void initBindings(pybind11::module_& m) [](int32_t tp_size) { return tensorrt_llm::kernels::max_workspace_size_lowprecision(tp_size); }, "Calculate the maximum workspace size needed for low precision all-reduce operations"); + py::enum_(m, "CudaVirtualMemoryAllocatorRestoreMode") + .value("NONE", tr::CudaVirtualMemoryAllocator::RestoreMode::NONE) + .value("CPU", tr::CudaVirtualMemoryAllocator::RestoreMode::CPU) + .value("PINNED", tr::CudaVirtualMemoryAllocator::RestoreMode::PINNED) + .value("MEMSET", tr::CudaVirtualMemoryAllocator::RestoreMode::MEMSET); + + m.def("get_virtual_memory_manager", &tr::getVirtualMemoryManager, "Get the virtual memory manager", + py::return_value_policy::reference); + + m.def( + "set_virtual_memory_allocator", + [](std::string const& tag, tr::CudaVirtualMemoryAllocator::RestoreMode mode, uintptr_t stream) + { + static_assert(sizeof(uintptr_t) == sizeof(cudaStream_t)); + tr::setVirtualMemoryAllocator(tag, mode, + std::make_shared( + reinterpret_cast(stream), tensorrt_llm::common::getDevice(), false)); + }, + "Set the virtual memory allocator and start allocating virtual memory for CUDA allocations"); + + m.def("clear_virtual_memory_allocator", &tr::clearVirtualMemoryAllocator, + "Reset the current virtual memory allocator and stop allocating virtual memory for CUDA allocations"); + py::class_(m, "McastGPUBuffer") .def(py::init()) .def("get_uc_buffer", &tensorrt_llm::runtime::McastGPUBuffer::getUCBuffer) diff --git a/cpp/tensorrt_llm/runtime/CMakeLists.txt b/cpp/tensorrt_llm/runtime/CMakeLists.txt index 79b8ea7d784..3ed61445d60 100644 --- a/cpp/tensorrt_llm/runtime/CMakeLists.txt +++ b/cpp/tensorrt_llm/runtime/CMakeLists.txt @@ -55,7 +55,8 @@ set(SRCS tllmStreamReaders.cpp tllmLogger.cpp workerPool.cpp - worldConfig.cpp) + worldConfig.cpp + virtualMemory.cpp) include_directories(${API_INCLUDE_DIR}/tensorrt_llm/runtime) diff --git a/cpp/tensorrt_llm/runtime/bufferManager.cpp b/cpp/tensorrt_llm/runtime/bufferManager.cpp index 13e904b9e1e..3de42a25315 100644 --- a/cpp/tensorrt_llm/runtime/bufferManager.cpp +++ b/cpp/tensorrt_llm/runtime/bufferManager.cpp @@ -39,6 +39,10 @@ BufferManager::BufferManager(CudaStreamPtr stream, bool trimPool) BufferManager::IBufferPtr BufferManager::gpu(std::size_t size, nvinfer1::DataType type) const { + if (auto vmAllocator = getVirtualMemoryAllocator()) + { + return std::make_unique(size, type, std::move(vmAllocator)); + } if (static_cast(mPool)) { return std::make_unique(size, type, CudaAllocatorAsync{mStream, mPool}); @@ -49,6 +53,10 @@ BufferManager::IBufferPtr BufferManager::gpu(std::size_t size, nvinfer1::DataTyp BufferManager::ITensorPtr BufferManager::gpu(nvinfer1::Dims dims, nvinfer1::DataType type) const { + if (auto vmAllocator = getVirtualMemoryAllocator()) + { + return std::make_unique(dims, type, std::move(vmAllocator)); + } if (static_cast(mPool)) { return std::make_unique(dims, type, CudaAllocatorAsync{mStream, mPool}); @@ -59,11 +67,19 @@ BufferManager::ITensorPtr BufferManager::gpu(nvinfer1::Dims dims, nvinfer1::Data BufferManager::IBufferPtr BufferManager::gpuSync(std::size_t size, nvinfer1::DataType type) { + if (auto vmAllocator = getVirtualMemoryAllocator()) + { + return std::make_unique(size, type, std::move(vmAllocator)); + } return std::make_unique(size, type, CudaAllocator{}); } BufferManager::ITensorPtr BufferManager::gpuSync(nvinfer1::Dims dims, nvinfer1::DataType type) { + if (auto vmAllocator = getVirtualMemoryAllocator()) + { + return std::make_unique(dims, type, std::move(vmAllocator)); + } return std::make_unique(dims, type, CudaAllocator{}); } diff --git a/cpp/tensorrt_llm/runtime/tllmBuffers.h b/cpp/tensorrt_llm/runtime/tllmBuffers.h index 38263bb5aa8..faed36537e5 100644 --- a/cpp/tensorrt_llm/runtime/tllmBuffers.h +++ b/cpp/tensorrt_llm/runtime/tllmBuffers.h @@ -25,6 +25,7 @@ #include "tensorrt_llm/runtime/iTensor.h" #include "tensorrt_llm/runtime/ipcNvlsMemory.h" #include "tensorrt_llm/runtime/memoryCounters.h" +#include "tensorrt_llm/runtime/virtualMemory.h" #include #include @@ -500,6 +501,36 @@ class PoolAllocator : public BaseAllocator, TAllocator using PinnedPoolAllocator = PoolAllocator; +class CudaVirtualMemoryAllocatorAdaptor + : public BaseAllocator, + CudaVirtualMemoryAllocator +{ + // Update to MemoryCounters is done in Creator to more precisely reflect the memory usage. + using Base = BaseAllocator; + friend Base; + +public: + // No explicit, to allow implicit conversion from CudaVirtualMemoryAllocator + CudaVirtualMemoryAllocatorAdaptor(CudaVirtualMemoryAllocator const& allocator) + : CudaVirtualMemoryAllocator(allocator) + { + } + + using Base::allocate; + using Base::deallocate; + +protected: + void allocateImpl(PointerType* ptr, std::size_t n) const + { + this->CudaVirtualMemoryAllocator::allocate(ptr, n, tensorrt_llm::common::getDevice()); + } + + void deallocateImpl(PointerType ptr, std::size_t n) const + { + this->CudaVirtualMemoryAllocator::deallocate(ptr, n); + } +}; + // Adopted from https://github.com/NVIDIA/TensorRT/blob/release/8.6/samples/common/buffers.h //! @@ -508,17 +539,10 @@ using PinnedPoolAllocator = PoolAllocator; //! \details This templated RAII (Resource Acquisition Is Initialization) class handles the allocation, //! deallocation, querying of buffers on both the device and the host. //! It can handle data of arbitrary types because it stores byte buffers. -//! The template parameters AllocFunc and FreeFunc are used for the -//! allocation and deallocation of the buffer. -//! AllocFunc must be a functor that takes in (void** ptr, size_t size) -//! and returns bool. ptr is a pointer to where the allocated buffer address should be stored. -//! size is the amount of memory in bytes to allocate. -//! The boolean indicates whether or not the memory allocation was successful. -//! FreeFunc must be a functor that takes in (void* ptr) and returns void. -//! ptr is the allocated buffer address. It must work with nullptr input. +//! The template parameter TAllocator must inherit from BaseAllocator. //! template -class GenericBuffer : virtual public IBuffer +class GenericBuffer : virtual public IBuffer, TAllocator // Inherit from TAllocator for EBO { public: using AllocatorType = TAllocator; @@ -527,20 +551,27 @@ class GenericBuffer : virtual public IBuffer //! \brief Construct an empty buffer. //! explicit GenericBuffer(nvinfer1::DataType type, TAllocator allocator = {}) // NOLINT(*-pro-type-member-init) - : GenericBuffer{0, type, std::move(allocator)} {}; + : GenericBuffer{0, type, std::move(allocator)} + { + } //! //! \brief Construct a buffer with the specified allocation size in number of elements. //! explicit GenericBuffer( // NOLINT(*-pro-type-member-init) std::size_t size, nvinfer1::DataType type, TAllocator allocator = {}) - : GenericBuffer{size, size, type, std::move(allocator)} {}; + : GenericBuffer{size, size, type, std::move(allocator)} + { + } + + GenericBuffer(GenericBuffer const& other) = delete; + GenericBuffer& operator=(GenericBuffer const& buf) = delete; GenericBuffer(GenericBuffer&& buf) noexcept - : mSize{buf.mSize} + : TAllocator(static_cast(buf)) + , mSize{buf.mSize} , mCapacity{buf.mCapacity} , mType{buf.mType} - , mAllocator{std::move(buf.mAllocator)} , mBuffer{buf.mBuffer} { buf.mSize = 0; @@ -552,11 +583,11 @@ class GenericBuffer : virtual public IBuffer { if (this != &buf) { - mAllocator.deallocate(mBuffer, toBytes(mCapacity)); + this->TAllocator::deallocate(mBuffer, toBytes(mCapacity)); mSize = buf.mSize; mCapacity = buf.mCapacity; mType = buf.mType; - mAllocator = std::move(buf.mAllocator); + *static_cast(this) = static_cast(buf); mBuffer = buf.mBuffer; // Reset buf. buf.mSize = 0; @@ -615,7 +646,7 @@ class GenericBuffer : virtual public IBuffer //! [[nodiscard]] MemoryType getMemoryType() const override { - return mAllocator.getMemoryType(); + return this->TAllocator::getMemoryType(); } //! @@ -625,8 +656,8 @@ class GenericBuffer : virtual public IBuffer { if (mCapacity < newSize) { - mAllocator.deallocate(mBuffer, toBytes(mCapacity)); - mBuffer = mAllocator.allocate(toBytes(newSize)); + this->TAllocator::deallocate(mBuffer, toBytes(mCapacity)); + mBuffer = this->TAllocator::allocate(toBytes(newSize)); mCapacity = newSize; } mSize = newSize; @@ -637,7 +668,7 @@ class GenericBuffer : virtual public IBuffer //! void release() override { - mAllocator.deallocate(mBuffer, toBytes(mCapacity)); + this->TAllocator::deallocate(mBuffer, toBytes(mCapacity)); mSize = 0; mCapacity = 0; mBuffer = nullptr; @@ -647,7 +678,7 @@ class GenericBuffer : virtual public IBuffer { try { - mAllocator.deallocate(mBuffer, toBytes(mCapacity)); + this->TAllocator::deallocate(mBuffer, toBytes(mCapacity)); } catch (std::exception const& e) { @@ -657,11 +688,11 @@ class GenericBuffer : virtual public IBuffer protected: explicit GenericBuffer(std::size_t size, std::size_t capacity, nvinfer1::DataType type, TAllocator allocator = {}) - : mSize{size} + : TAllocator{std::move(allocator)} + , mSize{size} , mCapacity{capacity} , mType{type} - , mAllocator{std::move(allocator)} - , mBuffer{capacity > 0 ? mAllocator.allocate(toBytes(capacity)) : nullptr} + , mBuffer{capacity > 0 ? this->TAllocator::allocate(toBytes(capacity)) : nullptr} { TLLM_CHECK(size <= capacity); TLLM_CHECK(capacity == 0 || size > 0); @@ -670,7 +701,6 @@ class GenericBuffer : virtual public IBuffer private: std::size_t mSize{0}, mCapacity{0}; nvinfer1::DataType mType; - TAllocator mAllocator; void* mBuffer; }; @@ -834,6 +864,7 @@ using HostBuffer = GenericBuffer; using PinnedBuffer = GenericBuffer; using PinnedPoolBuffer = GenericBuffer; using UVMBuffer = GenericBuffer; +using VirtualAddressDeviceBuffer = GenericBuffer; template std::make_unsigned_t nonNegative(T value) @@ -1069,5 +1100,6 @@ using HostTensor = GenericTensor; using PinnedTensor = GenericTensor; using PinnedPoolTensor = GenericTensor; using UVMTensor = GenericTensor; +using VirtualAddressDeviceTensor = GenericTensor; } // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/virtualMemory.cpp b/cpp/tensorrt_llm/runtime/virtualMemory.cpp new file mode 100644 index 00000000000..488da30d653 --- /dev/null +++ b/cpp/tensorrt_llm/runtime/virtualMemory.cpp @@ -0,0 +1,433 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/runtime/virtualMemory.h" +#include "bufferManager.h" + +#include +#include + +namespace tensorrt_llm::runtime +{ + +namespace +{ + +template +struct ScopeGuard +{ + bool const& ok; + T t; + + ~ScopeGuard() noexcept(noexcept(t())) + { + if (!ok) + { + t(); + } + } +}; + +template +ScopeGuard(bool const&, T) -> ScopeGuard; + +} // namespace + +void CUDAVirtualMemoryChunk::materialize() +{ + TLLM_CHECK_WITH_INFO(status() == RELEASED, "virtual memory not in RELEASED status, is: %d", status()); + mHandle = mCreator->create(); + + // Track the number of configurators ran, so release can correctly teardown. + for (auto const& conf : mConfigurators) + { + conf->setup(mHandle); // May throw + ++mState; + } +} + +template +static bool safe_invoke_helper(std::exception_ptr& ep, char const* msg, Callable&& f, Args&&... args) noexcept +{ + try + { + std::invoke(std::forward(f), std::forward(args)...); + return true; + } + catch (...) + { + if (ep) + { + try + { + std::rethrow_exception(ep); + } + catch (std::exception& e) + { + TLLM_LOG_ERROR(msg, e.what()); + } + } + ep = std::current_exception(); + return false; + } +} + +void CUDAVirtualMemoryChunk::_release(bool destructing) +{ + TLLM_CHECK_WITH_INFO(status() == MATERIALIZED || (status() == ERRORED && mState != INVALID_STATE), + "virtual memory is in status %d which cannot be released", status()); + size_t const count = mConfigurators.size(); + size_t const start = count - mState; + + // Revert materialize(). Only configurators that ran setup() successfully + // will have their teardown() been called. + // Never early returns on exceptions. The last exception will be rethrown, and + // previous ones will be logged. + std::exception_ptr ePtr{}; + auto const* msg = "Multiple exceptions thrown during release. The previous exception is: %s"; + for (size_t i = start; i < count; ++i) + { + safe_invoke_helper( + ePtr, msg, &Configurator::teardown, mConfigurators[count - i - 1].get(), mHandle, destructing); + } + safe_invoke_helper(ePtr, msg, &Creator::release, mCreator.get(), mHandle, destructing); + mHandle = {}; + mState = 0; + + if (ePtr != nullptr) + { + mState = INVALID_STATE; + std::rethrow_exception(ePtr); + } +} + +void OffloadConfigurator::setup(CUmemGenericAllocationHandle) +{ + if (mBackedStorage != nullptr) + { + if (mOndemand) + { + TLLM_CU_CHECK(cuMemcpyHtoD_v2(mAddress, mBackedStorage->data(), mSize)); + mBackedStorage.reset(); + } + else + { + TLLM_CU_CHECK(cuMemcpyHtoDAsync_v2(mAddress, mBackedStorage->data(), mSize, mStream)); + } + } +} + +void OffloadConfigurator::teardown(CUmemGenericAllocationHandle, bool destructing) +{ + if (destructing) + { + return; + } + + if (mBackedStorage == nullptr) + { + switch (mBackType) + { + case MemoryType::kCPU: mBackedStorage = BufferManager::cpu(mSize, nvinfer1::DataType::kINT8); break; + case MemoryType::kPINNED: mBackedStorage = BufferManager::pinned(mSize, nvinfer1::DataType::kINT8); break; + default: TLLM_THROW("Unknown memory type: %d", static_cast(mBackType)); + } + } + // We have to synchronize here, or the memory may be unmapped before the copy operation. + TLLM_CU_CHECK_FREE_RESOURCE(cuMemcpyDtoH_v2(mBackedStorage->data(), mAddress, mSize)); +} + +void CudaVirtualMemoryManager::add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk&& memory) +{ + bool success = false; + + TLLM_CHECK_WITH_INFO( + memory.status() == CUDAVirtualMemoryChunk::RELEASED || memory.status() == CUDAVirtualMemoryChunk::MATERIALIZED, + "CudaVirtualMemoryManager: bad virtual memory status"); + + std::unique_lock lock(mMutex); + auto [memIt, created] = mMemories.try_emplace(handle, Entry{}); + TLLM_CHECK_WITH_INFO( + created, "CudaVirtualMemoryManager: handle 0x%016zx already being used by another memory", handle); + ScopeGuard eraseMemIt{success, [&, memIt_ = memIt] { mMemories.erase(memIt_); }}; + + auto const entryIt = mEntries.emplace(std::move(tag), memIt); + entryIt->second->second.mEntryIt = entryIt; + + memIt->second.mMemory = std::move(memory); + success = true; +} + +void CudaVirtualMemoryManager::add(uintptr_t handle, std::string tag, CUDAVirtualMemoryChunk::CreatorPtr&& creator, + CUDAVirtualMemoryChunk::Configurators&& configurators) +{ + std::unique_lock lock(mMutex); + bool success = false; + + auto [memIt, created] = mMemories.try_emplace(handle, + Entry{ + {std::move(creator), std::move(configurators)}, + }); + TLLM_CHECK_WITH_INFO( + created, "CudaVirtualMemoryManager: handle 0x%016zx already being used by another memory", handle); + ScopeGuard eraseMemIt{success, [&, memIt_ = memIt] { mMemories.erase(memIt_); }}; + + auto const entryIt = mEntries.emplace(std::move(tag), memIt); + memIt->second.mEntryIt = entryIt; + ScopeGuard eraseTagIt{success, [&] { mEntries.erase(entryIt); }}; + + try + { + // Hopefully we don't need to hold the mutex guarding mMemories and mEntries anymore. + lock.unlock(); + memIt->second.mMemory.materialize(); + success = true; + } + catch (...) + { + // ...unless materialize() throws and we need to rollback. + lock.lock(); + throw; + } +} + +CUDAVirtualMemoryChunk CudaVirtualMemoryManager::remove(uintptr_t handle) noexcept +{ + std::unique_lock lock(mMutex); + + return unsafeRemove(handle); +} + +CUDAVirtualMemoryChunk CudaVirtualMemoryManager::unsafeRemove(uintptr_t handle) noexcept +{ + auto const nodeHandle = mMemories.extract(handle); + if (!nodeHandle) + { + return {}; + } + mEntries.erase(nodeHandle.mapped().mEntryIt); + + return std::move(nodeHandle.mapped().mMemory); +} + +void CudaVirtualMemoryManager::addBadHandle(uintptr_t handle) noexcept +{ + try + { + mBadHandles.push_back(handle); + } + catch (...) + { + } +} + +std::vector CudaVirtualMemoryManager::retrieveBadHandles() noexcept +{ + return std::move(mBadHandles); +} + +size_t CudaVirtualMemoryManager::releaseWithTag(std::string const& tag) +{ + std::unique_lock lock(mMutex); + + std::exception_ptr ePtr{}; + auto [begin, end] = mEntries.equal_range(tag); + size_t count = 0; + for (auto it = begin; it != end;) + { + auto const handle = it->second->first; + auto& memory = it->second->second.mMemory; + ++it; // element referenced by `it` will be invalidated by unsafeRemove(handle) + if (memory.status() == CUDAVirtualMemoryChunk::MATERIALIZED) + { + if (!safe_invoke_helper(ePtr, + "Multiple exceptions thrown during releaseWithTag. The previous exception is: %s", + &CUDAVirtualMemoryChunk::release, &memory)) + { + addBadHandle(handle); + unsafeRemove(handle); + } + ++count; + } + } + + if (ePtr != nullptr) + { + std::rethrow_exception(ePtr); + } + + return count; +} + +size_t CudaVirtualMemoryManager::materializeWithTag(std::string const& tag) +{ + std::unique_lock lock(mMutex); + + auto [begin, end] = mEntries.equal_range(tag); + size_t count = 0; + + auto it = begin; + + try + { + for (; it != end; ++it) + { + auto& memory = it->second->second.mMemory; + if (memory.status() == CUDAVirtualMemoryChunk::RELEASED) + { + memory.materialize(); + ++count; + } + } + } + catch (...) + { + for (auto itRollback = begin; itRollback != it;) + { + auto const handle = itRollback->second->first; + auto& memory = itRollback->second->second.mMemory; + ++itRollback; + try + { + memory.release(); + } + catch (std::exception& e) + { + addBadHandle(handle); + unsafeRemove(handle); + TLLM_LOG_ERROR("Additional exception thrown during rollback of materializeWithTag: %s", e.what()); + } + } + + addBadHandle(it->second->first); + unsafeRemove(it->second->first); + + throw; + } + return count; +} + +static_assert(sizeof(void*) == sizeof(CUdeviceptr)); + +static CUdeviceptr deviceptr_cast(void* ptr) +{ + CUdeviceptr ret{}; + std::memcpy(&ret, &ptr, sizeof(CUdeviceptr)); + return ret; +} + +static void* deviceptr_cast(CUdeviceptr ptr) +{ + void* ret{}; + std::memcpy(&ret, &ptr, sizeof(CUdeviceptr)); + return ret; +} + +void CudaVirtualMemoryAllocator::allocate(Pointer* ptr, std::size_t n, int device) const +{ + CUdeviceptr address{}; + std::size_t const pageAlignedSize = mConfig->pageAligned(n); + TLLM_CU_CHECK(cuMemAddressReserve(&address, pageAlignedSize, 0, {}, 0)); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::make_unique(address, n, + CUmemAccessDesc{{ + CU_MEM_LOCATION_TYPE_DEVICE, + device, + }, + CU_MEM_ACCESS_FLAGS_PROT_READWRITE})); + + switch (mConfig->mMode) + { + case NONE: break; + case MEMSET: + configurators.push_back(std::make_unique(address, n, 0, mConfig->mBackStream->get())); + break; + case CPU: + configurators.push_back( + std::make_unique(address, n, MemoryType::kCPU, mConfig->mBackStream->get())); + break; + case PINNED: + configurators.push_back( + std::make_unique(address, n, MemoryType::kPINNED, mConfig->mBackStream->get())); + break; + } + + mConfig->mManager.add(address, mConfig->mTag, + std::make_unique>(CUmemAllocationProp{CU_MEM_ALLOCATION_TYPE_PINNED, CU_MEM_HANDLE_TYPE_NONE, + { + CU_MEM_LOCATION_TYPE_DEVICE, + device, + }}, + n), + std::move(configurators)); + + *ptr = deviceptr_cast(address); +} + +void CudaVirtualMemoryAllocator::deallocate(Pointer ptr, std::size_t n) const +{ + auto const address = deviceptr_cast(ptr); + mConfig->mManager.remove(address); + + std::size_t const pageAlignedSize = mConfig->pageAligned(n); + TLLM_CU_CHECK_FREE_RESOURCE(cuMemAddressFree(address, pageAlignedSize)); +} + +} // namespace tensorrt_llm::runtime + +namespace tensorrt_llm::runtime +{ + +CudaVirtualMemoryManager& getVirtualMemoryManager() +{ + static CudaVirtualMemoryManager manager; + return manager; +} + +using AllocConf = CudaVirtualMemoryAllocator::Configuration; + +AllocConf AllocConf::backgroundConfiguration{getVirtualMemoryManager(), "", NONE, nullptr, true}; + +static const std::shared_ptr bgConf{std::shared_ptr{}, &AllocConf::backgroundConfiguration}; + +static std::shared_mutex currentConfMutex; +static std::shared_ptr currentConf = bgConf; + +CudaVirtualMemoryAllocator getVirtualMemoryAllocator() +{ + std::shared_lock lock(currentConfMutex); + return CudaVirtualMemoryAllocator{currentConf}; +} + +void setVirtualMemoryAllocator( + std::string const& tag, CudaVirtualMemoryAllocator::RestoreMode mode, std::shared_ptr backStream) +{ + std::unique_lock lock(currentConfMutex); + + TLLM_CHECK_WITH_INFO(currentConf == bgConf, + "An active virtual memory allocator (tag: %s, mode: %d, stream: %p) is already present", + currentConf->mTag.c_str(), currentConf->mMode, currentConf->mBackStream.get()); + currentConf = std::make_shared(getVirtualMemoryManager(), tag, mode, backStream); +} + +void clearVirtualMemoryAllocator() +{ + std::unique_lock lock(currentConfMutex); + currentConf = bgConf; +} + +} // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/thop/CMakeLists.txt b/cpp/tensorrt_llm/thop/CMakeLists.txt index 8e41e2a2886..a9d0d4009f9 100644 --- a/cpp/tensorrt_llm/thop/CMakeLists.txt +++ b/cpp/tensorrt_llm/thop/CMakeLists.txt @@ -85,6 +85,7 @@ add_library( selectiveScanOp.cpp userbuffersFinalizeOp.cpp userbuffersTensor.cpp + virtualMemoryAllocator.cpp weightOnlyQuantGemm.cpp weightOnlyQuantOp.cpp mtpOp.cpp diff --git a/cpp/tensorrt_llm/thop/virtualMemoryAllocator.cpp b/cpp/tensorrt_llm/thop/virtualMemoryAllocator.cpp new file mode 100644 index 00000000000..77c2c3dfd82 --- /dev/null +++ b/cpp/tensorrt_llm/thop/virtualMemoryAllocator.cpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/runtime/virtualMemory.h" +#include +#include + +extern "C" +{ + + void* tensorrt_llm_virtual_memory_alloc(ssize_t size, int device, cudaStream_t) noexcept + { + void* ptr{}; + try + { + tensorrt_llm::runtime::getVirtualMemoryAllocator().allocate(&ptr, size, device); + } + catch (std::exception const& e) + { + TLLM_LOG_EXCEPTION(e); + ptr = {}; + } + catch (...) + { + TLLM_LOG_ERROR("Unknown exception thrown allocating virtual memory"); + ptr = {}; + } + + return ptr; + } + + void tensorrt_llm_virtual_memory_free(void* ptr, ssize_t size, cudaStream_t) noexcept + { + try + { + tensorrt_llm::runtime::getVirtualMemoryAllocator().deallocate(ptr, size); + } + catch (std::exception const& e) + { + TLLM_LOG_EXCEPTION(e); + } + catch (...) + { + TLLM_LOG_ERROR("Unknown exception thrown deallocating virtual memory"); + } + } +} diff --git a/cpp/tests/unit_tests/runtime/CMakeLists.txt b/cpp/tests/unit_tests/runtime/CMakeLists.txt index 25ec2ab1b5e..77db1d9fd0e 100644 --- a/cpp/tests/unit_tests/runtime/CMakeLists.txt +++ b/cpp/tests/unit_tests/runtime/CMakeLists.txt @@ -28,6 +28,7 @@ add_gtest(tllmRuntimeTest tllmRuntimeTest.cpp) add_gtest(transposeKVKernelTest transposeKVKernelTest.cpp) add_gtest(userBufferTest userBufferTest.cpp) add_gtest(utilsTest utilsTest.cpp) +add_gtest(virtualMemoryTest virtualMemoryTest.cpp) add_gtest(workerPoolTest workerPoolTest.cpp) add_gtest(worldConfigTest worldConfigTest.cpp) diff --git a/cpp/tests/unit_tests/runtime/virtualMemoryTest.cpp b/cpp/tests/unit_tests/runtime/virtualMemoryTest.cpp new file mode 100644 index 00000000000..970a05299b1 --- /dev/null +++ b/cpp/tests/unit_tests/runtime/virtualMemoryTest.cpp @@ -0,0 +1,1572 @@ +/* + * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "tensorrt_llm/common/cudaUtils.h" +#include "tensorrt_llm/runtime/bufferManager.h" +#include "tensorrt_llm/runtime/tllmBuffers.h" +#include "tensorrt_llm/runtime/virtualMemory.h" + +#include +#include +#include +#include + +using namespace tensorrt_llm::runtime; +namespace tc = tensorrt_llm::common; + +struct DummyException : std::runtime_error +{ + DummyException() + : runtime_error("dummy exception") + { + } +}; + +class VirtualMemoryTestBase : public ::testing::Test +{ +protected: + void SetUp() override + { + if (tc::getDeviceCount() == 0) + { + GTEST_SKIP() << "This test suite cannot run on systems with no devices."; + } + + TLLM_CU_CHECK(cuInit(0)); + + CUdevice dev; + TLLM_CU_CHECK(cuDeviceGet(&dev, 0)); + + CUcontext ctx; + TLLM_CU_CHECK(cuDevicePrimaryCtxRetain(&ctx, dev)); + TLLM_CU_CHECK(cuCtxSetCurrent(ctx)); + + // Initialize NVML + nvmlReturn_t nvmlResult = nvmlInit(); + TLLM_CHECK_WITH_INFO(nvmlResult == NVML_SUCCESS, "Failed to initialize NVML: %s", nvmlErrorString(nvmlResult)); + + if (!memoryInfoAvailable()) + { + TLLM_LOG_WARNING("Per process memory information unavailable."); + } + + TLLM_CUDA_CHECK(cudaDeviceSynchronize()); + } + + void TearDown() override + { + TLLM_CUDA_CHECK(cudaDeviceSynchronize()); + } + + static bool memoryInfoAvailable() + { + static bool available = [] + { + auto blob = BufferManager::gpuSync(4096); + auto usage = getCurrentProcessMemoryInfo(); + return usage != 0; + }(); + + return available; + } + + static size_t getCurrentProcessMemoryInfo() + { + // Get current process ID + uint32_t currentPid = static_cast(getpid()); + + // Get device handle for GPU 0 + nvmlDevice_t device; + auto nvmlResult = nvmlDeviceGetHandleByIndex(0, &device); + TLLM_CHECK_WITH_INFO( + nvmlResult == NVML_SUCCESS, "Failed to get device handle: %s", nvmlErrorString(nvmlResult)); + + // Get running processes + unsigned int processCount = 1; + std::vector processes(processCount); + nvmlResult = NVML_ERROR_INSUFFICIENT_SIZE; + while (nvmlResult == NVML_ERROR_INSUFFICIENT_SIZE) + { + nvmlResult = nvmlDeviceGetComputeRunningProcesses_v3(device, &processCount, processes.data()); + TLLM_CHECK_WITH_INFO(nvmlResult == NVML_SUCCESS || nvmlResult == NVML_ERROR_INSUFFICIENT_SIZE, + "Failed to get process count: %s", nvmlErrorString(nvmlResult)); + processes.resize(processCount); + } + + // Find current process + for (auto const& process : processes) + { + if (process.pid == currentPid) + { + return process.usedGpuMemory; + } + } + + return 0; + } +}; + +class VirtualMemoryTest : public VirtualMemoryTestBase +{ +}; + +// Test CUDAVirtualMemoryChunk materialize and release memory correctly +TEST_F(VirtualMemoryTest, TestBasic) +{ + CUdeviceptr address{}; + std::size_t constexpr size = 256 * 1024 * 1024; + TLLM_CU_CHECK(cuMemAddressReserve(&address, size, 0, {}, 0)); + + CUDAVirtualMemoryChunk::CreatorPtr creator + = std::make_unique>(CUmemAllocationProp{CU_MEM_ALLOCATION_TYPE_PINNED, CU_MEM_HANDLE_TYPE_NONE, + { + CU_MEM_LOCATION_TYPE_DEVICE, + 0, + }}, + size); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::make_unique(address, size, + CUmemAccessDesc{{ + CU_MEM_LOCATION_TYPE_DEVICE, + 0, + }, + CU_MEM_ACCESS_FLAGS_PROT_READWRITE})); + + CUDAVirtualMemoryChunk vm(std::move(creator), std::move(configurators)); + ASSERT_EQ(vm.status(), CUDAVirtualMemoryChunk::RELEASED); + + auto memoryBegin = getCurrentProcessMemoryInfo(); + vm.materialize(); + ASSERT_EQ(vm.status(), CUDAVirtualMemoryChunk::MATERIALIZED); + + auto memoryMaterialized = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin + size, memoryMaterialized) << "materialize does not allocate memory"; + } + + auto result = cuMemsetD8_v2(address, 255, size); + ASSERT_EQ(result, CUDA_SUCCESS) << "Accessing memory returned failure (first materialize)"; + TLLM_CU_CHECK(cuStreamSynchronize(nullptr)); + + vm.release(); + ASSERT_EQ(vm.status(), CUDAVirtualMemoryChunk::RELEASED); + auto memoryReleased = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin, memoryReleased) << "release does not release memory"; + } + + vm.materialize(); + ASSERT_EQ(vm.status(), CUDAVirtualMemoryChunk::MATERIALIZED); + result = cuMemsetD8_v2(address, 255, size); + ASSERT_EQ(result, CUDA_SUCCESS) << "Accessing memory returned failure (second materialize)"; + TLLM_CU_CHECK(cuStreamSynchronize(nullptr)); + + vm.release(); + ASSERT_EQ(vm.status(), CUDAVirtualMemoryChunk::RELEASED); + memoryReleased = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin, memoryReleased) << "release does not release memory"; + } +} + +// Test BackedConfigurator refills memory correctly for both CPU and PINNED memory types +class VirtualMemoryOffloadConfigurator : public VirtualMemoryTest, public ::testing::WithParamInterface +{ +}; + +TEST_P(VirtualMemoryOffloadConfigurator, Test) +{ + MemoryType backType = GetParam(); + CUdeviceptr address{}; + std::size_t constexpr size = 4 * 1024 * 1024; + TLLM_CU_CHECK(cuMemAddressReserve(&address, size, 0, {}, 0)); + + CudaStream stream; + + CUDAVirtualMemoryChunk::CreatorPtr creator + = std::make_unique>(CUmemAllocationProp{CU_MEM_ALLOCATION_TYPE_PINNED, CU_MEM_HANDLE_TYPE_NONE, + { + CU_MEM_LOCATION_TYPE_DEVICE, + 0, + }}, + size); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::make_unique(address, size, + CUmemAccessDesc{{ + CU_MEM_LOCATION_TYPE_DEVICE, + 0, + }, + CU_MEM_ACCESS_FLAGS_PROT_READWRITE})); + configurators.push_back(std::make_unique(address, size, backType, stream.get(), false)); + + CUDAVirtualMemoryChunk vm(std::move(creator), std::move(configurators)); + + std::vector data(size / sizeof(uint64_t), 0); + std::generate(data.begin(), data.end(), [engine = std::mt19937_64(address)]() mutable { return engine(); }); + + vm.materialize(); + + auto pointer = reinterpret_cast(address); + auto result = cudaMemcpyAsync(pointer, data.data(), size, cudaMemcpyHostToDevice, stream.get()); + ASSERT_EQ(result, CUDA_SUCCESS) << "Copying memory returned failure"; + + vm.release(); + + vm.materialize(); + + std::fill(data.begin(), data.end(), 0); + result = cudaMemcpyAsync(data.data(), pointer, size, cudaMemcpyDeviceToHost, stream.get()); + stream.synchronize(); + ASSERT_EQ(result, CUDA_SUCCESS) << "Copying memory returned failure"; + + auto engine = std::mt19937_64(static_cast(address)); + for (size_t i = 0; i < data.size(); ++i) + { + ASSERT_EQ(data[i], engine()) << "Mismatched at index " << i; + } +} + +INSTANTIATE_TEST_SUITE_P( + Backends, VirtualMemoryOffloadConfigurator, ::testing::Values(MemoryType::kCPU, MemoryType::kPINNED)); + +// Test CUDAVirtualMemoryChunk calls creator and configurators in correct order +TEST_F(VirtualMemoryTest, TestOrder) +{ + // Order tracking - local counter to track call sequence + int callOrder = 0; + + // OrderTrackingCreator that records when its methods are called + class OrderTrackingCreator : public CUDAVirtualMemoryChunk::Creator + { + public: + int& mCallOrder; + int createOrder = -1; + int releaseOrder = -1; + CUmemGenericAllocationHandle createdHandle = 0; + + OrderTrackingCreator(int& callOrder) + : mCallOrder(callOrder) + { + } + + CUmemGenericAllocationHandle create() override + { + createOrder = ++mCallOrder; + createdHandle = 0xbaadf00dbaadf00d; + return createdHandle; + } + + void release(CUmemGenericAllocationHandle handle, bool destructing) override + { + releaseOrder = ++mCallOrder; + ASSERT_EQ(handle, createdHandle); + } + }; + + // OrderTrackingConfigurator that records when its methods are called + class OrderTrackingConfigurator : public CUDAVirtualMemoryChunk::Configurator + { + public: + int& mCallOrder; + std::string name; + int setupOrder = -1; + int teardownOrder = -1; + + OrderTrackingConfigurator(int& callOrder, std::string n) + : mCallOrder(callOrder) + , name(std::move(n)) + { + } + + void setup(CUmemGenericAllocationHandle handle) override + { + setupOrder = ++mCallOrder; + } + + void teardown(CUmemGenericAllocationHandle handle, bool destructing) override + { + teardownOrder = ++mCallOrder; + } + }; + + // Create creator and configurators + auto creator = std::make_unique(callOrder); + auto* creatorPtr = creator.get(); + + auto config1 = std::make_unique(callOrder, "config1"); + auto config2 = std::make_unique(callOrder, "config2"); + auto config3 = std::make_unique(callOrder, "config3"); + auto* config1Ptr = config1.get(); + auto* config2Ptr = config2.get(); + auto* config3Ptr = config3.get(); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(config1)); + configurators.push_back(std::move(config2)); + configurators.push_back(std::move(config3)); + + CUDAVirtualMemoryChunk vm(std::move(creator), std::move(configurators)); + + // Test materialize() order: creator.create() first, then configurators.setup() in order + vm.materialize(); + + // Verify materialize order + EXPECT_EQ(creatorPtr->createOrder, 1); // creator.create() should be called first + EXPECT_EQ(config1Ptr->setupOrder, 2); // config1.setup() should be called second + EXPECT_EQ(config2Ptr->setupOrder, 3); // config2.setup() should be called third + EXPECT_EQ(config3Ptr->setupOrder, 4); // config3.setup() should be called fourth + + // Verify release() hasn't been called yet + EXPECT_EQ(creatorPtr->releaseOrder, -1); + EXPECT_EQ(config1Ptr->teardownOrder, -1); + EXPECT_EQ(config2Ptr->teardownOrder, -1); + EXPECT_EQ(config3Ptr->teardownOrder, -1); + + // Test release() order: configurators.teardown() in reverse order, then creator.release() + vm.release(); + + // Verify release order + EXPECT_EQ(config3Ptr->teardownOrder, 5); // config3.teardown() should be called first (reverse order) + EXPECT_EQ(config2Ptr->teardownOrder, 6); // config2.teardown() should be called second + EXPECT_EQ(config1Ptr->teardownOrder, 7); // config1.teardown() should be called third + EXPECT_EQ(creatorPtr->releaseOrder, 8); // creator.release() should be called last +} + +// Test CUDAVirtualMemoryChunk behaves correctly when exceptions were thrown +TEST_F(VirtualMemoryTest, TestException) +{ + // Dummy Creator that can be configured to throw on create() or release() + class DummyCreator : public CUDAVirtualMemoryChunk::Creator + { + public: + bool throwOnCreate = false; + bool throwOnRelease = false; + bool createCalled = false; + bool releaseCalled = false; + CUmemGenericAllocationHandle createdHandle = 0; + + CUmemGenericAllocationHandle create() override + { + createCalled = true; + if (throwOnCreate) + { + throw DummyException(); + } + createdHandle = 0xbaadf00dbaadf00d; + return createdHandle; + } + + void release(CUmemGenericAllocationHandle handle, bool destructing) override + { + releaseCalled = true; + ASSERT_EQ(handle, createdHandle); + if (throwOnRelease) + { + throw DummyException(); + } + } + }; + + // Dummy Configurator that can be configured to throw on setup() or teardown() + class DummyConfigurator : public CUDAVirtualMemoryChunk::Configurator + { + public: + bool throwOnSetup = false; + bool throwOnTeardown = false; + bool setupCalled = false; + bool teardownCalled = false; + std::string name; + + DummyConfigurator(std::string n) + : name(std::move(n)) + { + } + + void setup(CUmemGenericAllocationHandle) override + { + setupCalled = true; + if (throwOnSetup) + { + throw DummyException(); + } + } + + void teardown(CUmemGenericAllocationHandle handle, bool destructing) override + { + teardownCalled = true; + if (throwOnTeardown) + { + throw DummyException(); + } + } + }; + + // Test 1: Exception in creator->create() + { + auto creator = std::make_unique(); + creator->throwOnCreate = true; + auto* creatorPtr = creator.get(); + + auto config1 = std::make_unique("config1"); + auto config2 = std::make_unique("config2"); + auto* config1Ptr = config1.get(); + auto* config2Ptr = config2.get(); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(config1)); + configurators.push_back(std::move(config2)); + + CUDAVirtualMemoryChunk vm(std::move(creator), std::move(configurators)); + + // materialize() should throw due to creator->create() exception + EXPECT_THROW(vm.materialize(), DummyException); + + // Verify creator->create() was called but no configurators were setup + EXPECT_TRUE(creatorPtr->createCalled); + EXPECT_FALSE(config1Ptr->setupCalled); + EXPECT_FALSE(config2Ptr->setupCalled); + + // Internal state is still valid. + // If the failure from creator is temporary, materialize() can be reattempted. + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::RELEASED); + } + + // Test 2: Exception in first configurator setup() + { + auto creator = std::make_unique(); + auto* creatorPtr = creator.get(); + + auto config1 = std::make_unique("config1"); + auto config2 = std::make_unique("config2"); + config1->throwOnSetup = true; + auto* config1Ptr = config1.get(); + auto* config2Ptr = config2.get(); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(config1)); + configurators.push_back(std::move(config2)); + + CUDAVirtualMemoryChunk vm(std::move(creator), std::move(configurators)); + + // materialize() should throw due to first configurator exception + EXPECT_THROW(vm.materialize(), DummyException); + + // Verify creator->create() was called and first configurator setup() was called + EXPECT_TRUE(creatorPtr->createCalled); + EXPECT_TRUE(config1Ptr->setupCalled); + EXPECT_FALSE(config2Ptr->setupCalled); + + // Status should be ERRORED + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::ERRORED); + + // release() should still work and only teardown what was set up + vm.release(); + EXPECT_TRUE(creatorPtr->releaseCalled); + EXPECT_FALSE(config1Ptr->teardownCalled); // Failed setup, so no teardown + EXPECT_FALSE(config2Ptr->teardownCalled); // Never setup + } + + // Test 3: Exception in second configurator setup() + { + auto creator = std::make_unique(); + auto* creatorPtr = creator.get(); + + auto config1 = std::make_unique("config1"); + auto config2 = std::make_unique("config2"); + config2->throwOnSetup = true; + auto* config1Ptr = config1.get(); + auto* config2Ptr = config2.get(); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(config1)); + configurators.push_back(std::move(config2)); + + CUDAVirtualMemoryChunk vm(std::move(creator), std::move(configurators)); + + // materialize() should throw due to second configurator exception + EXPECT_THROW(vm.materialize(), DummyException); + + // Verify both creator and first configurator were called + EXPECT_TRUE(creatorPtr->createCalled); + EXPECT_TRUE(config1Ptr->setupCalled); + EXPECT_TRUE(config2Ptr->setupCalled); + + // Status should be ERRORED + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::ERRORED); + + // release() should teardown the first configurator (successful setup) but not the second + vm.release(); + EXPECT_TRUE(creatorPtr->releaseCalled); + EXPECT_TRUE(config1Ptr->teardownCalled); // Successful setup, so teardown called + EXPECT_FALSE(config2Ptr->teardownCalled); // Failed setup, so no teardown + } + + // Test 4: Exception in configurator teardown() during release() + { + auto creator = std::make_unique(); + auto* creatorPtr = creator.get(); + + auto config1 = std::make_unique("config1"); + auto config2 = std::make_unique("config2"); + config2->throwOnTeardown = true; + auto* config1Ptr = config1.get(); + auto* config2Ptr = config2.get(); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(config1)); + configurators.push_back(std::move(config2)); + + CUDAVirtualMemoryChunk vm(std::move(creator), std::move(configurators)); + + // materialize() should succeed + vm.materialize(); + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::MATERIALIZED); + + // release() should throw due to teardown exception but still complete cleanup + EXPECT_THROW(vm.release(), DummyException); + + // Verify all teardown methods were called despite exception + EXPECT_TRUE(config1Ptr->teardownCalled); + EXPECT_TRUE(config2Ptr->teardownCalled); + EXPECT_TRUE(creatorPtr->releaseCalled); + + // Status should be ERRORED due to exception + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::ERRORED); + } + + // Test 5: Exception in creator->release() + { + auto creator = std::make_unique(); + creator->throwOnRelease = true; + auto* creatorPtr = creator.get(); + + auto config1 = std::make_unique("config1"); + auto* config1Ptr = config1.get(); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(config1)); + + CUDAVirtualMemoryChunk vm(std::move(creator), std::move(configurators)); + + // materialize() should succeed + vm.materialize(); + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::MATERIALIZED); + + // release() should throw due to creator exception but still complete configurator cleanup + EXPECT_THROW(vm.release(), DummyException); + + // Verify configurator teardown was called despite creator exception + EXPECT_TRUE(config1Ptr->teardownCalled); + EXPECT_TRUE(creatorPtr->releaseCalled); + + // Status should be ERRORED due to exception + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::ERRORED); + } +} + +// Test various class facilities +TEST_F(VirtualMemoryTest, TestFacilities) +{ + // Test default constructed CUDAVirtualMemoryChunk + { + CUDAVirtualMemoryChunk defaultVm; + + // Should be invalid + EXPECT_FALSE(defaultVm); + EXPECT_EQ(defaultVm.status(), CUDAVirtualMemoryChunk::INVALID); + } + + CUdeviceptr address{}; + std::size_t constexpr size = 64 * 1024 * 1024; + TLLM_CU_CHECK(cuMemAddressReserve(&address, size, 0, {}, 0)); + // Test move semantic + { + + // Create original CUDAVirtualMemoryChunk + CUDAVirtualMemoryChunk::CreatorPtr creator + = std::make_unique>(CUmemAllocationProp{CU_MEM_ALLOCATION_TYPE_PINNED, + CU_MEM_HANDLE_TYPE_NONE, {CU_MEM_LOCATION_TYPE_DEVICE, 0}}, + size); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::make_unique( + address, size, CUmemAccessDesc{{CU_MEM_LOCATION_TYPE_DEVICE, 0}, CU_MEM_ACCESS_FLAGS_PROT_READWRITE})); + + CUDAVirtualMemoryChunk original(std::move(creator), std::move(configurators)); + original.materialize(); + EXPECT_EQ(original.status(), CUDAVirtualMemoryChunk::MATERIALIZED); + + // Test move constructor + CUDAVirtualMemoryChunk moved{std::move(original)}; + EXPECT_FALSE(original); // Original should be invalid after move + EXPECT_TRUE(moved); // Moved-to object should be valid + EXPECT_EQ(moved.status(), CUDAVirtualMemoryChunk::MATERIALIZED); + + // Test move assignment + CUDAVirtualMemoryChunk assigned; + EXPECT_FALSE(assigned); // Default constructed, should be invalid + + assigned = std::move(moved); + EXPECT_FALSE(moved); // moved should be invalid after move + EXPECT_TRUE(assigned); // assigned should be valid + EXPECT_EQ(assigned.status(), CUDAVirtualMemoryChunk::MATERIALIZED); + + // Clean up + assigned.release(); + } +} + +// Test destructor +TEST_F(VirtualMemoryTest, TestDestructor) +{ + + // Dummy Creator for testing destructor behavior + class DummyCreator : public CUDAVirtualMemoryChunk::Creator + { + public: + bool& createCalledRef; + bool& releaseCalledRef; + CUmemGenericAllocationHandle createdHandle = 0; + + DummyCreator(bool& createRef, bool& releaseRef) + : createCalledRef(createRef) + , releaseCalledRef(releaseRef) + { + } + + CUmemGenericAllocationHandle create() override + { + createCalledRef = true; + createdHandle = 0xbaadf00dbaadf00d; + return createdHandle; + } + + void release(CUmemGenericAllocationHandle handle, bool destructing) override + { + releaseCalledRef = true; + ASSERT_EQ(handle, createdHandle); + } + }; + + // Dummy Configurator for testing destructor behavior + class DummyConfigurator : public CUDAVirtualMemoryChunk::Configurator + { + public: + bool& setupCalledRef; + bool& teardownCalledRef; + std::string name; + + DummyConfigurator(std::string n, bool& setupRef, bool& teardownRef) + : setupCalledRef(setupRef) + , teardownCalledRef(teardownRef) + , name(std::move(n)) + { + } + + void setup(CUmemGenericAllocationHandle) override + { + setupCalledRef = true; + } + + void teardown(CUmemGenericAllocationHandle, bool) override + { + teardownCalledRef = true; + } + }; + + // Test destructor calls release automatically for materialized memory + { + bool createCalled = false; + bool releaseCalled = false; + bool setupCalled = false; + bool teardownCalled = false; + + auto creator = std::make_unique(createCalled, releaseCalled); + auto config1 = std::make_unique("config1", setupCalled, teardownCalled); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(config1)); + + alignas(CUDAVirtualMemoryChunk) std::byte storage[sizeof(CUDAVirtualMemoryChunk)]; + CUDAVirtualMemoryChunk* vm = new (storage) CUDAVirtualMemoryChunk(std::move(creator), std::move(configurators)); + + vm->materialize(); + + // Verify materialize was called + EXPECT_TRUE(createCalled); + EXPECT_TRUE(setupCalled); + EXPECT_FALSE(releaseCalled); + EXPECT_FALSE(teardownCalled); + EXPECT_EQ(vm->status(), CUDAVirtualMemoryChunk::MATERIALIZED); + + vm->~CUDAVirtualMemoryChunk(); + + // Verify destructor called release + EXPECT_TRUE(releaseCalled); + EXPECT_TRUE(teardownCalled); + } + + // Test destructor doesn't double-release for manually released memory + { + // Local variables to track calls (persist after object destruction) + bool createCalled = false; + bool releaseCalled = false; + bool setupCalled = false; + bool teardownCalled = false; + + auto creator = std::make_unique(createCalled, releaseCalled); + auto config1 = std::make_unique("config1", setupCalled, teardownCalled); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(config1)); + + alignas(CUDAVirtualMemoryChunk) std::byte storage[sizeof(CUDAVirtualMemoryChunk)]; + auto* vm = new (storage) CUDAVirtualMemoryChunk(std::move(creator), std::move(configurators)); + + vm->materialize(); + vm->release(); // Manual release + + // Verify manual release was called + EXPECT_TRUE(releaseCalled); + EXPECT_TRUE(teardownCalled); + EXPECT_EQ(vm->status(), CUDAVirtualMemoryChunk::RELEASED); + + // Reset flags to verify destructor doesn't call release again + releaseCalled = false; + teardownCalled = false; + + vm->~CUDAVirtualMemoryChunk(); + + // Verify destructor did NOT call release again (no double-release) + EXPECT_FALSE(releaseCalled); + EXPECT_FALSE(teardownCalled); + } + + // Test destructor behavior with ERRORED state + { + // Local variables to track calls (persist after object destruction) + bool createCalled = false; + bool releaseCalled = false; + bool config1SetupCalled = false; + bool config1TeardownCalled = false; + bool throwingSetupCalled = false; + bool throwingTeardownCalled = false; + + class ThrowingConfigurator : public CUDAVirtualMemoryChunk::Configurator + { + public: + bool& setupCalledRef; + bool& teardownCalledRef; + + ThrowingConfigurator(bool& setupRef, bool& teardownRef) + : setupCalledRef(setupRef) + , teardownCalledRef(teardownRef) + { + } + + void setup(CUmemGenericAllocationHandle) override + { + setupCalledRef = true; + throw DummyException(); + } + + void teardown(CUmemGenericAllocationHandle, bool) override + { + teardownCalledRef = true; + } + }; + + auto creator = std::make_unique(createCalled, releaseCalled); + auto config1 = std::make_unique("config1", config1SetupCalled, config1TeardownCalled); + auto throwingConfig = std::make_unique(throwingSetupCalled, throwingTeardownCalled); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(config1)); + configurators.push_back(std::move(throwingConfig)); + + alignas(CUDAVirtualMemoryChunk) std::byte storage[sizeof(CUDAVirtualMemoryChunk)]; + auto* vm = new (storage) CUDAVirtualMemoryChunk(std::move(creator), std::move(configurators)); + + // Materialize should throw and leave VM in ERRORED state + EXPECT_THROW(vm->materialize(), DummyException); + EXPECT_EQ(vm->status(), CUDAVirtualMemoryChunk::ERRORED); + + // Verify partial setup occurred + EXPECT_TRUE(createCalled); + EXPECT_TRUE(config1SetupCalled); + EXPECT_TRUE(throwingSetupCalled); + EXPECT_FALSE(releaseCalled); + + vm->~CUDAVirtualMemoryChunk(); + + // Verify destructor called release to clean up the errored state + EXPECT_TRUE(releaseCalled); + EXPECT_TRUE(config1TeardownCalled); + // throwingConfig's teardown should NOT be called since setup failed + EXPECT_FALSE(throwingTeardownCalled); + } +} + +// Test edge cases and error scenarios +TEST_F(VirtualMemoryTest, TestEdgeCases) +{ + // Dummy Creator for testing edge cases + class DummyCreator : public CUDAVirtualMemoryChunk::Creator + { + public: + CUmemGenericAllocationHandle createdHandle = 0xbaadf00dbaadf00d; + + CUmemGenericAllocationHandle create() override + { + return createdHandle; + } + + void release(CUmemGenericAllocationHandle handle, bool destructing) override + { + ASSERT_EQ(handle, createdHandle); + } + }; + + // Test multiple materialize calls (should throw) + { + auto creator = std::make_unique(); + CUDAVirtualMemoryChunk vm(std::move(creator), {}); + + vm.materialize(); + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::MATERIALIZED); + + // Second materialize should throw + EXPECT_THROW(vm.materialize(), tc::TllmException); + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::MATERIALIZED); + + vm.release(); + } + + // Test multiple release calls (should throw) + { + auto creator = std::make_unique(); + CUDAVirtualMemoryChunk vm(std::move(creator), {}); + + vm.materialize(); + vm.release(); + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::RELEASED); + + // Second release should throw + EXPECT_THROW(vm.release(), tc::TllmException); + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::RELEASED); + } + + // Test release on RELEASED state (should throw) + { + auto creator = std::make_unique(); + CUDAVirtualMemoryChunk vm(std::move(creator), {}); + + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::RELEASED); + EXPECT_THROW(vm.release(), tc::TllmException); // Should throw on RELEASED state + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::RELEASED); + } + + // Test materialize on ERRORED state after exception recovery + { + // Create a VM that will go into ERRORED state + class ThrowingConfigurator : public CUDAVirtualMemoryChunk::Configurator + { + public: + bool shouldThrow = true; + + void setup(CUmemGenericAllocationHandle) override + { + if (shouldThrow) + { + throw DummyException(); + } + } + + void teardown(CUmemGenericAllocationHandle, bool) override {} + }; + + auto creator = std::make_unique(); + auto throwingConfig = std::make_unique(); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::move(throwingConfig)); + + CUDAVirtualMemoryChunk vm(std::move(creator), std::move(configurators)); + + // First materialize should throw and leave VM in ERRORED state + EXPECT_THROW(vm.materialize(), DummyException); + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::ERRORED); + + // Should be able to release from ERRORED state + vm.release(); + EXPECT_EQ(vm.status(), CUDAVirtualMemoryChunk::RELEASED); + } +} + +class VirtualMemoryManagerTest : public VirtualMemoryTestBase // NOLINT(cppcoreguidelines-pro-type-member-init) +{ + using Base = VirtualMemoryTestBase; + +protected: + auto& entries() + { + return mVMManager->mEntries; + } + + auto& memories() + { + return mVMManager->mMemories; + } + + auto& badHandles() + { + return mVMManager->mBadHandles; + } + + void SetUp() override + { + this->Base::SetUp(); + mVMManager = std::make_unique(); + } + + void TearDown() override + { + this->Base::TearDown(); + ASSERT_TRUE(!mVMManager || entries().size() == 0) << "Leftover memory in manager"; + } + + std::unique_ptr mVMManager = nullptr; +}; + +TEST_F(VirtualMemoryManagerTest, TestBasic) +{ + CUdeviceptr address{}; + std::size_t constexpr size = 256 * 1024 * 1024; + TLLM_CU_CHECK(cuMemAddressReserve(&address, size, 0, {}, 0)); + + uintptr_t handle = static_cast(address); + std::string tag = "test_tag"; + + CUDAVirtualMemoryChunk::CreatorPtr creator + = std::make_unique>(CUmemAllocationProp{CU_MEM_ALLOCATION_TYPE_PINNED, CU_MEM_HANDLE_TYPE_NONE, + { + CU_MEM_LOCATION_TYPE_DEVICE, + 0, + }}, + size); + + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::make_unique(address, size, + CUmemAccessDesc{{ + CU_MEM_LOCATION_TYPE_DEVICE, + 0, + }, + CU_MEM_ACCESS_FLAGS_PROT_READWRITE})); + + auto memoryBegin = getCurrentProcessMemoryInfo(); + + // Add to manager - this automatically materializes + mVMManager->add(handle, tag, std::move(creator), std::move(configurators)); + + auto memoryMaterialized = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin + size, memoryMaterialized) << "add/materialize does not allocate memory"; + } + + // Test memory access after materialization + auto result = cuMemsetD8_v2(address, 255, size); + ASSERT_EQ(result, CUDA_SUCCESS) << "Accessing memory returned failure (first materialize)"; + TLLM_CU_CHECK(cuStreamSynchronize(nullptr)); + + // Release memory through manager + auto releaseCount = mVMManager->releaseWithTag(tag); + ASSERT_EQ(releaseCount, 1) << "Expected to release 1 memory object"; + + auto memoryReleased = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin, memoryReleased) << "releaseWithTag does not release memory"; + } + + // Materialize again through manager + auto materializeCount = mVMManager->materializeWithTag(tag); + ASSERT_EQ(materializeCount, 1) << "Expected to materialize 1 memory object"; + + auto memoryRematerialized = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin + size, memoryRematerialized) << "materializeWithTag does not allocate memory"; + } + + // Test memory access after rematerialization + result = cuMemsetD8_v2(address, 255, size); + ASSERT_EQ(result, CUDA_SUCCESS) << "Accessing memory returned failure (second materialize)"; + TLLM_CU_CHECK(cuStreamSynchronize(nullptr)); + + // Clean up - remove from manager + { + auto removedMemory = mVMManager->remove(handle); + ASSERT_TRUE(removedMemory) << "Expected to successfully remove memory from manager"; + } + + auto memoryAfterRemove = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin, memoryAfterRemove) << "remove does not release memory"; + } + + auto unknownMemory = mVMManager->remove(0); + ASSERT_FALSE(unknownMemory) << "Expect invalid memory for unknown handle"; +} + +TEST_F(VirtualMemoryManagerTest, TestTags) +{ + // Dummy Creator for testing tag functionality + class DummyCreator : public CUDAVirtualMemoryChunk::Creator + { + public: + bool createCalled = false; + bool releaseCalled = false; + CUmemGenericAllocationHandle createdHandle = 0xbaadf00dbaadf00d; + + CUmemGenericAllocationHandle create() override + { + createCalled = true; + return createdHandle; + } + + void release(CUmemGenericAllocationHandle handle, bool destructing) override + { + releaseCalled = true; + ASSERT_EQ(handle, createdHandle); + } + }; + + // Create creators for different virtual memories + auto creator1 = std::make_unique(); + auto creator2 = std::make_unique(); + auto creator3 = std::make_unique(); + auto creator4 = std::make_unique(); + + // Keep pointers to track state + auto* creator1Ptr = creator1.get(); + auto* creator2Ptr = creator2.get(); + auto* creator3Ptr = creator3.get(); + auto* creator4Ptr = creator4.get(); + + mVMManager->add(0x1000, "tag_A", std::move(creator1), {}); + mVMManager->add(0x2000, "tag_B", std::move(creator2), {}); + mVMManager->add(0x3000, "tag_A", std::move(creator3), {}); + mVMManager->add(0x4000, "tag_C", std::move(creator4), {}); + + // All should be materialized initially (since add() materializes automatically) + EXPECT_TRUE(creator1Ptr->createCalled); + EXPECT_TRUE(creator2Ptr->createCalled); + EXPECT_TRUE(creator3Ptr->createCalled); + EXPECT_TRUE(creator4Ptr->createCalled); + + // Reset create flags to test materializeWithTag later + creator1Ptr->createCalled = false; + creator2Ptr->createCalled = false; + creator3Ptr->createCalled = false; + creator4Ptr->createCalled = false; + + // Test releaseWithTag - should release only memories with "tag_A" + auto releaseCount = mVMManager->releaseWithTag("tag_A"); + EXPECT_EQ(releaseCount, 2); // Should release 2 memories with tag_A + + // Verify only tag_A memories were released + EXPECT_TRUE(creator1Ptr->releaseCalled); // tag_A + EXPECT_FALSE(creator2Ptr->releaseCalled); // tag_B + EXPECT_TRUE(creator3Ptr->releaseCalled); // tag_A + EXPECT_FALSE(creator4Ptr->releaseCalled); // tag_C + + // Test materializeWithTag - should materialize only memories with "tag_A" + auto materializeCount = mVMManager->materializeWithTag("tag_A"); + EXPECT_EQ(materializeCount, 2); // Should materialize 2 memories with tag_A + + // Verify only tag_A memories were materialized + EXPECT_TRUE(creator1Ptr->createCalled); // tag_A + EXPECT_FALSE(creator2Ptr->createCalled); // tag_B + EXPECT_TRUE(creator3Ptr->createCalled); // tag_A + EXPECT_FALSE(creator4Ptr->createCalled); // tag_C + + // Reset flags and test releasing with a different tag + creator2Ptr->releaseCalled = false; + releaseCount = mVMManager->releaseWithTag("tag_B"); + EXPECT_EQ(releaseCount, 1); // Should release 1 memory with tag_B + EXPECT_TRUE(creator2Ptr->releaseCalled); // tag_B should now be released + + // Test with non-existent tag + releaseCount = mVMManager->releaseWithTag("nonexistent_tag"); + EXPECT_EQ(releaseCount, 0); // Should release 0 memories + + materializeCount = mVMManager->materializeWithTag("nonexistent_tag"); + EXPECT_EQ(materializeCount, 0); // Should materialize 0 memories + + // Clean up - remove all memories + mVMManager->remove(0x1000); + mVMManager->remove(0x2000); + mVMManager->remove(0x3000); + mVMManager->remove(0x4000); +} + +TEST_F(VirtualMemoryManagerTest, TestAddException) +{ + // Dummy Creator that succeeds + class DummyCreator : public CUDAVirtualMemoryChunk::Creator + { + public: + CUmemGenericAllocationHandle createdHandle = 0xbaadf00dbaadf00d; + + CUmemGenericAllocationHandle create() override + { + return createdHandle; + } + + void release(CUmemGenericAllocationHandle handle, bool destructing) override + { + ASSERT_EQ(handle, createdHandle); + } + }; + + // Dummy Configurator that throws during setup + class ThrowingConfigurator : public CUDAVirtualMemoryChunk::Configurator + { + public: + void setup(CUmemGenericAllocationHandle) override + { + throw DummyException(); + } + + void teardown(CUmemGenericAllocationHandle, bool) override + { + ASSERT_TRUE(false) << "Unreachable"; + } + }; + + uintptr_t handle = 0x12345678; + std::string tag = "test_tag"; + + // Verify initial state is clean + EXPECT_TRUE(memories().empty()); + EXPECT_TRUE(entries().empty()); + EXPECT_TRUE(badHandles().empty()); + + auto creator = std::make_unique(); + CUDAVirtualMemoryChunk::Configurators configurators; + configurators.push_back(std::make_unique()); + + // add() should throw because materialize() will fail due to ThrowingConfigurator + EXPECT_THROW(mVMManager->add(handle, tag, std::move(creator), std::move(configurators)), DummyException); + + // Verify that the manager state is clean after the exception + // The ScopeGuards in add() should have cleaned up properly + EXPECT_TRUE(memories().empty()) << "mMemories should be empty after failed add()"; + EXPECT_TRUE(entries().empty()) << "mEntries should be empty after failed add()"; + EXPECT_TRUE(badHandles().empty()) << "mBadHandles should be empty after failed add()"; + + // Test that we can successfully add a memory with the same handle after the failure + auto successCreator = std::make_unique(); + CUDAVirtualMemoryChunk::Configurators successConfigurators; // Empty configurators should work + + // This should succeed without throwing + EXPECT_NO_THROW(mVMManager->add(handle, tag, std::move(successCreator), std::move(successConfigurators))); + + // Verify that the manager now has the entry + EXPECT_EQ(memories().size(), 1); + EXPECT_EQ(entries().size(), 1); + EXPECT_TRUE(badHandles().empty()); + + // Clean up + auto removedMemory = mVMManager->remove(handle); + EXPECT_TRUE(removedMemory); +} + +TEST_F(VirtualMemoryManagerTest, TestMaterializeException) +{ + // State structure to track create/release order and can throw on a specific call + struct CreatorState + { + int& createCounter; // Reference to shared counter + int throwOnCreateIdx = 0; // 1-based index to throw on create + int myCreateIdx = INT_MAX; + bool createCalled = false; + bool releaseCalled = false; + CUmemGenericAllocationHandle createdHandle = 0xbaadf00dbaadf00d; + + CreatorState(int& sharedCounter) + : createCounter(sharedCounter) + { + } + }; + + // Dummy Creator that uses external state + class TestMatEx_DummyCreator : public CUDAVirtualMemoryChunk::Creator + { + public: + CreatorState& state; + + TestMatEx_DummyCreator(CreatorState& state) + : state(state) + { + } + + CUmemGenericAllocationHandle create() override + { + state.createCalled = true; + state.myCreateIdx = ++state.createCounter; + if (state.throwOnCreateIdx > 0 && state.myCreateIdx == state.throwOnCreateIdx) + { + throw DummyException(); + } + return state.createdHandle; + } + + void release(CUmemGenericAllocationHandle handle, bool destructing) override + { + state.releaseCalled = true; + ASSERT_EQ(handle, state.createdHandle); + } + }; + + // Create shared counter + int sharedCreateCounter = 0; + + // Create state objects for each creator + CreatorState state1(sharedCreateCounter); + CreatorState state2(sharedCreateCounter); + CreatorState state3(sharedCreateCounter); + + // We want the second memory (by create order) to throw + state1.throwOnCreateIdx = 2; + state2.throwOnCreateIdx = 2; + state3.throwOnCreateIdx = 2; + + // Create creators and configurators + auto creator1 = std::make_unique(state1); + auto creator2 = std::make_unique(state2); + auto creator3 = std::make_unique(state3); + + // Add memories to manager in RELEASED state (don't auto-materialize by constructing manually) + CUDAVirtualMemoryChunk vm1(std::move(creator1), {}); + CUDAVirtualMemoryChunk vm2(std::move(creator2), {}); + CUDAVirtualMemoryChunk vm3(std::move(creator3), {}); + + mVMManager->add(0x1000, "test_tag", std::move(vm1)); + mVMManager->add(0x2000, "test_tag", std::move(vm2)); + mVMManager->add(0x3000, "test_tag", std::move(vm3)); + + // Verify initial state is clean + EXPECT_TRUE(badHandles().empty()); + + // materializeWithTag should stop at the first exception (second memory by create order) + // and attempt to rollback the first memory that succeeded + EXPECT_THROW(mVMManager->materializeWithTag("test_tag"), DummyException); + + // Find which creators were called and in what order + std::vector> creators + = {{0x1000, &state1}, {0x2000, &state2}, {0x3000, &state3}}; + // Sort by myCreateIdx (nonzero means create was called) + std::sort(creators.begin(), creators.end(), + [](auto const& a, auto const& b) { return a.second->myCreateIdx < b.second->myCreateIdx; }); + + // The first memory (by create order) should have been materialized then released during rollback + auto* first = creators[0].second; + EXPECT_TRUE(first->createCalled); + EXPECT_TRUE(first->releaseCalled); // Rolled back + // The second memory should have thrown during setup, so creator was called but setup failed + auto* second = creators[1].second; + EXPECT_TRUE(second->createCalled); + EXPECT_FALSE(second->releaseCalled); + // The third memory should not have been touched (myCreateIdx == 0) + auto* third = creators[2].second; + EXPECT_FALSE(third->createCalled); + EXPECT_FALSE(third->releaseCalled); + + // The handle of the memory that threw should be the second one's handle + uintptr_t thrownHandle = creators[1].first; + + // Verify bad handles tracking - memories that threw exceptions should be removed + auto badHandles = mVMManager->retrieveBadHandles(); + EXPECT_EQ(badHandles.size(), 1); + EXPECT_EQ(badHandles[0], thrownHandle); + + // Verify the memory that threw was removed from the manager + auto removedMem = mVMManager->remove(thrownHandle); + EXPECT_FALSE(removedMem); // Should have been removed due to exception + + // The other two memories should still be in manager + for (int i = 0; i < 3; ++i) + { + if (creators[i].first != thrownHandle) + { + auto removed = mVMManager->remove(creators[i].first); + EXPECT_TRUE(removed); + } + } +} + +TEST_F(VirtualMemoryManagerTest, TestReleaseException) +{ + // State structure to track create/release calls + struct CreatorState + { + bool createCalled = false; + bool releaseCalled = false; + int& releaseCounter; + int throwOnReleaseCount; + CUmemGenericAllocationHandle createdHandle = 0xbaadf00dbaadf00d; + + CreatorState(int& counter, int throwCount) + : releaseCounter(counter) + , throwOnReleaseCount(throwCount) + { + } + }; + + // State structure to track setup/teardown calls + struct ConfiguratorState + { + bool setupCalled = false; + bool teardownCalled = false; + int& teardownCounter; + int throwOnTeardownCount; + + ConfiguratorState(int& counter, int throwCount) + : teardownCounter(counter) + , throwOnTeardownCount(throwCount) + { + } + }; + + // Dummy Creator that succeeds + class DummyCreator : public CUDAVirtualMemoryChunk::Creator + { + public: + CreatorState& state; + + DummyCreator(CreatorState& state) + : state(state) + { + } + + CUmemGenericAllocationHandle create() override + { + state.createCalled = true; + return state.createdHandle; + } + + void release(CUmemGenericAllocationHandle handle, bool destructing) override + { + state.releaseCalled = true; + ASSERT_EQ(handle, state.createdHandle); + if (++state.releaseCounter == state.throwOnReleaseCount) + { + throw DummyException(); + } + } + }; + + // Dummy Configurator that succeeds + class DummyConfigurator : public CUDAVirtualMemoryChunk::Configurator + { + public: + ConfiguratorState& state; + + DummyConfigurator(ConfiguratorState& state) + : state(state) + { + } + + void setup(CUmemGenericAllocationHandle) override + { + state.setupCalled = true; + } + + void teardown(CUmemGenericAllocationHandle, bool) override + { + state.teardownCalled = true; + if (++state.teardownCounter == state.throwOnTeardownCount) + { + throw DummyException(); + } + } + }; + + // Create counters for tracking release/teardown calls + int releaseCounter = 0; + int teardownCounter = 0; + + // Create state objects for each creator and configurator + CreatorState state1(releaseCounter, 2); // Throw on 2nd release + CreatorState state2(releaseCounter, 2); // Throw on 2nd release + CreatorState state3(releaseCounter, 2); // Throw on 2nd release + CreatorState state4(releaseCounter, 2); // Throw on 2nd release + + ConfiguratorState configState1(teardownCounter, 3); // Throw on 3rd teardown + ConfiguratorState configState2(teardownCounter, 3); // Throw on 3rd teardown + ConfiguratorState configState3(teardownCounter, 3); // Throw on 3rd teardown + ConfiguratorState configState4(teardownCounter, 3); // Throw on 3rd teardown + + // Create creators and configurators + auto creator1 = std::make_unique(state1); + auto creator2 = std::make_unique(state2); + auto creator3 = std::make_unique(state3); + auto creator4 = std::make_unique(state4); + + auto config1 = std::make_unique(configState1); + auto config2 = std::make_unique(configState2); + auto config3 = std::make_unique(configState3); + auto config4 = std::make_unique(configState4); + + CUDAVirtualMemoryChunk::Configurators configurators1; + configurators1.push_back(std::move(config1)); + + CUDAVirtualMemoryChunk::Configurators configurators2; + configurators2.push_back(std::move(config2)); + + CUDAVirtualMemoryChunk::Configurators configurators3; + configurators3.push_back(std::move(config3)); + + CUDAVirtualMemoryChunk::Configurators configurators4; + configurators4.push_back(std::move(config4)); + + mVMManager->add(0x1000, "test_tag", std::move(creator1), std::move(configurators1)); + mVMManager->add(0x2000, "test_tag", std::move(creator2), std::move(configurators2)); + mVMManager->add(0x3000, "test_tag", std::move(creator3), std::move(configurators3)); + mVMManager->add(0x4000, "other_tag", std::move(creator4), std::move(configurators4)); + + // Verify initial state + EXPECT_TRUE(badHandles().empty()); + + // releaseWithTag should call release on all memories with "test_tag" + // and continue despite exceptions + EXPECT_THROW(mVMManager->releaseWithTag("test_tag"), DummyException); + + // Verify behavior: + // - All memories with "test_tag" should have had release() attempted + EXPECT_TRUE(state1.releaseCalled); + EXPECT_TRUE(configState1.teardownCalled); + + EXPECT_TRUE(state2.releaseCalled); + EXPECT_TRUE(configState2.teardownCalled); + + EXPECT_TRUE(state3.releaseCalled); + EXPECT_TRUE(configState3.teardownCalled); + + // - Memory with different tag should not be affected + EXPECT_FALSE(state4.releaseCalled); + EXPECT_FALSE(configState4.teardownCalled); + + // Verify bad handles tracking - memories that threw exceptions should be removed + auto badHandles = mVMManager->retrieveBadHandles(); + EXPECT_EQ(badHandles.size(), 2); + EXPECT_NE(std::find(badHandles.begin(), badHandles.end(), 0x2000), badHandles.end()); + EXPECT_NE(std::find(badHandles.begin(), badHandles.end(), 0x3000), badHandles.end()); + + // Verify the memories were removed from the manager + auto removedMem1 = mVMManager->remove(0x1000); + auto removedMem2 = mVMManager->remove(0x2000); + auto removedMem3 = mVMManager->remove(0x3000); + auto removedMem4 = mVMManager->remove(0x4000); + + EXPECT_TRUE(removedMem1); // Should have been removed due to exception + EXPECT_FALSE(removedMem2); // Should have been removed due to exception + EXPECT_FALSE(removedMem3); // Should have been removed due to exception + EXPECT_TRUE(removedMem4); // Should still be in manager (different tag, not affected) +} + +TEST_F(VirtualMemoryManagerTest, TestCudaVirtualMemoryAllocator) +{ + std::size_t constexpr size = 64 * 1024 * 1024; // 64 MB + std::string tag = "test_allocator_tag"; + + // Create a CUDA stream for the allocator + CudaStream stream; + auto streamPtr = std::make_shared(std::move(stream)); + + // Create configuration for the virtual address allocator + auto config = std::make_shared( + *mVMManager.get(), tag, CudaVirtualMemoryAllocator::RestoreMode::NONE, streamPtr); + + auto memoryBegin = getCurrentProcessMemoryInfo(); + + // Create a buffer using the virtual address allocator + auto buffer = std::make_unique( + size, nvinfer1::DataType::kINT8, CudaVirtualMemoryAllocator{config}); + + auto memoryAfterAllocation = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin + size, memoryAfterAllocation) << "Buffer allocation does not allocate memory"; + } + + // Test that we can access the buffer data + ASSERT_NE(buffer->data(), nullptr) << "Buffer data should not be null"; + ASSERT_EQ(buffer->getSize(), size) << "Buffer size should match requested size"; + ASSERT_EQ(buffer->getDataType(), nvinfer1::DataType::kINT8) << "Buffer data type should be INT8"; + ASSERT_EQ(buffer->getMemoryType(), MemoryType::kGPU) << "Buffer memory type should be GPU"; + + // Test memory access by setting memory to a known pattern + auto devicePtr = reinterpret_cast(buffer->data()); + auto result = cuMemsetD8_v2(devicePtr, 0xAB, size); + ASSERT_EQ(result, CUDA_SUCCESS) << "Memory access should succeed"; + TLLM_CU_CHECK(cuStreamSynchronize(nullptr)); + + // Test releasing memory with tag - this should free the virtual memory + auto releaseCount = mVMManager->releaseWithTag(tag); + ASSERT_EQ(releaseCount, 1) << "Expected to release 1 memory object"; + + auto memoryAfterRelease = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin, memoryAfterRelease) << "Release should free the memory"; + } + + // Test materializing memory with tag - this should re-allocate the virtual memory + auto materializeCount = mVMManager->materializeWithTag(tag); + ASSERT_EQ(materializeCount, 1) << "Expected to materialize 1 memory object"; + + auto memoryAfterMaterialize = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin + size, memoryAfterMaterialize) << "Materialize should allocate memory"; + } + + // Test memory access again after rematerialization + result = cuMemsetD8_v2(devicePtr, 0xCD, size); + ASSERT_EQ(result, CUDA_SUCCESS) << "Memory access should succeed after rematerialization"; + TLLM_CU_CHECK(cuStreamSynchronize(nullptr)); + + // Clean up by destroying the buffer (this should automatically clean up the virtual memory) + buffer.reset(); + + auto memoryAfterCleanup = getCurrentProcessMemoryInfo(); + if (memoryInfoAvailable()) + { + ASSERT_EQ(memoryBegin, memoryAfterCleanup) << "Buffer destruction should free memory"; + } +} diff --git a/tensorrt_llm/_torch/virtual_memory.py b/tensorrt_llm/_torch/virtual_memory.py new file mode 100644 index 00000000000..e2335286333 --- /dev/null +++ b/tensorrt_llm/_torch/virtual_memory.py @@ -0,0 +1,88 @@ +import functools +from contextlib import contextmanager +from typing import Generator + +import torch + +from tensorrt_llm.bindings.internal.runtime import \ + CudaVirtualMemoryAllocatorRestoreMode as RestoreMode +from tensorrt_llm.bindings.internal.runtime import ( + clear_virtual_memory_allocator, get_virtual_memory_manager, + set_virtual_memory_allocator) + +__all__ = [ + "RestoreMode", "maybe_scope", "scope", "release_with_tag", + "materialize_with_tag" +] + + +@functools.cache +def _get_torch_pluggable_virtual_memory_allocator(): + th_common = next(path for path in torch.classes.loaded_libraries + if 'th_common' in path) + virtual_memory_allocator = torch.cuda.CUDAPluggableAllocator( + th_common, 'tensorrt_llm_virtual_memory_alloc', + 'tensorrt_llm_virtual_memory_free') + return virtual_memory_allocator.allocator() + + +@contextmanager +def _virtual_memory_helper(tag: str, mode: RestoreMode): + stream = torch.cuda.current_stream() + set_virtual_memory_allocator(tag, mode, stream.cuda_stream) + try: + yield + finally: + clear_virtual_memory_allocator() + + +def _scope( + tag: str, + mode: RestoreMode = RestoreMode.NONE +) -> Generator[torch.cuda.MemPool, None, None]: + """A context manager that routes allocations to virtual memory allocator + using given tag and backed mode. + + :param tag: The tag to reference the memory for release and materialize + :param mode: The backed mode to choose how the memory content is backed up + """ + pool = torch.cuda.MemPool(_get_torch_pluggable_virtual_memory_allocator()) + with _virtual_memory_helper(tag, mode), torch.cuda.use_mem_pool(pool): + yield pool + + +scope = contextmanager(_scope) + + +@contextmanager +def maybe_scope( + enable: bool, + tag: str, + mode: RestoreMode = RestoreMode.NONE +) -> Generator[torch.cuda.MemPool | None, None, None]: + if enable: + yield from _scope(tag, mode) + else: + yield + + +def release_with_tag(*tags: str) -> int: + """Release virtual memory allocated with given tags + + :param tags: The tag of the scope when the virtual memory is allocated + :return: Number of memory blobs released + """ + manager = get_virtual_memory_manager() + released_blobs = sum(manager.release_with_tag(tag) for tag in tags) + return released_blobs + + +def materialize_with_tag(*tags: str) -> int: + """Materialize virtual memory allocated with given tags + + :param tags: The tag of the scope when the virtual memory is allocated + :return: Number of memory blobs materialized + """ + manager = get_virtual_memory_manager() + materialized_blobs = sum(manager.materialize_with_tag(tag) for tag in tags) + return materialized_blobs diff --git a/tests/unittest/_torch/test_virtual_memory.py b/tests/unittest/_torch/test_virtual_memory.py new file mode 100644 index 00000000000..ce1b58ce701 --- /dev/null +++ b/tests/unittest/_torch/test_virtual_memory.py @@ -0,0 +1,254 @@ +import gc +import os +import warnings + +import pynvml +import pytest +import torch + +import tensorrt_llm +from tensorrt_llm._torch import virtual_memory +from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager +from tensorrt_llm.bindings.executor import KvCacheConfig +from tensorrt_llm.bindings.internal.batch_manager import CacheType +from tensorrt_llm.mapping import Mapping + + +@pytest.fixture(scope="function", autouse=True) +def cuda_sync_fixture(): + """ + Synchronizes CUDA to catch device errors. + """ + + torch.cuda.synchronize() + yield + torch.cuda.synchronize() + + +@pytest.fixture(scope="module") +def memory_info_available(): + """ + Checks if NVML can get per-process memory information. + """ + + # Allocate a small tensor to test memory tracking + tensor = torch.zeros(4096, dtype=torch.int32, device='cuda') + torch.cuda.synchronize() + + # Try to get memory usage + usage = get_current_process_memory_info() + + # Clean up + del tensor + torch.cuda.synchronize() + torch.cuda.empty_cache() + + if usage == 0: + warnings.warn("Per process memory information unavailable.") + return False + + return True + + +@pytest.fixture(scope="module", autouse=True) +def nvml_init(): + pynvml.nvmlInit() + + +def get_current_process_memory_info() -> int: + """ + Returns GPU memory usage for current process in bytes. + """ + # Get current process ID + current_pid = os.getpid() + # Get device handle for GPU 0 + device_handle = pynvml.nvmlDeviceGetHandleByIndex(0) + + # Get running processes + processes = pynvml.nvmlDeviceGetComputeRunningProcesses(device_handle) + + # Find current process + for process in processes: + if process.pid == current_pid: + return process.usedGpuMemory + + return 0 + + +@pytest.fixture(scope="function", autouse=True) +def clean_cache(): + gc.collect() + torch.cuda.empty_cache() + yield + gc.collect() + torch.cuda.empty_cache() + + +def test_basic(memory_info_available): + memory_usage_begin = get_current_process_memory_info() + + alloc_size = 256 * 1024 * 1024 + tag = "test_tag" + + with virtual_memory.scope(tag) as pool: + tensor = torch.full([alloc_size], 42, dtype=torch.int8, device='cuda') + memory_usage_materialized = get_current_process_memory_info() + if memory_info_available: + assert memory_usage_begin + alloc_size == memory_usage_materialized + + assert tensor[0].item() == 42 + + torch.cuda.synchronize() + virtual_memory.release_with_tag(tag) + + memory_usage_released = get_current_process_memory_info() + if memory_info_available: + assert memory_usage_begin == memory_usage_released + + torch.cuda.synchronize() + virtual_memory.materialize_with_tag(tag) + + memory_usage_rematerialized = get_current_process_memory_info() + if memory_info_available: + assert memory_usage_begin + alloc_size == memory_usage_rematerialized + + torch.fill_(tensor, 24) + assert tensor[0].item() == 24 + + del tensor + del pool + + memory_usage_end = get_current_process_memory_info() + if memory_info_available: + assert memory_usage_begin == memory_usage_end + + +def test_restore(): + alloc_size = 1024 * 1024 + tag = "test_tag" + + with virtual_memory.scope(tag, virtual_memory.RestoreMode.PINNED) as pool: + tensor = torch.full([alloc_size], 42, dtype=torch.int8, device='cuda') + + assert tensor[0].item() == 42 + + torch.cuda.synchronize() + virtual_memory.release_with_tag(tag) + + torch.cuda.synchronize() + + virtual_memory.materialize_with_tag(tag) + torch.cuda.synchronize() + + assert tensor[0].item() == 42 + + del tensor + del pool + + +def test_kv_cache_manager(memory_info_available): + kv_cache_params = { + "kv_cache_config": KvCacheConfig(max_tokens=1024), + "kv_cache_type": CacheType.SELF, + "num_layers": 8, + "num_kv_heads": 256, + "head_dim": 64, + "tokens_per_block": 64, + "max_seq_len": 1024, + "max_batch_size": 1, + "mapping": Mapping(world_size=1, tp_size=1, rank=0), + "dtype": tensorrt_llm.bindings.DataType.FP8, + } + + mgr = KVCacheManager(**kv_cache_params) + mgr.shutdown() + del mgr + + memory_usage_begin = get_current_process_memory_info() + + tag = "test_tag" + cache_size = torch.empty( + [ + 2, # KV + 8, # Layers + 256, # Heads + 1024, # Tokens + 64, # Head dim + ], + dtype=torch.float8_e4m3fn, + device='meta') + + alloc_size = cache_size.nelement() + + with virtual_memory.scope(tag) as pool: + mgr = KVCacheManager(**kv_cache_params) + memory_usage_materialized = get_current_process_memory_info() + if memory_info_available: + assert memory_usage_begin + alloc_size == memory_usage_materialized + + torch.cuda.synchronize() + virtual_memory.release_with_tag(tag) + + memory_usage_released = get_current_process_memory_info() + if memory_info_available: + assert memory_usage_begin == memory_usage_released + + torch.cuda.synchronize() + virtual_memory.materialize_with_tag(tag) + + memory_usage_rematerialized = get_current_process_memory_info() + if memory_info_available: + assert memory_usage_begin + alloc_size == memory_usage_rematerialized + + mgr.shutdown() + del mgr + del pool + + memory_usage_end = get_current_process_memory_info() + if memory_info_available: + assert memory_usage_begin == memory_usage_end + + +def test_cuda_graph(memory_info_available): + + def work(input: torch.Tensor) -> torch.Tensor: + intermediate = input + input + output = input + intermediate + return output + + g = torch.cuda.CUDAGraph() + tag = "cuda_graph" + + with virtual_memory.scope(tag) as pool: + static_input = torch.ones(1024, dtype=torch.float32, device='cuda') + static_output = torch.zeros(1024, dtype=torch.float32, device='cuda') + + with torch.cuda.graph(g): + static_output.copy_(work(static_input)) + + torch.fill_(static_input, 1.0) + g.replay() + + torch.cuda.synchronize() + assert static_output[0].item() == 3.0 + + memory_usage_before = get_current_process_memory_info() + + torch.cuda.synchronize() + virtual_memory.release_with_tag(tag) + + memory_usage_released = get_current_process_memory_info() + if memory_info_available: + assert memory_usage_released < memory_usage_before + + torch.cuda.synchronize() + virtual_memory.materialize_with_tag(tag) + + torch.fill_(static_input, 1.0) + torch.fill_(static_output, 0.0) + g.replay() + + torch.cuda.synchronize() + assert static_output[0].item() == 3.0 + + del static_input, static_output, g, pool