diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp index 9a86c5d5c08..d17ed835c53 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @@ -18,12 +18,176 @@ #include "cacheTransBuffer.h" #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/common/logger.h" +#include "tensorrt_llm/common/opUtils.h" +#include "tensorrt_llm/executor/executor.h" #include #include namespace tensorrt_llm::batch_manager::kv_cache_manager { +class FabricMemory::Impl +{ +public: + Impl(size_t size) + : mSize(size) + { + TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceIdx)); + CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC; + CUmemAllocationProp prop = {}; + prop.requestedHandleTypes = handle_type; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = mDeviceIdx; + prop.allocFlags.gpuDirectRDMACapable = 1; + + size_t granularity{0}; + TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + mGranularity = granularity; + mAllocSize = (size + granularity - 1) / granularity * granularity; + TLLM_CU_CHECK(cuMemCreate(&mHandle, mAllocSize, &prop, 0)); + TLLM_CU_CHECK(cuMemAddressReserve(&mDevicePtr, mAllocSize, mGranularity, 0, 0)); + mPtr = reinterpret_cast(mDevicePtr); + CUmemAccessDesc accessDesc = {}; + accessDesc.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + accessDesc.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE; + accessDesc.location.id = mDeviceIdx; + TLLM_CU_CHECK(cuMemMap(mDevicePtr, mAllocSize, 0, mHandle, 0)); + TLLM_CU_CHECK(cuMemSetAccess(mDevicePtr, mAllocSize, &accessDesc, 1)); + TLLM_LOG_DEBUG("FabricMemory::Impl::Impl mAllocSize:%ld", mAllocSize); + } + + ~Impl() + { + TLLM_LOG_DEBUG("FabricMemory::Impl::~Impl mAllocSize:%ld", mAllocSize); + TLLM_CU_CHECK(cuMemUnmap(mDevicePtr, mAllocSize)); + TLLM_CU_CHECK(cuMemRelease(mHandle)); + TLLM_CU_CHECK(cuMemAddressFree(mDevicePtr, mAllocSize)); + } + + [[nodiscard]] void* getPtr() const + { + return mPtr; + } + + [[nodiscard]] size_t getSize() const + { + return mSize; + } + +private: + size_t mSize; + size_t mAllocSize; + size_t mGranularity; + void* mPtr; + CUdeviceptr mDevicePtr; + CUmemGenericAllocationHandle mHandle; + int mDeviceIdx; +}; + +FabricMemory::FabricMemory(size_t size) + : pImpl(std::make_unique(size)) +{ +} + +FabricMemory::~FabricMemory() = default; + +FabricMemory::FabricMemory(FabricMemory&&) noexcept = default; +FabricMemory& FabricMemory::operator=(FabricMemory&&) noexcept = default; + +void* FabricMemory::getPtr() const +{ + return pImpl->getPtr(); +} + +size_t FabricMemory::getSize() const +{ + return pImpl->getSize(); +} + +size_t FabricMemory::getAlignedSize(size_t size) +{ + + auto alingedSizeFun = []() + { + int deviceIdx = -1; + TLLM_CUDA_CHECK(cudaGetDevice(&deviceIdx)); + CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC; + CUmemAllocationProp prop = {}; + prop.requestedHandleTypes = handle_type; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = deviceIdx; + prop.allocFlags.gpuDirectRDMACapable = 1; + + size_t granularity{0}; + TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + return granularity; + }; + static size_t granularity = alingedSizeFun(); + + return (size + granularity - 1) / granularity * granularity; +} + +bool FabricMemory::supportFbaricMemory() +{ +#ifdef __aarch64__ + auto support_fun = []() + { + int fabric_handle_supported{0}; + int gpu_direct_rdma_with_cuda_vmm_supported{0}; + int deviceIdx = 0; + TLLM_CUDA_CHECK(cudaGetDevice(&deviceIdx)); + CUresult ret0 = cuDeviceGetAttribute( + &fabric_handle_supported, CU_DEVICE_ATTRIBUTE_HANDLE_TYPE_FABRIC_SUPPORTED, deviceIdx); + + CUresult ret1 = cuDeviceGetAttribute(&gpu_direct_rdma_with_cuda_vmm_supported, + CU_DEVICE_ATTRIBUTE_GPU_DIRECT_RDMA_WITH_CUDA_VMM_SUPPORTED, deviceIdx); + TLLM_LOG_DEBUG("FabricMemory::supportFbaricMemory fabric_handle_supported:%d", fabric_handle_supported); + TLLM_LOG_DEBUG("FabricMemory::supportFbaricMemory gpu_direct_rdma_with_cuda_vmm_supported:%d", + gpu_direct_rdma_with_cuda_vmm_supported); + if (ret0 != CUresult::CUDA_SUCCESS || ret1 != CUresult::CUDA_SUCCESS || fabric_handle_supported == 0 + || gpu_direct_rdma_with_cuda_vmm_supported == 0) + { + return false; + } + + CUmemAllocationHandleType const handle_type = CU_MEM_HANDLE_TYPE_FABRIC; + CUmemAllocationProp prop = {}; + prop.requestedHandleTypes = handle_type; + prop.type = CU_MEM_ALLOCATION_TYPE_PINNED; + prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE; + prop.location.id = deviceIdx; + prop.allocFlags.gpuDirectRDMACapable = 1; + + size_t granularity{0}; + TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM)); + CUmemGenericAllocationHandle handle; + + auto cuRet = cuMemCreate(&handle, granularity, &prop, 0); + + if (cuRet == CUresult::CUDA_SUCCESS) + { + TLLM_CU_CHECK(cuMemRelease(handle)); + return true; + } + if (cuRet == CUresult::CUDA_ERROR_NOT_PERMITTED) + { + TLLM_LOG_WARNING("Try to creat fabric memory failed , setting imex channel may be required"); + return false; + } + TLLM_CU_CHECK(cuRet); + + return false; + }; + static bool support = support_fun(); + return support; + +#else + return false; +#endif +} + CacheTransBufferManager::CacheTransBufferManager( KVCacheManager::BaseKVCacheManager* cacheManager, std::optional maxNumTokens) : mCacheManager{cacheManager} @@ -47,12 +211,19 @@ CacheTransBufferManager::CacheTransBufferManager( mOnlyUseDynamicBuffer = mTransferBufferSize == 0; mRecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1; mSendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1; + mUseFabricMemory = !(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer()) + && FabricMemory::supportFbaricMemory(); + if (mUseFabricMemory) + { + mTransferBufferSize = FabricMemory::getAlignedSize(mTransferBufferSize); + } mPreAllocBufferSize = mTransferBufferSize * (mRecvBufferCount + mSendBufferCount); TLLM_LOG_INFO( "CacheTransBufferManager: mMaxNumTokens:%ld, mRecvBufferCount:%ld, " - "mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d", + "mSendBufferCount:%ld,mTransferBufferSize:%ld, mPreAllocBufferSize:%ld,mOnlyUseDynamicBuffer:%d " + "mUseFabricMemory:%d", maxNumTokens.has_value() ? maxNumTokens.value() : 0, mRecvBufferCount, mSendBufferCount, mTransferBufferSize, - mPreAllocBufferSize, mOnlyUseDynamicBuffer); + mPreAllocBufferSize, mOnlyUseDynamicBuffer, mUseFabricMemory); bool to_allocate = common::getEnvUseMPIKvCache() || common::getEnvUseUCXKvCache() || common::getEnvUseNixlKvCache(); TLLM_CHECK_WITH_INFO(to_allocate, "CacheTransBufferManager: to_allocate is false"); @@ -76,6 +247,12 @@ size_t CacheTransBufferManager::preAllocBufferSize( { TransferBufferSize = maxNumTokens.value() * kvCacheSizePerToken.value(); } + bool useFabricMemory = FabricMemory::supportFbaricMemory() + && (!(common::getEnvKVCacheTransferUseSyncBuffer() || common::getEnvKVCacheTransferUseAsyncBuffer())); + if (useFabricMemory) + { + TransferBufferSize = FabricMemory::getAlignedSize(TransferBufferSize); + } size_t RecvBufferCount = common::getEnvRequestKVCacheConcurrent() ? common::getEnvKVCacheRecvBufferCount() : 1; size_t SendBufferCount = common::getEnvParallelCacheSend() ? common::getEnvKVCacheSendMaxConcurrenceNum() : 1; size_t PreAllocBufferSize = TransferBufferSize * (RecvBufferCount + SendBufferCount); @@ -122,7 +299,6 @@ runtime::ITensor::SharedPtr CacheTransBufferManager::getSendBuffer(std::optional if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < mSendBufferCount); - // TLLM_CHECK(mConcurrenceSendResource.mBufferIndexFlag[bufferId.value()] == 1); return mConcurrenceSendResource.mBuffers[bufferId.value()]; } return nullptr; @@ -193,7 +369,23 @@ void CacheTransBufferManager::allocateBuffer() mBufferEleSize = mTransferBufferSize / common::getDTypeSize(mDataType); mConcurrenceSendResource.mBufferIndexFlag.resize(mSendBufferCount, 0); mConcurrenceRecvResource.mBufferIndexFlag.resize(mRecvBufferCount, 0); - if (common::getEnvKVCacheTransferUseAsyncBuffer()) + if (mUseFabricMemory) + { + mFabricMemory.reserve(mSendBufferCount + mRecvBufferCount); + for (size_t i = 0; i < mSendBufferCount; i++) + { + mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); + mConcurrenceSendResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, + runtime::ITensor::makeShape({static_cast(mBufferEleSize)}), mBufferEleSize); + } + for (size_t i = 0; i < mRecvBufferCount; i++) + { + mFabricMemory.emplace_back(std::make_unique(mTransferBufferSize)); + mConcurrenceRecvResource.mBuffers[i] = runtime::ITensor::wrap(mFabricMemory.back()->getPtr(), mDataType, + runtime::ITensor::makeShape({static_cast(mBufferEleSize)}), mBufferEleSize); + } + } + else if (common::getEnvKVCacheTransferUseAsyncBuffer()) { for (size_t i = 0; i < mSendBufferCount; i++) { diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h index 7a2248ddd6b..e18bc810803 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h @@ -30,6 +30,29 @@ namespace tensorrt_llm::batch_manager::kv_cache_manager { +class FabricMemory +{ +public: + explicit FabricMemory(size_t size); + ~FabricMemory(); + + FabricMemory(FabricMemory const&) = delete; + FabricMemory& operator=(FabricMemory const&) = delete; + + FabricMemory(FabricMemory&&) noexcept; + FabricMemory& operator=(FabricMemory&&) noexcept; + + void* getPtr() const; + size_t getSize() const; + + static size_t getAlignedSize(size_t size); + static bool supportFbaricMemory(); + +private: + class Impl; + std::unique_ptr pImpl; +}; + class CacheTransBufferManager { public: @@ -81,12 +104,14 @@ class CacheTransBufferManager size_t mSendBufferCount; size_t mTransferBufferSize; bool mOnlyUseDynamicBuffer; + bool mUseFabricMemory; size_t mBufferEleSize; nvinfer1::DataType mDataType; ConcurrenceResource mConcurrenceSendResource; ConcurrenceResource mConcurrenceRecvResource; KVCacheManager::BaseKVCacheManager* mCacheManager; runtime::BufferManager mBufferManager; + std::vector> mFabricMemory; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index f0c12c3f232..88f1380e0b3 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -101,6 +101,7 @@ class DataResponder::Impl { TLLM_CHECK(mSender); TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); + mCurrentRequest = std::nullopt; mResponseFuture = std::async(std::launch::async, &Impl::response, this); } @@ -220,13 +221,10 @@ class DataResponder::Impl } else { - if (mCurrentRequest.has_value()) - { - TLLM_LOG_ERROR( - "This executor does not have a prepared KV cache for request ID: %zu, and the " - "mReadyResponses size is: %zu.", - mCurrentRequest.value(), mReadyResponses.size()); - } + TLLM_CHECK_WITH_INFO(!mCurrentRequest.has_value(), + "This executor does not have a prepared KV cache for request ID: %zu, and the " + "mReadyResponses size is: %zu. mpi rank :%d ", + mCurrentRequest.value(), mReadyResponses.size(), mpi::MpiComm::world().getRank()); std::unique_lock lk(mCondMutex); mResponderCv.wait(lk, [this]() { return (mAnyReady || mTerminate); }); } diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index 665ba032ad2..0981a40efc8 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -365,6 +365,12 @@ bool getEnvKVCacheTransferUseAsyncBuffer() return useAsyncBuffer; } +bool getEnvKVCacheTransferUseSyncBuffer() +{ + static bool const useSyncBuffer = getBoolEnv("TRTLLM_KVCACHE_TRANSFER_USE_SYNC_BUFFER"); + return useSyncBuffer; +} + size_t getEnvKVCacheSendMaxConcurrenceNum() { diff --git a/cpp/tensorrt_llm/common/envUtils.h b/cpp/tensorrt_llm/common/envUtils.h index 5bd250c6f6a..54801f8344e 100644 --- a/cpp/tensorrt_llm/common/envUtils.h +++ b/cpp/tensorrt_llm/common/envUtils.h @@ -92,6 +92,8 @@ size_t getEnvKVCacheRecvBufferCount(); bool getEnvKVCacheTransferUseAsyncBuffer(); +bool getEnvKVCacheTransferUseSyncBuffer(); + size_t getEnvKVCacheSendMaxConcurrenceNum(); size_t getEnvMemSizeForKVCacheTransferBuffer(); diff --git a/cpp/tests/batch_manager/cacheTransceiverTest.cpp b/cpp/tests/batch_manager/cacheTransceiverTest.cpp index 8981960da11..d6f14525ded 100644 --- a/cpp/tests/batch_manager/cacheTransceiverTest.cpp +++ b/cpp/tests/batch_manager/cacheTransceiverTest.cpp @@ -1203,7 +1203,8 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase) int requestId = 0; for (auto len : {30, 10, 60, 30, 60, 10}) { - requests.emplace_back(makeLlmRequestWithDP(len, requestId++, requestId % contextTp)); + requests.emplace_back(makeLlmRequestWithDP(len, requestId, requestId % contextTp)); + requestId++; } std::vector> contextFutures; std::vector> generationFutures; @@ -1216,7 +1217,7 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase) { for (int i = 0; i < requests.size(); i++) { - if (i % mTpSize == mTpRank) + if ((i) % mTpSize == mTpRank) { // round robin contextRequests.push_back(requests[i]); @@ -1239,7 +1240,7 @@ TEST_P(AsymmetricalCacheTestWithDP, TestCase) { for (int i = 0; i < requests.size(); i++) { - if (i % mTpSize == mTpRank) + if ((i) % mTpSize == mTpRank) { generationRequests.push_back(requests[i]); }