diff --git a/.devcontainer/docker-compose.yml b/.devcontainer/docker-compose.yml index 5d84d5b879c..f5b42bcfb81 100644 --- a/.devcontainer/docker-compose.yml +++ b/.devcontainer/docker-compose.yml @@ -1,7 +1,7 @@ version: "3.9" services: tensorrt_llm-dev: - image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539 + image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420 network_mode: host ipc: host diff --git a/cpp/conanfile.py b/cpp/conanfile.py index 791643b4ec5..0c5ac4e7f88 100644 --- a/cpp/conanfile.py +++ b/cpp/conanfile.py @@ -1,3 +1,6 @@ +import os +import sys + from conan import ConanFile from conan.tools.cmake import CMakeDeps, CMakeToolchain @@ -9,10 +12,22 @@ class TensorRT_LLM(ConanFile): virtualrunenv = False def requirements(self): - pass # TODO add dependencies here + self.requires("libnuma/system") def generate(self): cmake = CMakeDeps(self) cmake.generate() tc = CMakeToolchain(self) tc.generate() + + def build_requirements(self): + # register libnuma_conan.py for conan + base_dir = os.path.dirname(os.path.abspath(__file__)) + libnuma_path = os.path.join(base_dir, "libnuma_conan.py") + conan_bin = os.path.abspath(sys.argv[0]) + if not os.path.isfile(conan_bin) or not os.access(conan_bin, os.X_OK): + raise RuntimeError(f"Conan binary not found {sys.argv[0]}") + + self.run( + f"{conan_bin} export {libnuma_path} --name=libnuma --version=system" + ) diff --git a/cpp/libnuma_conan.py b/cpp/libnuma_conan.py new file mode 100644 index 00000000000..11c9b3855ff --- /dev/null +++ b/cpp/libnuma_conan.py @@ -0,0 +1,36 @@ +from conan import ConanFile + + +class LibnumaSystemConan(ConanFile): + name = "libnuma" + version = "system" + package_type = "shared-library" + settings = "os", "arch" + + def package_info(self): + if self.settings.os == "Windows": + self.output.info("libnuma not needed on Windows.") + return + + self.cpp_info.includedirs = ["/usr/include"] + libdirs = [] + + arch = str(self.settings.arch) + os_name = str(self.settings.os) + + if os_name == "Linux": + if arch == "x86_64": + libdirs.append("/usr/lib/x86_64-linux-gnu") + elif arch in ["armv8", "aarch64"]: + libdirs.append("/usr/lib/aarch64-linux-gnu") + else: + self.output.warn( + f"Unrecognized architecture: {arch}, falling back to /usr/lib" + ) + libdirs.append("/usr/lib") + else: + self.output.warn(f"Unsupported OS: {os_name}, assuming /usr/lib") + libdirs.append("/usr/lib") + + self.cpp_info.libdirs = libdirs + self.cpp_info.system_libs = ["numa"] diff --git a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu index 88b40a7b5cf..4f4bce83ec3 100644 --- a/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu +++ b/cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu @@ -245,12 +245,166 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic } } +template +__global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo, + int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount) +{ + extern __shared__ int16_t sharedGlobalSlotIdsInfo[]; + int expertIds[ITEM_PER_THREAD]; + int slotIds[ITEM_PER_THREAD]; + for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT) + { + sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId]; + } + + int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD; + + for (; blockOffset < tokenCount * metaInfo.topK; blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD) + { + int tokenIdxBase = blockOffset + threadIdx.x; +#pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; i++) + { + int tokenIdx = tokenIdxBase + i * THREAD_COUNT; + expertIds[i] + = tokenIdx < tokenCount * metaInfo.topK ? tokenSelectedExperts[tokenIdx] : metaInfo.expertCount; + } +#pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; i++) + { + if (expertIds[i] < 0 || expertIds[i] >= metaInfo.expertCount) + { + expertIds[i] = metaInfo.expertCount; + } + } + if (blockOffset == blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD) + { + __syncthreads(); + } +#pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; i++) + { + slotIds[i] = expertIds[i] < metaInfo.expertCount ? sharedGlobalSlotIdsInfo[expertIds[i]] + : metaInfo.epSize * metaInfo.slotCountPerRank; + } +#pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; i++) + { + int tokenIdx = tokenIdxBase + i * THREAD_COUNT; + if (tokenIdx < tokenCount * metaInfo.topK) + { + tokenRoutedSlotIds[tokenIdx] = slotIds[i]; + } + } + } +} + template __global__ void moeComputeRouteKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo, int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount, bool offsetByEpRank) +{ + int warpId = threadIdx.x / 32; + int laneId = threadIdx.x % 32; + static int const kWarpCount = THREAD_COUNT / 32; + extern __shared__ int16_t sharedGlobalSlotIdsInfo[]; + __shared__ int sharedExpertReplicaCountAndStartOffset[MAX_EXPERT_COUNT]; + + __shared__ int sharedArbitrateExpertId[THREAD_COUNT * ITEM_PER_THREAD]; + __shared__ int sharedExpertCount[MAX_EXPERT_COUNT]; + for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT) + { + int replicaCount = placementInfo.expertReplicaCount[expertIdx]; + int replicaStartOffset = placementInfo.expertReplicaStartOffset[expertIdx]; + sharedExpertReplicaCountAndStartOffset[expertIdx] = (replicaCount << 16) | replicaStartOffset; + sharedExpertCount[expertIdx] = 0; + } + for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT) + { + sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId]; + } + + int expertIds[ITEM_PER_THREAD]; + int tokenIdxBase = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD + threadIdx.x; +#pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; i++) + { + int tokenIdx = tokenIdxBase + i * THREAD_COUNT; + expertIds[i] = tokenIdx < tokenCount * metaInfo.topK ? tokenSelectedExperts[tokenIdx] : metaInfo.expertCount; + } +#pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; i++) + { + if (expertIds[i] < 0 || expertIds[i] >= metaInfo.expertCount) + { + expertIds[i] = metaInfo.expertCount; + } + } + __syncthreads(); +#pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; i++) + { + int countAndStart + = expertIds[i] < metaInfo.expertCount ? sharedExpertReplicaCountAndStartOffset[expertIds[i]] : (1 << 16); + int arbitrateExpertId = (countAndStart >> 16) > 1 ? expertIds[i] : metaInfo.expertCount; + sharedArbitrateExpertId[threadIdx.x + i * THREAD_COUNT] = arbitrateExpertId; + } + __syncthreads(); + int baseOffset = blockIdx.x + (offsetByEpRank ? metaInfo.epRank : 0); + if (warpId == 0) + { +#pragma unroll + for (int i = 0; i < kWarpCount * ITEM_PER_THREAD; ++i) + { + int expertId = sharedArbitrateExpertId[laneId + i * 32]; + __syncwarp(); + unsigned match = __match_any_sync(0xFFFFFFFF, expertId); + int leader = __ffs(match) - 1; + int matchCount = __popc(match); + int oldVal = 0; + if (laneId == leader && expertId < metaInfo.expertCount) + { + oldVal = atomicAdd_block(&sharedExpertCount[expertId], matchCount); + } + __syncwarp(); + oldVal = __shfl_sync(0XFFFFFFFF, oldVal, leader); + unsigned lowerMask = match & ((1u << laneId) - 1); + int rankInGroup = __popc(lowerMask); + int offset = oldVal + rankInGroup; + offset += baseOffset; + sharedArbitrateExpertId[laneId + i * 32] = offset; + } + } + __syncthreads(); + int targetGlobalSlotId[ITEM_PER_THREAD]; +#pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; i++) + { + int countAndStart + = expertIds[i] < metaInfo.expertCount ? sharedExpertReplicaCountAndStartOffset[expertIds[i]] : (1 << 16); + int count = countAndStart >> 16; + int offset = countAndStart & 0xFFFF; + int arbitratedIndex = sharedArbitrateExpertId[threadIdx.x + i * THREAD_COUNT]; + offset += arbitratedIndex % count; + targetGlobalSlotId[i] = expertIds[i] < metaInfo.expertCount ? sharedGlobalSlotIdsInfo[offset] + : metaInfo.epSize * metaInfo.slotCountPerRank; + } +#pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; i++) + { + int tokenIdx = tokenIdxBase + i * THREAD_COUNT; + if (tokenIdx < tokenCount * metaInfo.topK) + { + tokenRoutedSlotIds[tokenIdx] = targetGlobalSlotId[i]; + } + } +} + +template +__global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo, + int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount, bool offsetByEpRank) { using BlockSort = cub::BlockRadixSort; - extern __shared__ int sharedGlobalSlotIdsInfo[]; + extern __shared__ int16_t sharedGlobalSlotIdsInfo[]; __shared__ typename BlockSort::TempStorage tempStorage; @@ -361,9 +515,19 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla constexpr int kThreadCount = 256; constexpr int kEltPerThread = 4; int blockCount = (tokenCount * metaInfo.topK + kThreadCount * kEltPerThread - 1) / (kThreadCount * kEltPerThread); - int dynamicShmSize = sizeof(int) * metaInfo.epSize * metaInfo.slotCountPerRank; - moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<>>( - metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank); + int dynamicShmSize = sizeof(int16_t) * metaInfo.epSize * metaInfo.slotCountPerRank; + if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank) + { + // no redundant expert, so we don't need complex routing, but just assign to the correct solt. + moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread> + <<>>( + metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount); + } + else + { + moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<>>( + metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank); + } } void moeWaitSignalForCpuStageHost(MoeLoadBalanceSingleLayerSignal* signal) diff --git a/cpp/tensorrt_llm/pybind/runtime/moeBindings.cpp b/cpp/tensorrt_llm/pybind/runtime/moeBindings.cpp index a7c37e58d3b..593c706b747 100644 --- a/cpp/tensorrt_llm/pybind/runtime/moeBindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/moeBindings.cpp @@ -16,7 +16,7 @@ */ #include "moeBindings.h" -#include "tensorrt_llm/runtime/moeLoadBalancer.h" +#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h" #include #include #include @@ -98,6 +98,8 @@ void initMoeBindings(pybind11::module_& m) py::class_(m, "MoeLoadBalancer") .def(py::init(), py::arg("ep_rank"), py::arg("ep_size"), py::arg("layer_updates_per_iter"), "Initialize the MoeLoadBalancer with the specified expert parallel rank, size, and update frequency") + .def("set_use_gpu_memcpy", &tr::MoeLoadBalancer::setUseGpuMemcpy, py::arg("use_gpu_memcpy"), + "Set whether to use GPU memcpy for weight updates") .def("add_layer", &tr::MoeLoadBalancer::AddLayer, py::arg("expert_count"), py::arg("top_k"), py::arg("slot_count_per_rank"), "Add a new MOE layer to the load balancer") .def("finalize_model", &tr::MoeLoadBalancer::finalizeModel, diff --git a/cpp/tensorrt_llm/runtime/CMakeLists.txt b/cpp/tensorrt_llm/runtime/CMakeLists.txt index d351f459a37..321d65e1028 100644 --- a/cpp/tensorrt_llm/runtime/CMakeLists.txt +++ b/cpp/tensorrt_llm/runtime/CMakeLists.txt @@ -43,7 +43,8 @@ set(SRCS ipcNvlsMemory.cpp mcastDeviceMemory.cpp memoryCounters.cpp - moeLoadBalancer.cpp + moeLoadBalancer/moeLoadBalancer.cpp + moeLoadBalancer/topologyDetector.cpp ncclCommunicator.cpp promptTuningParams.cpp runtimeKernels.cu @@ -80,3 +81,27 @@ target_include_directories(runtime_src PRIVATE ${MPI_C_INCLUDE_DIRS}) if(ENABLE_MULTI_DEVICE) target_link_libraries(runtime_src PUBLIC ${NCCL_LIB}) endif() + +if(NOT WIN32) + find_package(libnuma QUIET CONFIG) + + if(NOT libnuma_FOUND) + message( + STATUS "libnuma not found via Conan, falling back to system libnuma") + find_path(NUMA_INCLUDE_DIR numa.h) + find_library(NUMA_LIBRARY numa) + + if(NUMA_INCLUDE_DIR AND NUMA_LIBRARY) + add_library(libnuma::libnuma UNKNOWN IMPORTED) + set_target_properties( + libnuma::libnuma + PROPERTIES IMPORTED_LOCATION "${NUMA_LIBRARY}" + INTERFACE_INCLUDE_DIRECTORIES "${NUMA_INCLUDE_DIR}") + else() + message(FATAL_ERROR "NUMA library not found, please install libnuma-dev") + endif() + else() + message(STATUS "libnuma found.") + endif() + target_link_libraries(runtime_src PUBLIC libnuma::libnuma) +endif() diff --git a/cpp/tensorrt_llm/runtime/moeLoadBalancer.cpp b/cpp/tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.cpp similarity index 76% rename from cpp/tensorrt_llm/runtime/moeLoadBalancer.cpp rename to cpp/tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.cpp index 4f71815eaf9..292cf1110c2 100644 --- a/cpp/tensorrt_llm/runtime/moeLoadBalancer.cpp +++ b/cpp/tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.cpp @@ -14,10 +14,11 @@ * limitations under the License. */ -#include "tensorrt_llm/runtime/moeLoadBalancer.h" +#include "moeLoadBalancer.h" #include "tensorrt_llm/common/assert.h" #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h" +#include "topologyDetector.h" #include #include #include @@ -33,7 +34,6 @@ namespace tensorrt_llm::runtime { - // Helper structure to hold replica information struct ReplicaInfo { @@ -271,6 +271,7 @@ void allocateStatisticInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo const& tensorrt_llm::kernels::MoeLoadBalanceStatisticInfo* statisticInfo) { TLLM_CUDA_CHECK(cudaMallocHost(&statisticInfo->expertLoadFactor, sizeof(float) * metaInfo.expertCount)); + std::fill_n(statisticInfo->expertLoadFactor, metaInfo.expertCount, 0.0f); TLLM_CHECK_WITH_INFO(statisticInfo->rawDataWindowSize > 0, "statisticInfo->rawDataWindowSize should > 0."); TLLM_CUDA_CHECK(cudaMalloc( &statisticInfo->expertTokenCount, sizeof(int) * metaInfo.expertCount * statisticInfo->rawDataWindowSize)); @@ -288,9 +289,9 @@ void freeStatisticInfo(tensorrt_llm::kernels::MoeLoadBalanceStatisticInfo* stati } void allocatePlacementInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo const& metaInfo, - tensorrt_llm::kernels::MoePlacementInfo* placementInfo, bool isCpu = false) + tensorrt_llm::kernels::MoePlacementInfo* placementInfo, bool isCpu = false, bool useManaged = false) { - auto allocFn = [isCpu](void** ptr, size_t size) + auto allocFn = [isCpu, useManaged](void** ptr, size_t size) { if (isCpu) { @@ -298,7 +299,21 @@ void allocatePlacementInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo const& } else { - return cudaMalloc(ptr, size); + if (useManaged) + { + TLLM_CUDA_CHECK(cudaMallocManaged(ptr, size)); + int cur_dev; + TLLM_CUDA_CHECK(cudaGetDevice(&cur_dev)); + TLLM_CUDA_CHECK(cudaMemAdvise(*ptr, size, cudaMemAdviseSetPreferredLocation, cur_dev)); + TLLM_CUDA_CHECK(cudaMemAdvise(*ptr, size, cudaMemAdviseSetAccessedBy, cur_dev)); + TLLM_CUDA_CHECK(cudaMemAdvise(*ptr, size, cudaMemAdviseSetAccessedBy, cudaCpuDeviceId)); + TLLM_CUDA_CHECK(cudaMemset(*ptr, 0, size)); + return cudaSuccess; + } + else + { + return cudaMalloc(ptr, size); + } } }; TLLM_CUDA_CHECK( @@ -405,7 +420,7 @@ void SingleLayerMoeLoadBalancer::createResources() } allocatePlacementInfo(mMetaInfo, &mCpuPlacementInfo.placementInfoForGPU, true); - allocatePlacementInfo(mMetaInfo, &mGpuPlacement, false); + allocatePlacementInfo(mMetaInfo, &mGpuPlacement, false, true); mSingleLayerSignal = allocateSingleLayerSignal(); TLLM_CUDA_CHECK(cudaEventCreate(&mUpdateWeightsDoneEvent)); @@ -451,7 +466,18 @@ void SingleLayerMoeLoadBalancer::maybeStartUpdateWeights() { if (mIterId >= 0 && mUpdateWeightsEnabled) { - mMoeLoadBalancer->addUpdateTask([this] { updateWeightsRoutine(); }); + mMoeLoadBalancer->addUpdateTask( + [this] + { + if (mMoeLoadBalancer->mUseGpuMemcpy) + { + updateWeightsRoutine(); + } + else + { + updateWeightsRoutineByCpu(); + } + }); } } @@ -461,6 +487,7 @@ void SingleLayerMoeLoadBalancer::waitLastUpdateDone() { std::unique_lock lock(mUpdateWeightsMutex); mUpdateWeightsCondition.wait(lock, [this] { return mUpdateWeightsDone; }); + lock.unlock(); } } @@ -485,6 +512,25 @@ void SingleLayerMoeLoadBalancer::copyPlacementInfoToGpu() { std::fill_n(mCpuPlacementInfo.rankExpertIds[i].begin(), mMetaInfo.slotCountPerRank, -1); } + // clear expert load factor for next statistic + std::fill_n(mStatisticInfo.expertLoadFactor, mMetaInfo.expertCount, 0.0f); +} + +void SingleLayerMoeLoadBalancer::copyPlacementInfoToGpuByCpu() +{ + memcpy(mGpuPlacement.expertReplicaCount, mCpuPlacementInfo.placementInfoForGPU.expertReplicaCount, + sizeof(int) * mMetaInfo.expertCount); + memcpy(mGpuPlacement.expertReplicaStartOffset, mCpuPlacementInfo.placementInfoForGPU.expertReplicaStartOffset, + sizeof(int) * mMetaInfo.expertCount); + memcpy(mGpuPlacement.globalSlotIds, mCpuPlacementInfo.placementInfoForGPU.globalSlotIds, + sizeof(int) * mMetaInfo.epSize * mMetaInfo.slotCountPerRank); + mCpuPlacementInfo.rankExpertIds.swap(mCpuPlacementInfo.oldRankExpertIds); + for (int i = 0; i < mMetaInfo.epSize; ++i) + { + std::fill_n(mCpuPlacementInfo.rankExpertIds[i].begin(), mMetaInfo.slotCountPerRank, -1); + } + // clear expert load factor for next statistic + std::fill_n(mStatisticInfo.expertLoadFactor, mMetaInfo.expertCount, 0.0f); } void SingleLayerMoeLoadBalancer::updateWeightsRoutine() @@ -501,6 +547,21 @@ void SingleLayerMoeLoadBalancer::updateWeightsRoutine() mUpdateWeightsCondition.notify_one(); } +void SingleLayerMoeLoadBalancer::updateWeightsRoutineByCpu() +{ + doReplication(mMetaInfo, mStatisticInfo.expertLoadFactor, &mCpuPlacementInfo); + doPlacement(mMetaInfo, mStatisticInfo.expertLoadFactor, &mCpuPlacementInfo); + prepareGpuPlacementInfo(mMetaInfo, &mCpuPlacementInfo); + mLastUpdateTaskId = mMoeLoadBalancer->addCopyTask( + [this](int rank, int size) { mWeightUpdater->updateWeights(&mCpuPlacementInfo, rank, size); }); + mMoeLoadBalancer->waitCopyTaskDone(mLastUpdateTaskId); + mLastUpdateTaskId = -1; + copyPlacementInfoToGpuByCpu(); + std::unique_lock lock(mUpdateWeightsMutex); + mUpdateWeightsDone = true; + mUpdateWeightsCondition.notify_one(); +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // Weight Updater /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -646,8 +707,61 @@ void HostMemoryMoeWeightUpdater::copyWeights(MoeWeight const& src, MoeWeight con } } -void HostMemoryMoeWeightUpdater::updateWeights(tensorrt_llm::runtime::MoePlacementCpuInfo const* placementCpuInfo) +void HostMemoryMoeWeightUpdater::copyWeightsCpu(MoeWeight const& src, MoeWeight const& dst, int rank, int size) +{ + TLLM_CHECK(src.mWeightPtr != nullptr && dst.mWeightPtr != nullptr); + TLLM_CHECK(src.mHeight == dst.mHeight && src.mWidth == dst.mWidth); + char* srcPtr = static_cast(src.mWeightPtr); + char* dstPtr = static_cast(dst.mWeightPtr); + size_t singleCopySize, copyCount, srcPitch, dstPitch; + if (src.mPitch == src.mWidth && dst.mPitch == dst.mWidth) + { + singleCopySize = src.mWidth * src.mHeight; + copyCount = 1; + srcPitch = singleCopySize; + dstPitch = singleCopySize; + } + else + { + singleCopySize = src.mWidth; + copyCount = src.mHeight; + srcPitch = src.mPitch; + dstPitch = dst.mPitch; + } + size_t fullCopyCount = copyCount / size * size; + size_t threadCopyCount = fullCopyCount / size; + for (size_t i = rank * threadCopyCount; i < (rank + 1) * threadCopyCount; i++) + { + memcpy(dstPtr + i * dstPitch, srcPtr + i * srcPitch, singleCopySize); + } + size_t threadStartOffset = rank * singleCopySize / size; + size_t threadEndOffset = (rank + 1) * singleCopySize / size; + size_t threadCopySize = threadEndOffset - threadStartOffset; + for (size_t i = fullCopyCount; i < copyCount && threadCopySize > 0; i++) + { + memcpy(dstPtr + i * dstPitch + threadStartOffset, srcPtr + i * srcPitch + threadStartOffset, threadCopySize); + } +} + +void PrintUpdateInfo(tensorrt_llm::kernels::MoeLoadBalanceMetaInfo metaInfo, + tensorrt_llm::runtime::MoePlacementCpuInfo const* placementCpuInfo) +{ + std::stringstream ss; + ss << "[UpdateInfo] rank=" << metaInfo.epRank << ", expert weights=\n ["; + for (int slotId = 0; slotId < metaInfo.slotCountPerRank * metaInfo.epSize; slotId++) + { + ss << placementCpuInfo->rankExpertIds[slotId / metaInfo.slotCountPerRank][slotId % metaInfo.slotCountPerRank] + << ", "; + } + ss << "\n"; + fprintf(stderr, "%s\n", ss.str().c_str()); +} + +void HostMemoryMoeWeightUpdater::updateWeights( + tensorrt_llm::runtime::MoePlacementCpuInfo const* placementCpuInfo, int rank, int size) { + // PrintUpdateInfo(mMetaInfo, placementCpuInfo); + bool useGpu = mLayerLoadBalancer->mMoeLoadBalancer->mUseGpuMemcpy; for (int slotId = 0; slotId < mMetaInfo.slotCountPerRank; ++slotId) { int oldExpertId = placementCpuInfo->oldRankExpertIds[mMetaInfo.epRank][slotId]; @@ -665,7 +779,14 @@ void HostMemoryMoeWeightUpdater::updateWeights(tensorrt_llm::runtime::MoePlaceme auto& name = slotIt->first; auto& slotWeight = slotIt->second[slotId]; auto& hostWeight = mHostWeights[name][newExpertId]; - copyWeights(hostWeight, slotWeight, mLayerLoadBalancer->getStream()); + if (useGpu) + { + copyWeights(hostWeight, slotWeight, mLayerLoadBalancer->getStream()); + } + else + { + copyWeightsCpu(hostWeight, slotWeight, rank, size); + } } } } @@ -680,8 +801,40 @@ MoeLoadBalancer::MoeLoadBalancer(int epRank, int epSize, int layerUpdatesPerIter , mLayerUpdatesPerIter{layerUpdatesPerIter} { TLLM_CUDA_CHECK(cudaGetDevice(&mCudaDeviceId)); - // create a non-blocking stream for compute and update + // create a non-blocking stream for compute and update, not needed anymore for CPU copy engine. TLLM_CUDA_CHECK(cudaStreamCreateWithFlags(&mStream, cudaStreamNonBlocking)); + + auto& topologyDetector = TopologyDetector::getInstance(); + int currentGpuNumaId = topologyDetector.getCurrentGpuNumaId(); + int numaCpuCount = topologyDetector.getCurrentGpuNumaCpuCount(); + int numaGpuCount = topologyDetector.getGpuCountUnderNuma(currentGpuNumaId); + TLLM_CHECK_WITH_INFO( + numaCpuCount > 0 && numaGpuCount > 0, "numaCpuCount=%d, numaGpuCount=%d", numaCpuCount, numaGpuCount); + int cpuCountPerGpu = std::max(1, numaCpuCount / numaGpuCount); + std::string cpuArch = topologyDetector.getCpuArchitecture(); + + int numCopyThreads = 8; + if (getenv("TLLM_LOAD_BALANCE_NUM_COPY_THREADS")) + { + int numCopyThreadsFromEnv = atoi(getenv("TLLM_LOAD_BALANCE_NUM_COPY_THREADS")); + if (numCopyThreadsFromEnv > 0) + { + TLLM_LOG_INFO( + "Setting TLLM_LOAD_BALANCE_NUM_COPY_THREADS to %d by environment variable", numCopyThreadsFromEnv); + numCopyThreads = numCopyThreadsFromEnv; + } + } + else + { + if (cpuCountPerGpu > 0) + { + numCopyThreads = std::min(16, std::max(4, cpuCountPerGpu / 2)); + TLLM_LOG_INFO("Auto-setting copy threads to %d based on NUMA topology (NUMA node %d, %d CPUs, arch: %s)", + numCopyThreads, currentGpuNumaId, numaCpuCount, cpuArch.c_str()); + } + } + + mMultiThreadWorker.reset(new MultiThreadWorker(numCopyThreads)); } MoeLoadBalancer::~MoeLoadBalancer() {} @@ -730,6 +883,7 @@ void MoeLoadBalancer::finalizeModel() } if (mLayerUpdatesPerIter > 0) { + mMultiThreadWorker->start(); generateUpdatePlan(); startThreads(); } @@ -783,6 +937,7 @@ void MoeLoadBalancer::shutdown() mWorkerThread->join(); TLLM_LOG_INFO("MoeLoadBalancer shutdown."); + mMultiThreadWorker->stop(); } } @@ -831,7 +986,6 @@ void MoeLoadBalancer::workerThread() } addUpdateTask(nullptr); mComputeAndUpdateThread->join(); - TLLM_LOG_INFO("MoeLoadBalancer worker thread stopped"); } void MoeLoadBalancer::computeAndUpdateThread() @@ -850,7 +1004,6 @@ void MoeLoadBalancer::computeAndUpdateThread() } task(); } - TLLM_LOG_INFO("MoeLoadBalancer compute and update thread stopped"); } void MoeLoadBalancer::addUpdateTask(std::function task) @@ -860,4 +1013,127 @@ void MoeLoadBalancer::addUpdateTask(std::function task) mUpdateQueueCondition.notify_one(); } +int64_t MoeLoadBalancer::addCopyTask(std::function task) +{ + return mMultiThreadWorker->addTask(task); +} + +void MoeLoadBalancer::waitCopyTaskDone(int64_t taskId) +{ + if (!mUseGpuMemcpy) + { + mMultiThreadWorker->waitTaskDone(taskId); + } +} + +MultiThreadWorker::MultiThreadWorker(int numThreads) + : mNumThreads(numThreads) + , mRunning(false) + , mNextTaskId(0) +{ +} + +MultiThreadWorker::~MultiThreadWorker() +{ + stop(); +} + +void MultiThreadWorker::start() +{ + std::lock_guard lk(mMutex); + if (mRunning) + return; + mRunning = true; + mThreads.reserve(mNumThreads); + for (int i = 0; i < mNumThreads; ++i) + { + mThreads.emplace_back(&MultiThreadWorker::workerLoop, this, i); + } +} + +int64_t MultiThreadWorker::addTask(std::function func) +{ + auto task = std::make_shared(); + { + std::lock_guard lk(mMutex); + task->id = mNextTaskId++; + task->func = std::move(func); + task->remaining = mNumThreads; + mTasks.push_back(task); + mTaskMap[task->id] = task; + } + mCondition.notify_all(); + return task->id; +} + +void MultiThreadWorker::waitTaskDone(int64_t taskId) +{ + std::unique_lock lk(mMutex); + auto it = mTaskMap.find(taskId); + if (it == mTaskMap.end()) + { + TLLM_CHECK_WITH_INFO(mDoneTaskMap.count(taskId) > 0, "Task %ld not found", taskId); + mDoneTaskMap.erase(taskId); + return; + } + auto task = it->second; + task->cv.wait(lk, [task] { return task->remaining == 0; }); + TLLM_CHECK_WITH_INFO(mDoneTaskMap.count(taskId) > 0, "Task %ld not found", taskId); + mDoneTaskMap.erase(taskId); +} + +void MultiThreadWorker::stop() +{ + { + std::lock_guard lk(mMutex); + if (!mRunning) + return; + mRunning = false; + } + mCondition.notify_all(); + for (auto& t : mThreads) + { + if (t.joinable()) + t.join(); + } + mThreads.clear(); +} + +void MultiThreadWorker::workerLoop(int rank) +{ + auto& topologyDetector = TopologyDetector::getInstance(); + topologyDetector.bindThreadByCurrentGpu(); // use relaxed mode + while (true) + { + std::shared_ptr task; + { + std::unique_lock lk(mMutex); + + mCondition.wait(lk, [this] { return !mRunning || !mTasks.empty(); }); + + if (!mRunning && mTasks.empty()) + return; + + task = mTasks.front(); + } + + task->func(rank, mNumThreads); + + { + std::unique_lock lk(mMutex); + if (--task->remaining == 0) + { + mTasks.pop_front(); + mTaskMap.erase(task->id); + mDoneTaskMap[task->id] = task; + task->cv.notify_all(); + } + else + { + task->cv.wait(lk, [task] { return task->remaining == 0; }); + } + } + } +} + } // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/moeLoadBalancer.h b/cpp/tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h similarity index 84% rename from cpp/tensorrt_llm/runtime/moeLoadBalancer.h rename to cpp/tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h index 5eb8070421e..4c77963c69b 100644 --- a/cpp/tensorrt_llm/runtime/moeLoadBalancer.h +++ b/cpp/tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -81,7 +82,7 @@ class MoeWeightUpdaterBase void addSingleWeightSlot(int localSlotId, std::string const& name, MoeWeight weightSlot); virtual void addSingleHostWeight(int expertId, std::string const& name, MoeWeight hostWeight) = 0; virtual void finalizeWeights(); - virtual void updateWeights(MoePlacementCpuInfo const* placementCpuInfo) = 0; + virtual void updateWeights(MoePlacementCpuInfo const* placementCpuInfo, int rank = 0, int size = 1) = 0; protected: void finalizeWeightSlot(); @@ -102,10 +103,11 @@ class HostMemoryMoeWeightUpdater : public MoeWeightUpdaterBase void addSingleHostWeight(int expertId, std::string const& name, MoeWeight hostWeight) override; void finalizeWeights() override; - void updateWeights(MoePlacementCpuInfo const* placementCpuInfo) override; + void updateWeights(MoePlacementCpuInfo const* placementCpuInfo, int rank = 0, int size = 1) override; private: static void copyWeights(MoeWeight const& src, MoeWeight const& dst, cudaStream_t stream); + static void copyWeightsCpu(MoeWeight const& src, MoeWeight const& dst, int rank, int size); void finalizeHostWeight(); bool mHostWeightsFinalized = false; std::map> mHostWeights; @@ -166,6 +168,7 @@ class SingleLayerMoeLoadBalancer private: friend class MoeLoadBalancer; + friend class HostMemoryMoeWeightUpdater; void createResources(); void destroyResources(); @@ -187,7 +190,11 @@ class SingleLayerMoeLoadBalancer bool mUpdateWeightsEnabled = true; void copyPlacementInfoToGpu(); + void copyPlacementInfoToGpuByCpu(); void updateWeightsRoutine(); + void updateWeightsRoutineByCpu(); + + int64_t mLastUpdateTaskId = -1; cudaEvent_t mUpdateWeightsDoneEvent = nullptr; tensorrt_llm::kernels::MoeLoadBalanceMetaInfo mMetaInfo; @@ -203,6 +210,42 @@ class SingleLayerMoeLoadBalancer int mLayerId = -1; }; +class MultiThreadWorker +{ +public: + explicit MultiThreadWorker(int numThreads); + ~MultiThreadWorker(); + + void start(); + int64_t addTask(std::function func); + void waitTaskDone(int64_t taskId); + void stop(); + +private: + struct Task + { + int64_t id; + std::function func; + int remaining; + std::condition_variable cv; + }; + + void workerLoop(int rank); + + int mNumThreads; + std::vector mThreads; + std::mutex mMutex; + std::condition_variable mCondition; + + std::deque> mTasks; + + std::unordered_map> mTaskMap; + std::unordered_map> mDoneTaskMap; + + bool mRunning; + int64_t mNextTaskId; +}; + class MoeLoadBalancer { public: @@ -227,8 +270,15 @@ class MoeLoadBalancer // should bind to python void shutdown(); + // Test interface to use GPU to do memcpy test functionality + void setUseGpuMemcpy(bool useGpuMemcpy = false) + { + mUseGpuMemcpy = useGpuMemcpy; + } + private: friend class SingleLayerMoeLoadBalancer; + friend class HostMemoryMoeWeightUpdater; void startThreads(); @@ -247,6 +297,8 @@ class MoeLoadBalancer std::condition_variable mUpdateQueueCondition; std::queue> mUpdateTaskQueue; void addUpdateTask(std::function task); + int64_t addCopyTask(std::function task); + void waitCopyTaskDone(int64_t taskId); std::vector> mLayers; @@ -272,10 +324,14 @@ class MoeLoadBalancer std::unique_ptr mWorkerThread; std::unique_ptr mComputeAndUpdateThread; + std::unique_ptr mMultiThreadWorker; + // update plan member and function int mLayerUpdatesPerIter = 1; std::deque> mUpdateLayerQueue; void generateUpdatePlan(); + + bool mUseGpuMemcpy = false; }; // functions exposed for testing diff --git a/cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.cpp b/cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.cpp new file mode 100644 index 00000000000..01ee8297f77 --- /dev/null +++ b/cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.cpp @@ -0,0 +1,446 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.h" + +#include "tensorrt_llm/common/assert.h" +#include "tensorrt_llm/common/cudaUtils.h" + +#include // For std::for_each, std::sort, std::unique +#include +#include +#include // For std::numeric_limits +#include +#include +#include +#include +#include +#include + +#ifdef __linux__ +#include // For errno +#include // For strerror +#include // For libnuma +#include // For struct bitmask definition if not in numa.h +#include +#include +#endif + +namespace tensorrt_llm::runtime +{ + +TopologyDetector::TopologyDetector() +{ + std::lock_guard lock(mDetectionMutex); + if (!mTopologyDetected) + { + detectCpuTopology(); + detectGpuTopology(); +#ifdef __linux__ + if (numa_available() != -1) + { // Only precompute if libnuma is usable + precomputeCpuAffinityMasks(); + } +#endif + mTopologyDetected = true; + } +} + +TopologyDetector::~TopologyDetector() +{ +#ifdef __linux__ + auto free_mask_map = [](std::map& mask_map) + { + for (auto const& [id, mask] : mask_map) + { + if (mask) + { + numa_free_cpumask(mask); + } + } + mask_map.clear(); + }; + free_mask_map(mGpuStrictCpuMasks); +#endif +} + +void TopologyDetector::detectCpuTopology() +{ + // Detect CPU architecture +#if defined(__x86_64__) || defined(_M_X64) + mCpuArchitecture = "x86_64"; +#elif defined(__aarch64__) || defined(_M_ARM64) + mCpuArchitecture = "aarch64"; +#elif defined(__powerpc64__) + mCpuArchitecture = "ppc64"; +#else + mCpuArchitecture = "unknown"; +#endif + + // Detect NUMA topology on Linux systems using libnuma +#ifdef __linux__ + if (numa_available() == -1) + { + // libnuma not available, fall back to default behavior + TLLM_LOG_WARNING("libnuma not available. Falling back to default CPU topology detection."); + mNumaToCpuCountMap[0] = std::thread::hardware_concurrency(); + return; + } + + int maxNode = numa_max_node(); + if (maxNode < 0) + { + // Failed to get max node, fall back to default behavior + TLLM_LOG_WARNING("Failed to get max NUMA node. Falling back to default CPU topology detection."); + mNumaToCpuCountMap[0] = std::thread::hardware_concurrency(); + return; + } + + mNumaToCpuCountMap.clear(); // Clear before re-populating + std::map tempNumaToCpuCountMap; + for (int i = 0; i <= maxNode; ++i) + { + struct bitmask* cpus = numa_allocate_cpumask(); + if (!cpus) + { + TLLM_LOG_WARNING("Failed to allocate cpumask for NUMA node query. Skipping node %d.", i); + continue; // Skip to the next node if allocation fails + } + + // Attempt to get CPUs for node i. If numa_node_to_cpus returns 0, it's successful. + if (numa_node_to_cpus(i, cpus) == 0) + { + int cpuCount = 0; + for (int cpu_idx = 0; cpu_idx < numa_num_possible_cpus(); ++cpu_idx) + { + if (numa_bitmask_isbitset(cpus, cpu_idx)) + { + cpuCount++; + } + } + if (cpuCount > 0) + { // Only add NUMA nodes with actual CPUs + tempNumaToCpuCountMap[i] = cpuCount; + } + } + // If numa_node_to_cpus failed (returned -1), node 'i' might be invalid or an error occurred. + // In this case, we simply don't add it to our map, effectively skipping it. + + numa_free_cpumask(cpus); // Always free the allocated mask + } + mNumaToCpuCountMap = tempNumaToCpuCountMap; + + if (mNumaToCpuCountMap.empty()) + { + // If no NUMA nodes with CPUs were detected (e.g. libnuma error or unusual configuration), + // default to a single NUMA node with all hardware concurrency. + TLLM_LOG_WARNING( + "No NUMA nodes with CPUs detected via libnuma, or libnuma error. Defaulting to single NUMA node."); + mNumaToCpuCountMap[0] = std::thread::hardware_concurrency(); + } + +#else + // For non-Linux systems, assume a single NUMA node + mNumaToCpuCountMap[0] = std::thread::hardware_concurrency(); +#endif +} + +void TopologyDetector::detectGpuTopology() +{ + int deviceCount = 0; + cudaError_t result = cudaGetDeviceCount(&deviceCount); + if (result != cudaSuccess || deviceCount == 0) + { + return; + } + mGpuToNumaMap.clear(); // Clear before re-populating + mNumaToGpuMap.clear(); // Clear before re-populating + + for (int deviceId = 0; deviceId < deviceCount; ++deviceId) + { + int numaNode = 0; // Default NUMA node + +#ifdef __linux__ + if (numa_available() != -1) + { + char pciPath[256]; + cudaDeviceProp prop; + if (cudaGetDeviceProperties(&prop, deviceId) == cudaSuccess) + { + // Construct PCI path to find NUMA node + snprintf(pciPath, sizeof(pciPath), "/sys/bus/pci/devices/%04x:%02x:%02x.0/numa_node", prop.pciDomainID, + prop.pciBusID, prop.pciDeviceID); + std::ifstream numaFile(pciPath); + if (numaFile.is_open()) + { + numaFile >> numaNode; + numaFile.close(); + // If NUMA node is -1, it means no specific NUMA information, use node 0 + if (numaNode < 0) + { + numaNode = 0; + } + } + else + { + // Fallback if sysfs path is not available or readable + TLLM_LOG_DEBUG("Could not open %s to determine NUMA node for GPU %d. Defaulting to node 0.", + pciPath, deviceId); + numaNode = 0; + } + TLLM_LOG_INFO("GPU %d is on NUMA node %d", deviceId, numaNode); + } + else + { + TLLM_LOG_WARNING("Failed to get properties for GPU %d. Defaulting to NUMA node 0.", deviceId); + numaNode = 0; + } + } + else + { + // libnuma not available, default GPU to NUMA node 0 + numaNode = 0; + } +#endif + + mGpuToNumaMap[deviceId] = numaNode; + mNumaToGpuMap[numaNode].push_back(deviceId); + } +} + +#ifdef __linux__ + +static void bitmask_copy_manual(struct bitmask* dst, const struct bitmask* src) +{ + if (!dst || !src) + return; + numa_bitmask_clearall(dst); + for (int i = 0; i < numa_num_possible_cpus(); ++i) + { + if (numa_bitmask_isbitset(src, i)) + { + numa_bitmask_setbit(dst, i); + } + } +} + +static void bitmask_or_manual(struct bitmask* dst, const struct bitmask* src) +{ + if (!dst || !src) + return; + for (int i = 0; i < numa_num_possible_cpus(); ++i) + { + if (numa_bitmask_isbitset(src, i)) + { + numa_bitmask_setbit(dst, i); + } + } +} + +void TopologyDetector::precomputeCpuAffinityMasks() +{ + int num_gpus = 0; + cudaError_t err = cudaGetDeviceCount(&num_gpus); + if (err != cudaSuccess || num_gpus == 0) + { + return; + } + + for (int gpuId = 0; gpuId < num_gpus; ++gpuId) + { + auto itGpuNuma = mGpuToNumaMap.find(gpuId); + if (itGpuNuma == mGpuToNumaMap.end()) + { + TLLM_LOG_WARNING("GPU %d not found in mGpuToNumaMap during mask precomputation. Skipping.", gpuId); + continue; + } + int gpuNumaNode = itGpuNuma->second; + + // Strict Mask: CPUs on the GPU's direct NUMA node + struct bitmask* strictMask = numa_allocate_cpumask(); // Uses numa_bitmask_alloc internally + if (strictMask) + { + numa_bitmask_clearall(strictMask); // Initialize to empty + if (mNumaToCpuCountMap.count(gpuNumaNode) && mNumaToCpuCountMap.at(gpuNumaNode) > 0) + { + if (numa_node_to_cpus(gpuNumaNode, strictMask) != 0) + { + TLLM_LOG_WARNING( + "Failed to get CPUs for GPU %d's NUMA node %d for strict mask. Strict mask will be empty.", + gpuId, gpuNumaNode); + numa_bitmask_clearall(strictMask); // Ensure it's empty on failure + } + } + mGpuStrictCpuMasks[gpuId] = strictMask; + } + else + { + TLLM_LOG_WARNING("Failed to allocate strict CPU mask for GPU %d.", gpuId); + } + } +} + +const struct bitmask* TopologyDetector::getStrictCpuMaskForGpu(int gpuId) const +{ + auto it = mGpuStrictCpuMasks.find(gpuId); + if (it != mGpuStrictCpuMasks.end()) + { + return it->second; + } + return nullptr; +} + +#endif + +void TopologyDetector::bindThreadByCurrentGpu() +{ +#ifdef __linux__ + if (numa_available() == -1) + { + TLLM_LOG_WARNING("libnuma not available. Cannot bind thread to NUMA node."); + return; + } + + int currentDevice = -1; + if (cudaGetDevice(¤tDevice) != cudaSuccess) + { + TLLM_LOG_WARNING("Failed to get current CUDA device. Cannot bind thread."); + return; + } + + const struct bitmask* targetMask = nullptr; + targetMask = getStrictCpuMaskForGpu(currentDevice); + + if (targetMask) + { + // Check if the mask is not all clear before attempting to set affinity + bool maskIsClear = true; + for (int k = 0; k < numa_num_possible_cpus(); ++k) + { + if (numa_bitmask_isbitset(targetMask, k)) + { + maskIsClear = false; + break; + } + } + + if (!maskIsClear) + { + // Create a mutable copy of the targetMask to pass to numa_sched_setaffinity + struct bitmask* mutableCopyForAffinity = numa_allocate_cpumask(); + if (mutableCopyForAffinity) + { + bitmask_copy_manual(mutableCopyForAffinity, targetMask); + if (numa_sched_setaffinity(0, mutableCopyForAffinity) == -1) + { // 0 refers to the current thread + TLLM_LOG_WARNING("Failed to set thread affinity for GPU %d using precomputed mask. Error: %s", + currentDevice, strerror(errno)); + } + numa_free_cpumask(mutableCopyForAffinity); + } + else + { + TLLM_LOG_WARNING( + "Failed to allocate temporary bitmask for setting affinity. Cannot bind thread for GPU %d.", + currentDevice); + } + } + else + { + TLLM_LOG_DEBUG("Target affinity mask for GPU %d is empty. Not setting affinity.", currentDevice); + } + } + else + { + TLLM_LOG_WARNING("Precomputed CPU affinity mask not found for GPU %d. Cannot bind thread.", currentDevice); + } + +#else + TLLM_LOG_DEBUG("Thread binding by GPU NUMA node is only supported on Linux with libnuma."); +#endif +} + +int TopologyDetector::getCurrentGpuNumaCpuCount() +{ + int numaId = getCurrentGpuNumaId(); + if (numaId >= 0) + { + auto it = mNumaToCpuCountMap.find(numaId); + if (it != mNumaToCpuCountMap.end()) + { + return it->second; + } + } + TLLM_LOG_DEBUG( + "CPU count for GPU's NUMA node %d not found or node invalid. Returning total hardware concurrency.", numaId); + return std::thread::hardware_concurrency(); +} + +int TopologyDetector::getCurrentGpuNumaId() +{ + int currentDevice = -1; + if (cudaGetDevice(¤tDevice) != cudaSuccess) + { + return -1; // Indicate error or no CUDA device context + } + + auto it = mGpuToNumaMap.find(currentDevice); + if (it != mGpuToNumaMap.end()) + { + return it->second; + } + TLLM_LOG_WARNING("NUMA node for current GPU %d not found in map. Defaulting to node 0.", currentDevice); + return 0; +} + +int TopologyDetector::getGpuCountUnderNuma(int numaId) +{ + auto it = mNumaToGpuMap.find(numaId); + if (it != mNumaToGpuMap.end()) + { + return it->second.size(); + } + return 0; +} + +std::string TopologyDetector::getCpuArchitecture() +{ + return mCpuArchitecture; +} + +bool TopologyDetector::canSupportHostNativeAtomics() +{ + int currentDevice = -1; + if (cudaGetDevice(¤tDevice) != cudaSuccess) + { + TLLM_LOG_WARNING("Failed to get current CUDA device for atomic support check."); + return false; + } + + int hostNativeAtomicSupported = 0; + cudaError_t err + = cudaDeviceGetAttribute(&hostNativeAtomicSupported, cudaDevAttrHostNativeAtomicSupported, currentDevice); + + if (err != cudaSuccess) + { + TLLM_LOG_WARNING("Failed to get cudaDevAttrHostNativeAtomicSupported for device %d. Error: %s", currentDevice, + cudaGetErrorString(err)); + return false; + } + return static_cast(hostNativeAtomicSupported); +} + +} // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.h b/cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.h new file mode 100644 index 00000000000..f1cd279dbbb --- /dev/null +++ b/cpp/tensorrt_llm/runtime/moeLoadBalancer/topologyDetector.h @@ -0,0 +1,100 @@ +/* + * Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#ifdef __linux__ +#include // For libnuma +#endif + +// Forward declaration for struct bitmask to avoid including numaif.h if numa.h already covers it, +// or if only numa.h is intended to be the public include for this header's users. +#ifdef __linux__ +struct bitmask; +#endif + +namespace tensorrt_llm::runtime +{ + +class TopologyDetector +{ +public: + static TopologyDetector& getInstance() + { + static TopologyDetector instance; + return instance; + } + + ~TopologyDetector(); + + // Binds the current thread to the CPU cores of the NUMA node associated with the current GPU. + void bindThreadByCurrentGpu(); + + // Returns the number of CPU cores on the NUMA node associated with the current GPU. + // Returns total hardware concurrency as a fallback if specific count cannot be determined. + int getCurrentGpuNumaCpuCount(); + + // Returns the ID of the NUMA node associated with the current GPU. + // Returns 0 as a default or -1 on error. + int getCurrentGpuNumaId(); + + // Returns the number of GPUs associated with the given NUMA node ID. + int getGpuCountUnderNuma(int numaId); + + // Returns the number of GPUs which have same NUMA node ID with the current GPU. + int getGpuCountUnderSameNuma() + { + return getGpuCountUnderNuma(getCurrentGpuNumaId()); + } + + // Returns the detected CPU architecture (e.g., "x86_64", "aarch64"). + std::string getCpuArchitecture(); + + // Checks if the current CUDA device and host system support native atomic operations. + bool canSupportHostNativeAtomics(); + +#ifdef __linux__ + // Getters for precomputed CPU affinity masks + const struct bitmask* getStrictCpuMaskForGpu(int gpuId) const; +#endif + +private: + TopologyDetector(); + void detectCpuTopology(); // Detects CPU NUMA topology and CPU counts per node. + void detectGpuTopology(); // Detects GPU to NUMA node mapping. +#ifdef __linux__ + void precomputeCpuAffinityMasks(); // Precomputes CPU masks for each GPU +#endif + + // Member variables + std::map mGpuToNumaMap; // GPU ID -> NUMA Node ID + std::map> mNumaToGpuMap; // NUMA Node ID -> List of GPU IDs + std::map mNumaToCpuCountMap; // NUMA Node ID -> CPU Core Count + std::string mCpuArchitecture; + bool mTopologyDetected = false; + std::mutex mDetectionMutex; // Mutex to protect topology detection process + +#ifdef __linux__ + // Precomputed CPU affinity masks + std::map mGpuStrictCpuMasks; // GPU ID -> Strict CPU mask +#endif +}; + +} // namespace tensorrt_llm::runtime diff --git a/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp b/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp index 68da7dc8338..d2df16424e3 100644 --- a/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp +++ b/cpp/tensorrt_llm/thop/moeLoadBalanceOp.cpp @@ -20,12 +20,14 @@ #include "tensorrt_llm/runtime/torchUtils.h" #include "tensorrt_llm/thop/thUtils.h" +#include // for c10::DataPtr +#include // for c10::StorageImpl and use_byte_size_t() #include -#include +#include // for c10::make_intrusive#include #include #include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h" -#include "tensorrt_llm/runtime/moeLoadBalancer.h" +#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h" namespace torch_ext { @@ -109,6 +111,43 @@ torch::Tensor moeLoadBalanceRouting( return tokenRoutedSlotIds; } +void migrateToManaged(at::Tensor& tensor) +{ + TORCH_CHECK(tensor.device().is_cuda(), "only support CUDA Tensor"); + + // 1) compute total bytes + size_t byte_size = tensor.numel() * tensor.element_size(); + + // 2) allocate UVM + void* managed_ptr = nullptr; + cudaError_t err = cudaMallocManaged(&managed_ptr, byte_size); + TORCH_CHECK(err == cudaSuccess, "cudaMallocManaged failed"); + + // 3) advise to place on current GPU + int cur_dev; + TLLM_CUDA_CHECK(cudaGetDevice(&cur_dev)); + TLLM_CUDA_CHECK(cudaMemAdvise(managed_ptr, byte_size, cudaMemAdviseSetPreferredLocation, cur_dev)); + TLLM_CUDA_CHECK(cudaMemAdvise(managed_ptr, byte_size, cudaMemAdviseSetAccessedBy, cur_dev)); + TLLM_CUDA_CHECK(cudaMemAdvise(managed_ptr, byte_size, cudaMemAdviseSetAccessedBy, cudaCpuDeviceId)); + + // 4) copy old data to UVM + TLLM_CUDA_CHECK(cudaMemcpy(managed_ptr, tensor.data_ptr(), byte_size, cudaMemcpyDeviceToDevice)); + + // 5) use new DataPtr/StorageImpl to construct storage + // here managed_ptr is data,and also context,use cudaFree as deleter + c10::DataPtr dp( + managed_ptr, managed_ptr, [](void* ptr) { cudaFree(ptr); }, tensor.device()); + auto allocator = c10::GetAllocator(tensor.device().type()); + auto storage_impl = c10::make_intrusive(c10::StorageImpl::use_byte_size_t(), byte_size, + std::move(dp), allocator, + /*resizable=*/false); + at::Storage new_storage(storage_impl); + + // Finally replace tensor's storage,offset = 0,shape and stride kept unchanged + tensor.set_(new_storage, + /*storage_offset=*/0, tensor.sizes().vec(), tensor.strides().vec()); +} + } // namespace torch_ext TORCH_LIBRARY_FRAGMENT(trtllm, m) @@ -154,3 +193,13 @@ TORCH_LIBRARY_IMPL(trtllm, CUDA, m) { m.impl("moe_load_balance_routing", &torch_ext::moeLoadBalanceRouting); } + +TORCH_LIBRARY_FRAGMENT(trtllm, m) +{ + m.def("migrate_to_managed(Tensor tensor) -> ()"); +} + +TORCH_LIBRARY_IMPL(trtllm, CUDA, m) +{ + m.impl("migrate_to_managed", &torch_ext::migrateToManaged); +} diff --git a/cpp/tests/runtime/moeLoadBalancerTest.cpp b/cpp/tests/runtime/moeLoadBalancerTest.cpp index bf75b92d442..739891c9f35 100644 --- a/cpp/tests/runtime/moeLoadBalancerTest.cpp +++ b/cpp/tests/runtime/moeLoadBalancerTest.cpp @@ -18,7 +18,7 @@ #include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.h" -#include "tensorrt_llm/runtime/moeLoadBalancer.h" +#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h" using namespace tensorrt_llm::runtime; @@ -318,6 +318,8 @@ class MoeLoadBalancerTest : public ::testing::TestWithParam(param.epRank, param.epSize, param.layerUpdatesPerIter); + mLoadBalancer->setUseGpuMemcpy(true); + // Create multiple MoE layers createLayers(param); diff --git a/docker/common/install_base.sh b/docker/common/install_base.sh index 572b5ad0867..0b4eb91ca8a 100644 --- a/docker/common/install_base.sh +++ b/docker/common/install_base.sh @@ -53,6 +53,8 @@ init_ubuntu() { llvm \ libclang-rt-dev \ libffi-dev \ + libnuma1 \ + libnuma-dev \ python3-dev \ python3-pip \ python-is-python3 \ @@ -88,6 +90,8 @@ install_python_rockylinux() { llvm-toolset \ lld \ libffi-devel \ + numactl \ + numactl-devel \ zlib-devel \ xz-devel \ sqlite-devel \ diff --git a/jenkins/L0_MergeRequest.groovy b/jenkins/L0_MergeRequest.groovy index 8f4162e89e5..03abfd7453a 100644 --- a/jenkins/L0_MergeRequest.groovy +++ b/jenkins/L0_MergeRequest.groovy @@ -28,10 +28,10 @@ UPLOAD_PATH = env.uploadPath ? env.uploadPath : "sw-tensorrt-generic/llm-artifac // Container configuration // available tags can be found in: https://urm.nvidia.com/artifactory/sw-tensorrt-docker/tensorrt-llm/ // [base_image_name]-[arch]-[os](-[python_version])-[trt_version]-[torch_install_type]-[stage]-[date]-[mr_id] -LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539" -LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539" -LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202505211401-4539" -LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202505211401-4539" +LLM_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420" +LLM_SBSA_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-aarch64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420" +LLM_ROCKYLINUX8_PY310_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py310-trt10.10.0.31-skip-tritondevel-202506021004-9420" +LLM_ROCKYLINUX8_PY312_DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:cuda-12.9.0-devel-rocky8-x86_64-rocky8-py312-trt10.10.0.31-skip-tritondevel-202506021004-9420" // TODO: Move common variables to an unified location BUILD_CORES_REQUEST = "8" diff --git a/jenkins/controlCCache.groovy b/jenkins/controlCCache.groovy index d5520cbfd3f..37dcb1589f7 100644 --- a/jenkins/controlCCache.groovy +++ b/jenkins/controlCCache.groovy @@ -1,7 +1,7 @@ import java.lang.InterruptedException -DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539" +DOCKER_IMAGE = "urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420" def createKubernetesPodConfig(image) { diff --git a/scripts/build_wheel.py b/scripts/build_wheel.py index 18f6af4c208..bdb3bde8930 100755 --- a/scripts/build_wheel.py +++ b/scripts/build_wheel.py @@ -440,7 +440,7 @@ def main(*, with working_directory(build_dir): if clean or first_build or configure_cmake: build_run( - f"\"{venv_conan}\" install --remote=tensorrt-llm --output-folder={build_dir}/conan -s 'build_type={build_type}' {source_dir}" + f"\"{venv_conan}\" install --build=missing --remote=tensorrt-llm --output-folder={build_dir}/conan -s 'build_type={build_type}' {source_dir}" ) cmake_def_args.append( f"-DCMAKE_TOOLCHAIN_FILE={build_dir}/conan/conan_toolchain.cmake" diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 9b7d1865c67..c2f817c25a2 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -23,18 +23,13 @@ class MoeLoadBalancerConfig: repr=False) layer_updates_per_iter: int = 0 - num_experts: Optional[int] = field(default=None, init=False) ep_rank: Optional[int] = field(default=None, init=False) ep_size: Optional[int] = field(default=None, init=False) - def setup(self, num_experts: int, ep_rank: int, ep_size: int) -> None: - self.num_experts = num_experts + def setup(self, ep_rank: int, ep_size: int) -> None: self.ep_rank = ep_rank self.ep_size = ep_size - if self.num_slots is None: - self.num_slots = self.num_experts - assert self.num_slots >= self.num_experts - assert self.num_slots % self.ep_size == 0 + assert self.num_slots is not None @property def num_local_slots(self) -> int: @@ -49,17 +44,13 @@ def slot_end(self) -> int: return self.slot_start + self.num_local_slots def get_layer_initial_global_assignments(self, layer_idx: int) -> List[int]: - if self.initial_global_assignments is None: - return [(ep_rank * self.num_experts // self.ep_size + i) % - self.num_experts for ep_rank in range(self.ep_size) - for i in range(self.num_local_slots)] - else: + if self.initial_global_assignments is not None: assert layer_idx in self.initial_global_assignments assert len( self.initial_global_assignments[layer_idx]) == self.num_slots - assert set(self.initial_global_assignments[layer_idx]) == set( - range(self.num_experts)) return self.initial_global_assignments[layer_idx] + else: + return None @dataclass(kw_only=True) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 51bddc00355..35579515b92 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -53,7 +53,7 @@ from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod, - MoeLoadBalancer, create_moe) + create_moe) from ..modules.gated_mlp import GatedMLP from ..modules.linear import Linear from ..modules.multi_stream_utils import maybe_execute_in_parallel @@ -344,7 +344,6 @@ def __init__(self, dtype: Optional[torch.dtype] = None, model_config: ModelConfig = ModelConfig(), override_quant_config: Optional[QuantConfig] = None, - moe_load_balancer: Optional[MoeLoadBalancer] = None, layer_idx: Optional[int] = None): from ..distributed import AllReduce @@ -379,7 +378,6 @@ def __init__(self, override_quant_config=override_quant_config, aux_stream=aux_stream_dict[AuxStreamType.MoeChunkingOverlap], enable_alltoall=self.enable_alltoall, - moe_load_balancer=moe_load_balancer, layer_idx=layer_idx) self.mapping = model_config.mapping @@ -542,11 +540,9 @@ def _compute_routed_output(): class DeepseekV3DecoderLayer(DecoderLayer): - def __init__(self, - model_config: ModelConfig[PretrainedConfig], - layer_idx: int, - aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], - moe_load_balancer: Optional[MoeLoadBalancer] = None): + def __init__(self, model_config: ModelConfig[PretrainedConfig], + layer_idx: int, aux_stream_dict: Dict[AuxStreamType, + torch.cuda.Stream]): super().__init__() self.model_config = model_config config = model_config.pretrained_config @@ -598,7 +594,6 @@ def __init__(self, model_config=model_config, override_quant_config=quant_config, aux_stream_dict=aux_stream_dict, - moe_load_balancer=moe_load_balancer, layer_idx=layer_idx) else: block_size = 1 @@ -865,13 +860,10 @@ def forward_mlp( class DeepseekV3MTP(DeepseekV3DecoderLayer): - def __init__(self, - model_config: ModelConfig[PretrainedConfig], - layer_idx: int, - aux_stream_dict: Dict[AuxStreamType, torch.cuda.Stream], - moe_load_balancer: Optional[MoeLoadBalancer] = None): - super().__init__(model_config, layer_idx, aux_stream_dict, - moe_load_balancer) + def __init__(self, model_config: ModelConfig[PretrainedConfig], + layer_idx: int, aux_stream_dict: Dict[AuxStreamType, + torch.cuda.Stream]): + super().__init__(model_config, layer_idx, aux_stream_dict) config = model_config.pretrained_config self.hidden_dim = config.hidden_size self.moe_intermediate_size = config.moe_intermediate_size @@ -992,23 +984,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): dtype=config.torch_dtype, ) - self.moe_load_balancer = None - if model_config.moe_load_balancer is not None: - num_experts = config.n_routed_experts - ep_rank = model_config.mapping.moe_ep_rank - ep_size = model_config.mapping.moe_ep_size - model_config.moe_load_balancer.setup(num_experts=num_experts, - ep_rank=ep_rank, - ep_size=ep_size) - self.moe_load_balancer = MoeLoadBalancer( - ep_rank=ep_rank, - ep_size=ep_size, - layer_updates_per_iter=model_config.moe_load_balancer. - layer_updates_per_iter) - self.layers = nn.ModuleList([ DeepseekV3DecoderLayer(model_config, layer_idx, - self.aux_stream_dict, self.moe_load_balancer) + self.aux_stream_dict) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(hidden_size=config.hidden_size, @@ -1054,7 +1032,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): hidden_size=model_config.pretrained_config.hidden_size, vocab_size=model_config.pretrained_config.vocab_size) - self.moe_load_balancer = self.model.moe_load_balancer self.model_nextn = 0 if model_config.spec_config is not None: model_nextn = model_config.spec_config.num_nextn_predict_layers @@ -1063,8 +1040,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint." if ckpt_nextn == 1: mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers, - self.model.aux_stream_dict, - self.moe_load_balancer) + self.model.aux_stream_dict) self.model.layers.append(mtp_layer) self.epilogue.append(mtp_layer) self.mtp_worker = MTPEagleWorker(model_config.spec_config) @@ -1074,8 +1050,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): mtp_layers = nn.ModuleList([ DeepseekV3MTP(model_config, layer_idx + self.num_hidden_layers, - self.model.aux_stream_dict, - self.moe_load_balancer) + self.model.aux_stream_dict) for layer_idx in range(model_nextn) ]) self.model.layers.extend(mtp_layers) @@ -1100,9 +1075,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): extend_exclude_modules) self.epilogue.append(self.mtp_worker) - if self.moe_load_balancer is not None: - self.moe_load_balancer.finalize_model() - def forward( self, attn_metadata: AttentionMetadata, diff --git a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py index 18f87da45aa..98f59026dbc 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/create_moe.py +++ b/tensorrt_llm/_torch/modules/fused_moe/create_moe.py @@ -10,7 +10,7 @@ from .fused_moe_trtllm_gen import TRTLLMGenFusedMoE from .fused_moe_vanilla import VanillaMoE from .interface import MoE, MoEWeightLoadingMode -from .moe_load_balancer import MoeLoadBalancer +from .moe_load_balancer import get_moe_load_balancer from .routing import BaseMoeRoutingMethod @@ -53,15 +53,17 @@ def create_moe( weight_loading_mode: MoEWeightLoadingMode = MoEWeightLoadingMode.VANILLA, apply_router_weight_on_input: bool = False, enable_alltoall: bool = False, - moe_load_balancer: Optional[MoeLoadBalancer] = None, layer_idx: Optional[int] = None, ) -> MoE: moe_cls = get_moe_cls(model_config, override_quant_config) + moe_load_balancer = get_moe_load_balancer() + if moe_load_balancer is not None: + assert moe_cls == CutlassFusedMoE, "MoE Load Balance is only supported in CutlassFusedMoE now." + if moe_cls == TRTLLMGenFusedMoE: assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in TRTLLMGenFusedMoE." assert not enable_alltoall, "enable_alltoall is not supported in TRTLLMGenFusedMoE." - assert moe_load_balancer is None, "moe_load_balancer is not supported in TRTLLMGenFusedMoE." return moe_cls( routing_method=routing_method, @@ -87,13 +89,11 @@ def create_moe( weight_loading_mode=weight_loading_mode, apply_router_weight_on_input=apply_router_weight_on_input, enable_alltoall=enable_alltoall, - moe_load_balancer=moe_load_balancer, layer_idx=layer_idx, ) elif moe_cls == VanillaMoE: assert not apply_router_weight_on_input, "apply_router_weight_on_input is not supported in VanillaMoE." assert not enable_alltoall, "enable_alltoall is not supported in VanillaMoE." - assert moe_load_balancer is None, "moe_load_balancer is not supported in VanillaMoE." return moe_cls( routing_method=routing_method, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 71e1b0cf85f..9636c1e8959 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -1,5 +1,5 @@ import os -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import torch @@ -8,11 +8,11 @@ from ...distributed import allgather, reducescatter from ...expert_statistic import ExpertStatistic -from ...model_config import ModelConfig, MoeLoadBalancerConfig +from ...model_config import ModelConfig from ...utils import (EventType, Fp4QuantizedTensor, disable_fp4_allgather, reswizzle_sf, swizzle_sf, unswizzle_sf) from .interface import MoE -from .moe_load_balancer import MoeLoadBalancer +from .moe_load_balancer import get_moe_load_balancer from .quantization import (FP8BlockScalesFusedMoEMethod, FP8QDQFusedMoEMethod, MoEWeightLoadingMode, NVFP4CutlassFusedMoEMethod, UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod) @@ -82,7 +82,6 @@ def __init__( VANILLA, apply_router_weight_on_input: bool = False, enable_alltoall: bool = False, - moe_load_balancer: Optional[MoeLoadBalancer] = None, layer_idx: Optional[int] = None, ): @@ -98,47 +97,57 @@ def __init__( ) self.layer_idx = layer_idx + + moe_load_balancer = get_moe_load_balancer() + self.layer_load_balancer = None + moe_load_balancer_config = model_config.moe_load_balancer - if moe_load_balancer_config is None: - assert moe_load_balancer is None - # A dummy MoeLoadBalancerConfig to generate default initial_global_assignments - moe_load_balancer_config = MoeLoadBalancerConfig() - moe_load_balancer_config.setup(num_experts=num_experts, - ep_rank=self.ep_rank, - ep_size=self.ep_size) + init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size + self.initial_global_assignments = [ + (ep_rank * self.num_experts // self.ep_size + local_slot_id) % + self.num_experts for ep_rank in range(self.ep_size) + for local_slot_id in range(init_expert_size_per_partition) + ] + + if moe_load_balancer: + assert moe_load_balancer_config is not None + top_k = self.routing_method.experts_per_token + self.expert_size_per_partition = moe_load_balancer_config.num_local_slots + self.layer_load_balancer = moe_load_balancer.add_layer( + self.num_experts, top_k, self.expert_size_per_partition) + loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments( + self.layer_idx) + self.num_slots = moe_load_balancer_config.num_slots + if loaded_initial_global_assignments is not None: + assert isinstance(loaded_initial_global_assignments, list) + assert len(loaded_initial_global_assignments) == self.num_slots + assert self.num_slots >= self.num_experts + assert set(loaded_initial_global_assignments) == set( + range(self.num_experts)) + self.initial_global_assignments = loaded_initial_global_assignments + self.layer_load_balancer.set_initial_weight_assignments( + self.initial_global_assignments) + logger.info( + f"MoE load balancer enabled. num_experts = {num_experts}, num_slots = {self.num_slots}, ep_size = {self.ep_size}" + ) + logger.info( + f"initial_global_assignments (layer {self.layer_idx}) = {self.initial_global_assignments}" + ) else: - assert moe_load_balancer is not None + assert num_experts % self.ep_size == 0 + self.expert_size_per_partition = num_experts // self.ep_size + self.num_slots = num_experts - self.num_slots = moe_load_balancer_config.num_slots if self.smart_router: assert self.num_slots == self.num_experts, "Smart router should not have redundant slots" - self.initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments( - layer_idx) - self.expert_size_per_partition = moe_load_balancer_config.num_local_slots - self.slot_start = moe_load_balancer_config.slot_start - self.slot_end = moe_load_balancer_config.slot_end + self.slot_start = self.ep_rank * self.expert_size_per_partition + self.slot_end = self.slot_start + self.expert_size_per_partition self.initial_local_expert_ids = self.initial_global_assignments[ self.slot_start:self.slot_end] assert len( self.initial_local_expert_ids) == self.expert_size_per_partition - self.balancer_layer = None - if moe_load_balancer is not None: - self.balancer_layer = moe_load_balancer.add_layer( - expert_count=num_experts, - top_k=routing_method.experts_per_token, - slot_count_per_rank=self.expert_size_per_partition, - ) - self.balancer_layer.set_initial_weight_assignments( - self.initial_global_assignments) - logger.info( - f"MoE load balancer enabled. num_experts = {num_experts}, num_slots = {self.num_slots}, ep_size = {self.ep_size}" - ) - logger.info( - f"initial_global_assignments (layer {layer_idx}) = {self.initial_global_assignments}" - ) - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled if self.use_dp: @@ -259,13 +268,14 @@ def reducescatter_or_allreduce( return outputs def forward_chunk( - self, - x: Union[torch.Tensor, Fp4QuantizedTensor], - router_logits: torch.Tensor, - cutlass_min_latency_mode: bool = False, - output_dtype: Optional[torch.dtype] = None, - all_rank_num_tokens: Optional[List[int]] = None, - use_dp_padding: Optional[bool] = None, + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + cutlass_min_latency_mode: bool = False, + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + use_dp_padding: Optional[bool] = None, + repeating_info: Tuple = (True, True), ) -> torch.Tensor: if isinstance(x, Fp4QuantizedTensor): assert output_dtype is not None @@ -273,31 +283,25 @@ def forward_chunk( else: output_dtype = x.dtype + is_first_call, is_last_call = repeating_info + + if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( + ) and is_first_call: + self.layer_load_balancer.wait_for_gpu_stage() + use_fp8_block_scaling = False use_w4a8_group_scaling = False weight_dtype = self.w3_w1_weight.dtype token_selected_experts, token_final_scales = self.routing_method.apply( router_logits) - if self.balancer_layer is None: - token_selected_slots = token_selected_experts - else: - # If attention DP is enabled, token_selected_experts is a local rank tensor, - # so we need to offset the round robin position by ep_rank - token_selected_slots = self.balancer_layer.route( - token_selected_experts, offset_by_ep_rank=self.use_dp) - # If load balancer is disabled, the statistics are collected from expert IDs. - # If load balancer is enabled, the statistics are collected from expert slot IDs. - ExpertStatistic.set_layer(self.layer_idx) - ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) - - assert token_selected_slots.shape[ + assert token_selected_experts.shape[ 1] == self.routing_method.experts_per_token - assert token_selected_slots.shape == token_final_scales.shape - assert token_selected_slots.shape[0] == router_logits.shape[0] + assert token_selected_experts.shape == token_final_scales.shape + assert token_selected_experts.shape[0] == router_logits.shape[0] assert token_final_scales.dtype == torch.float32 - assert token_selected_slots.dtype == torch.int32 + assert token_selected_experts.dtype == torch.int32 if self.apply_router_weight_on_input: assert self.routing_method.top_k == 1, "Current workaround only supports top-1 routing" @@ -310,13 +314,32 @@ def forward_chunk( alltoall_info = None + if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( + ) and is_first_call: + self.layer_load_balancer.maybe_cudagraph_done_wait() + + need_statistic = False + if self.layer_load_balancer is None: + token_selected_slots = token_selected_experts + else: + token_selected_slots = self.layer_load_balancer.route( + token_selected_experts, self.use_dp) + if not self.layer_load_balancer.is_static_routing(): + need_statistic = True + + # If load balancer is disabled, the statistics are collected from expert IDs. + # If load balancer is enabled, the statistics are collected from expert slot IDs. + ExpertStatistic.set_layer(self.layer_idx) + ExpertStatistic.maybe_add_info(self.num_slots, token_selected_slots) + + token_selected_experts_for_statistic = token_selected_experts if need_statistic else None if self.enable_alltoall: - x, token_selected_slots, token_final_scales, alltoall_info = \ + x, token_selected_slots, token_final_scales, token_selected_experts_for_statistic, alltoall_info = \ self.alltoall_prepare_maybe_dispatch(all_rank_num_tokens, x, token_selected_slots, - token_final_scales) - + token_final_scales, + token_selected_experts_for_statistic) x_sf = None if self.has_any_quant: if self.has_fp8_qdq: @@ -348,8 +371,11 @@ def forward_chunk( if self.use_dp and self.parallel_size > 1 and not disable_fp4_allgather( ) and not self.enable_alltoall: - x, x_sf, token_selected_slots, token_final_scales = allgather( - [x, x_sf, token_selected_slots, token_final_scales], + x, x_sf, token_selected_slots, token_final_scales, token_selected_experts_for_statistic = allgather( + [ + x, x_sf, token_selected_slots, token_final_scales, + token_selected_experts_for_statistic + ], self.mapping, dim=0, sizes=None if use_dp_padding else all_rank_num_tokens) @@ -358,6 +384,12 @@ def forward_chunk( x_sf = reswizzle_sf(x_sf, x_row, x_col, self.scaling_vector_size) + if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( + ): + self.layer_load_balancer.statistic( + token_selected_experts_for_statistic, is_first_call, + is_last_call) + if self.smart_router and not cutlass_min_latency_mode: ep_size = self.cluster_size ep_rank = self.cluster_rank @@ -405,20 +437,29 @@ def forward_chunk( tune_max_num_tokens=self.tune_max_num_tokens, ) + if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( + ) and is_last_call: + self.layer_load_balancer.set_cpu_stage() + if cutlass_min_latency_mode: assert not self.reduce_results - return final_hidden_states + assert not self.enable_alltoall else: # Custom op requires all inputs are in the same type. # Only in cutlass_min_latency_mode, the output is a list of tensors. # Otherwise, the output should be unpacked as a single tensor. final_hidden_states = final_hidden_states[0] - if not self.enable_alltoall: - return final_hidden_states - else: - return self.alltoall_combine(final_hidden_states, alltoall_info, - token_count) + if self.enable_alltoall: + final_hidden_states = self.alltoall_combine(final_hidden_states, + alltoall_info, + token_count) + + if self.layer_load_balancer and not self.layer_load_balancer.is_static_routing( + ) and is_last_call: + self.layer_load_balancer.maybe_cudagraph_done_set_cpu_stage() + + return final_hidden_states def forward( self, @@ -500,6 +541,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int): # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap for idx_chunk, (x, router_logits) in enumerate( zip(x_list, router_logits_list)): + is_first_call = idx_chunk == 0 + is_last_call = idx_chunk == num_chunks - 1 if not self.enable_alltoall: if idx_chunk % 2 == 0: with torch.cuda.stream(self.aux_stream): @@ -508,7 +551,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int): router_logits, all_rank_num_tokens=all_rank_num_tokens_list[ idx_chunk] if self.use_dp else None, - use_dp_padding=use_dp_padding) + use_dp_padding=use_dp_padding, + repeating_info=(is_first_call, is_last_call)) if idx_chunk > 0: outputs_list[-1] = self.reducescatter_or_allreduce( outputs_list[-1], @@ -521,7 +565,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int): router_logits, all_rank_num_tokens=all_rank_num_tokens_list[ idx_chunk] if self.use_dp else None, - use_dp_padding=use_dp_padding) + use_dp_padding=use_dp_padding, + repeating_info=(is_first_call, is_last_call)) with torch.cuda.stream(self.aux_stream): outputs_list[-1] = self.reducescatter_or_allreduce( outputs_list[-1], @@ -533,7 +578,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int): x, router_logits, all_rank_num_tokens=all_rank_num_tokens_list[idx_chunk] - if self.use_dp else None) + if self.use_dp else None, + repeating_info=(is_first_call, is_last_call)) outputs_list.append(outputs) if not self.enable_alltoall: @@ -557,32 +603,48 @@ def split_chunk(split_token_num: int, split_num_chunks: int): outputs = outputs[:all_rank_num_tokens[rank]] return outputs - def alltoall_prepare_maybe_dispatch(self, all_rank_num_tokens: list, - x: torch.Tensor, - token_selected_slots: torch.Tensor, - token_final_scales: torch.Tensor): + def alltoall_prepare_maybe_dispatch( + self, all_rank_num_tokens: list, x: torch.Tensor, + token_selected_slots: torch.Tensor, + token_final_scales: torch.Tensor, + token_selected_experts_for_statistic: Optional[torch.Tensor]): top_k = self.routing_method.experts_per_token - expert_count = self.num_experts # gather router info max_num_token = max(all_rank_num_tokens) token_selected_slots = torch.nn.functional.pad( token_selected_slots, (0, 0, 0, max_num_token - token_selected_slots.shape[0]), - 'constant', self.num_experts) + 'constant', self.num_slots) + token_selected_experts_for_statistic = torch.nn.functional.pad( + token_selected_experts_for_statistic, + (0, 0, 0, + max_num_token - token_selected_experts_for_statistic.shape[0]), + 'constant', self.num_experts + ) if token_selected_experts_for_statistic is not None else None token_final_scales = torch.nn.functional.pad( token_final_scales, (0, 0, 0, max_num_token - token_final_scales.shape[0])) - gathered_token_selected_slots, gathered_token_final_scales = allgather( - [token_selected_slots, token_final_scales], self.mapping, dim=0) + gathered_token_selected_slots, gathered_token_final_scales, gathered_token_selected_experts_for_statistic = allgather( + [ + token_selected_slots, token_final_scales, + token_selected_experts_for_statistic + ], + self.mapping, + dim=0) + if gathered_token_selected_experts_for_statistic is not None: + gathered_token_selected_experts_for_statistic = torch.flatten( + gathered_token_selected_experts_for_statistic.contiguous(), + start_dim=0, + end_dim=-2) gathered_token_selected_slots = torch.flatten( gathered_token_selected_slots.contiguous(), start_dim=0, end_dim=-2) gathered_token_final_scales = torch.flatten( gathered_token_final_scales.contiguous(), start_dim=0, end_dim=-2) gathered_target_rank_ids = MnnvlMoe.compute_target_rank_id( - gathered_token_selected_slots, self.num_experts, self.ep_size) + gathered_token_selected_slots, self.num_slots, self.ep_size) alltoall_info, token_selected_slots, token_final_scales = MnnvlMoe.mnnvl_moe_alltoallv_prepare( gathered_target_rank_ids, None, gathered_token_selected_slots, - gathered_token_final_scales, max_num_token, expert_count, top_k, + gathered_token_final_scales, max_num_token, self.num_slots, top_k, self.ep_rank, self.ep_size) if not self.use_postquant_alltoall: @@ -593,7 +655,7 @@ def alltoall_prepare_maybe_dispatch(self, all_rank_num_tokens: list, self.alltoall_workspace, self.ep_rank, self.ep_size) - return x, token_selected_slots, token_final_scales, alltoall_info + return x, token_selected_slots, token_final_scales, gathered_token_selected_experts_for_statistic, alltoall_info def alltoall_postquant_dispatch(self, x: torch.Tensor, x_sf: torch.Tensor, x_row: int, x_col: int, @@ -633,6 +695,63 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor, return final_hidden_states + def register_parameter_weight_slot_fn(self, weight_name: str, + local_slot_id: int): + assert hasattr( + self, + weight_name), f"FusedMoE doesn't has weight attr: {weight_name}" + weight_tensor = getattr(self, weight_name).data[local_slot_id] + self.layer_load_balancer.register_weight_slot(local_slot_id, + weight_name, + weight_tensor) + + def register_to_fix_weight_fn(self, weight_name: str): + assert hasattr( + self, + weight_name), f"FusedMoE doesn't has weight attr: {weight_name}" + param = getattr(self, weight_name) + weight_tensor = param.detach() + assert isinstance( + weight_tensor, + torch.Tensor), f'weight {weight_name} should be a tensor' + assert weight_tensor.is_contiguous( + ), f'weight {weight_name} should be a is_contiguous, shape={weight_tensor.shape}, strides={weight_tensor.is_contiguous()}' + assert weight_tensor.numel() * weight_tensor.element_size() == weight_tensor.untyped_storage().size(),\ + f'weight {weight_name} shape={weight_tensor.shape} storage_size = {weight_tensor.untyped_storage().size()}, numel={weight_tensor.numel()}, eltsize={weight_tensor.element_size()}, dtype={weight_tensor.dtype}' + self.layer_load_balancer.fix_tensor(weight_tensor) + param.data = weight_tensor + + def register_all_parameter_slot_and_to_fix_weight_fns( + self, weight_and_tensor_dict: Dict[str, torch.Tensor]): + """ + weight_and_tensor_dict: key is the name of the weight, value is the tensor of loaded shared tensor shard. + E.g. if num_experts=256 and 4 GPUs per node, then each rank need to load 256 / 4 = 64 expert weights for host sharing. + By this way, host_tensor_sharer can share the weights and each rank has access to all 256 experts. + """ + for local_slot_id, expert_id in enumerate( + self.initial_local_expert_ids): + for weight_name in weight_and_tensor_dict: + self.layer_load_balancer.add_register_weight_fn( + self.register_parameter_weight_slot_fn, + (weight_name, local_slot_id)) + for weight_name in weight_and_tensor_dict: + self.layer_load_balancer.add_to_fix_weight_fn( + self.register_to_fix_weight_fn, (weight_name, )) + + local_shared_load_expert_ids = self.layer_load_balancer.get_load_expert_ids( + ) + for expert_id in range(self.num_experts): + for weight_name, weight_tensor in weight_and_tensor_dict.items(): + if expert_id in local_shared_load_expert_ids: + local_slot_id = local_shared_load_expert_ids.index( + expert_id) + self.layer_load_balancer.host_tensor_sharer.share_host_tensor_with_shape( + expert_id, weight_name, weight_tensor[local_slot_id]) + else: + self.layer_load_balancer.host_tensor_sharer.pre_register_host_tensor_with_shape( + expert_id, weight_name, weight_tensor.dtype, + weight_tensor[0].shape) + def load_weights(self, weights: List[Dict]): assert self._weights_created assert len(weights) == 1 diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py index 67de304471c..312b6400fe2 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py @@ -2,7 +2,7 @@ import torch -from ...model_config import ModelConfig, MoeLoadBalancerConfig +from ...model_config import ModelConfig from ...utils import Fp4QuantizedTensor from .interface import MoE, MoEWeightLoadingMode from .quantization import (FP8BlockScalesFusedMoEMethod, @@ -71,18 +71,15 @@ def __init__( assert not self.smart_router, "Smart router is not supported in TRTLLMGenFusedMoE." assert not self.use_dp, "AttentionDP is not supported in TRTLLMGenFusedMoE." - # A dummy MoeLoadBalancerConfig to generate default initial_global_assignments and initial_local_expert_ids - moe_load_balancer_config = MoeLoadBalancerConfig() - moe_load_balancer_config.setup(num_experts=num_experts, - ep_rank=self.ep_rank, - ep_size=self.ep_size) - - self.num_slots = moe_load_balancer_config.num_slots - self.initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments( - layer_idx) - self.expert_size_per_partition = moe_load_balancer_config.num_local_slots - self.slot_start = moe_load_balancer_config.slot_start - self.slot_end = moe_load_balancer_config.slot_end + self.num_slots = self.num_experts + self.expert_size_per_partition = self.num_experts // self.ep_size + self.initial_global_assignments = [ + (ep_rank * self.num_experts // self.ep_size + local_slot_id) % + self.num_experts for ep_rank in range(self.ep_size) + for local_slot_id in range(self.expert_size_per_partition) + ] + self.slot_start = self.ep_rank * self.expert_size_per_partition + self.slot_end = self.slot_start + self.expert_size_per_partition self.initial_local_expert_ids = self.initial_global_assignments[ self.slot_start:self.slot_end] assert len( diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py index b648afb1590..cb3a2b69f84 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py @@ -1,13 +1,18 @@ -import atexit +import platform import threading -from multiprocessing import shared_memory -from typing import Callable, List, Optional +from contextlib import nullcontext +from multiprocessing import resource_tracker, shared_memory +from typing import Callable, Dict, List, Optional, Tuple +import numpy as np import torch from mpi4py import MPI import tensorrt_llm import tensorrt_llm.bindings.internal.runtime as _tbr +from tensorrt_llm._torch.pyexecutor.cuda_graph_runner import is_graph_capturing +from tensorrt_llm.logger import logger +from tensorrt_llm.mapping import Mapping def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight: @@ -20,6 +25,7 @@ def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight: assert t.dim() <= 2, "t.dim() should be less than or equal to 2" shape = [1, 1] pitch = 1 + elt_size = torch.tensor([], dtype=t.dtype).element_size() if t.dim() == 2: shape[0] = t.size(0) shape[1] = t.size(1) @@ -31,8 +37,8 @@ def _tensor_to_weight(t: torch.Tensor) -> _tbr.MoeWeight: pass mw = _tbr.MoeWeight() mw.height = shape[0] - mw.width = shape[1] - mw.pitch = pitch + mw.width = shape[1] * elt_size + mw.pitch = pitch * elt_size mw.weight_ptr = t.data_ptr() return mw @@ -42,7 +48,8 @@ class HostMoeTensorSharer: A class representing a host tensor sharer. """ - def __init__(self, layer_id: int, shared_mpi_comm: MPI.Comm): + def __init__(self, layer_id: int, expert_count: int, + shared_mpi_comm: MPI.Comm): """ Initialize a HostMoeTensorSharer instance. @@ -51,11 +58,24 @@ def __init__(self, layer_id: int, shared_mpi_comm: MPI.Comm): """ self.shared_mpi_comm = shared_mpi_comm self.layer_id = layer_id + self.expert_count = expert_count self.shared_memory_base_name = None - self.host_tensor_shapes = [] + + self.local_rank = self.shared_mpi_comm.Get_rank() + self.local_size = self.shared_mpi_comm.Get_size() + + self.expert_start = self.local_rank * self.expert_count // self.local_size + self.expert_end = (self.local_rank + + 1) * self.expert_count // self.local_size + + self.name_info = {} # key is weight name, value is (dtype, shape) self.host_weights = {} - self.own_shms = {} - self.all_shms = [] + + self.own_shm = None + self.imported_shms = [] + + self.shared_tensors = {} + self.names = [] def set_shared_memory_base_name(self, shared_memory_base_name): """ @@ -66,17 +86,19 @@ def set_shared_memory_base_name(self, shared_memory_base_name): """ self.shared_memory_base_name = shared_memory_base_name - def get_shared_memory_name(self, expert_id: int, name: str): + def get_shared_memory_name(self, rank: Optional[int] = None): """ Get the shared memory name for the layer. Args: - expert_id: The ID of the expert - name: The name of the weight + rank: The rank who created the shared memory. Current rank if None """ + if rank is None: + rank = self.local_rank + assert 0 <= rank < self.local_size assert isinstance(self.shared_memory_base_name, str), "self.shared_memory_base_name must be a string" - shared_memory_name = f"{self.shared_memory_base_name}_l{self.layer_id}_e{expert_id}_{name}" + shared_memory_name = f"{self.shared_memory_base_name}_l{self.layer_id}_lr{rank}_all" return shared_memory_name def pre_register_host_tensor_with_shape(self, expert_id: int, name: str, @@ -96,7 +118,13 @@ def pre_register_host_tensor_with_shape(self, expert_id: int, name: str, """ assert len(tensor_shape ) <= 2, "tensor_shape dim must be less than or equal to 2" - self.host_tensor_shapes.append((expert_id, name, dtype, tensor_shape)) + assert 0 <= expert_id < self.expert_count + assert expert_id < self.expert_start or expert_id >= self.expert_end + if name not in self.name_info: + self.name_info[name] = (dtype, tensor_shape) + else: + assert dtype == self.name_info[name][0] and tensor_shape == self.name_info[name][1], \ + f'weights name={name}, dtype={dtype}, shape={tensor_shape}, but already registered with dtype={self.name_info[name][0]}, shape={self.name_info[name][1]}' def share_host_tensor_with_shape(self, expert_id: int, name: str, t: torch.Tensor): @@ -111,38 +139,99 @@ def share_host_tensor_with_shape(self, expert_id: int, name: str, name: The name of the weight t: The weight tensor """ + assert len( + t.shape) <= 2, "tensor_shape dim must be less than or equal to 2" assert t.is_contiguous() == True, "t.is_contiguous() must be True" - shm_name = self.get_shared_memory_name(expert_id, name) - shm = shared_memory.SharedMemory(name=shm_name, - create=True, - size=t.numel() * t.element_size()) - shm.buf[:t.numel() * t.element_size()] = t.numpy().tobytes() + assert (expert_id, name) not in self.shared_tensors.keys() + assert self.expert_start <= expert_id < self.expert_end + self.shared_tensors[(expert_id, name)] = t dtype = t.dtype tensor_shape = t.shape - t = torch.frombuffer(shm.buf, - dtype=dtype).view(tensor_shape).pin_memory() - key = (expert_id, name) - assert key not in self.host_weights.keys(), f"key={key} already exists" - self.host_weights[key] = t - self.own_shms[(expert_id, name)] = shm - self.all_shms.append(shm) - atexit.register(shm.unlink) + if name not in self.name_info: + self.name_info[name] = (dtype, tensor_shape) + else: + assert dtype == self.name_info[name][0] and tensor_shape == self.name_info[name][1], \ + f'weights name={name}, dtype={dtype}, shape={tensor_shape}, but already registered with dtype={self.name_info[name][0]}, shape={self.name_info[name][1]}' + + @staticmethod + def align_size(size: int): + return (size + 256 - 1) // 256 * 256 + + def finalize_layer_weights(self): + self.names = list(sorted(self.name_info.keys())) + assert len( + self.shared_tensors.keys()) == (self.expert_end - + self.expert_start) * len(self.names) + + total_size = 0 + for name in self.names: + dtype, shape = self.name_info[name] + for expert_id in range(self.expert_start, self.expert_end): + t = self.shared_tensors[(expert_id, name)] + assert dtype == t.dtype and shape == t.shape + data_size = t.numel() * t.element_size() + aligned_size = self.align_size(data_size) + total_size += aligned_size + + shm_name = self.get_shared_memory_name() + shm = shared_memory.SharedMemory(name=shm_name, + create=True, + size=total_size) + self.own_shm = shm + + offset = 0 + for name in self.names: + for expert_id in range(self.expert_start, self.expert_end): + t = self.shared_tensors[(expert_id, name)] + data_size = t.numel() * t.element_size() + aligned_size = self.align_size(data_size) + shm.buf[offset:offset + data_size] = t.numpy().tobytes() + dtype = t.dtype + tensor_shape = t.shape + elt_count = t.numel() + st = torch.frombuffer(shm.buf, + dtype=dtype, + offset=offset, + count=elt_count).view(tensor_shape) + key = (expert_id, name) + assert key not in self.host_weights.keys( + ), f"key={key} already exists" + self.host_weights[key] = st + offset += aligned_size + self.shared_tensors = {} def finalize_host_tensor_sharing(self, add_host_weight_fn: Callable = None): """ Finalize the host tensor sharing. """ - for expert_weight_info in self.host_tensor_shapes: - expert_id, name, dtype, tensor_shape = expert_weight_info - shm_name = self.get_shared_memory_name(expert_id, name) + for rank in range(self.local_size): + if rank == self.local_rank: + continue + + shm_name = self.get_shared_memory_name(rank) shm = shared_memory.SharedMemory(name=shm_name) - self.all_shms.append(shm) - t = torch.frombuffer(shm.buf, - dtype=dtype).view(tensor_shape).pin_memory() - key = (expert_id, name) - assert key not in self.host_weights.keys( - ), f"key={key} already exists" - self.host_weights[key] = t + self.imported_shms.append(shm) + + rank_expert_start = rank * self.expert_count // self.local_size + rank_expert_end = (rank + 1) * self.expert_count // self.local_size + + offset = 0 + for name in self.names: + dtype, shape = self.name_info[name] + elt_count = int(np.prod(shape)) + data_size = torch.tensor([], + dtype=dtype).element_size() * elt_count + aligned_size = self.align_size(data_size) + for expert_id in range(rank_expert_start, rank_expert_end): + t = torch.frombuffer(shm.buf, + dtype=dtype, + offset=offset, + count=elt_count).view(shape) + key = (expert_id, name) + assert key not in self.host_weights.keys( + ), f"key={key} already exists" + self.host_weights[key] = t + offset += aligned_size if add_host_weight_fn is not None: for key, t in self.host_weights.items(): @@ -154,8 +243,20 @@ def pre_shutdown_cleanup(self): """ Clean up the resources before C++ shutdown and barrier """ - for shm in self.all_shms: + for shm in self.imported_shms: shm.close() + resource_tracker.unregister(shm._name, "shared_memory") + self.imported_shms = None + if self.own_shm: + self.own_shm.close() + + def post_shutdown_cleanup(self): + """ + Clean up the resources before C++ shutdown and barrier + """ + if self.own_shm: + self.own_shm.unlink() + self.own_shm = None class SingleLayerMoeLoadBalancer: @@ -167,19 +268,56 @@ class SingleLayerMoeLoadBalancer: def __init__( self, single_layer_load_balancer_impl: _tbr.SingleLayerMoeLoadBalancer, - shared_mpi_comm: MPI.Comm): + shared_mpi_comm: MPI.Comm, + expert_count: int, + updates_enabled: bool = True): """ Initialize a SingleLayerMoeLoadBalancer instance. Args: single_layer_load_balancer_impl: The C++ implementation of SingleLayerMoeLoadBalancer shared_mpi_comm: The MPI communicator for shared memory + expert_count: total number of experts + updates_enabled: whether to enable weight updates """ self.single_layer_load_balancer_impl = single_layer_load_balancer_impl self.single_layer_load_balancer_ptr = single_layer_load_balancer_impl.get_pointer( ) + self.expert_count = expert_count + self.updates_enabled = updates_enabled layer_id = self.single_layer_load_balancer_impl.get_layer_id() - self.host_tensor_sharer = HostMoeTensorSharer(shared_mpi_comm, layer_id) + self.host_tensor_sharer = HostMoeTensorSharer( + layer_id, expert_count, + shared_mpi_comm) if self.updates_enabled else None + self.register_weight_fns = [] + self.to_fix_weight_fns = [] + + shared_rank = shared_mpi_comm.Get_rank() + shared_size = shared_mpi_comm.Get_size() + + load_expert_start = shared_rank * self.expert_count // shared_size + load_expert_end = min( + (shared_rank + 1) * self.expert_count // shared_size, + self.expert_count) + self.load_expert_ids = list(range(load_expert_start, load_expert_end)) + + self.statistic_flag_tensor = None + + self.cudagraph_stream = None + self.cudagraph_event = None + + def get_layer_idx(self): + return self.single_layer_load_balancer_impl.get_layer_id() + + def get_load_expert_ids(self): + assert self.updates_enabled, "should not call get_load_expert_ids when using statistic routing" + return self.load_expert_ids + + def is_static_routing(self): + return not self.updates_enabled + + def need_load_shared_weights(self): + return self.updates_enabled def set_shared_memory_base_name(self, shared_memory_base_name): """ @@ -188,8 +326,9 @@ def set_shared_memory_base_name(self, shared_memory_base_name): Args: shared_memory_base_name: The base name for the shared memory """ - self.host_tensor_sharer.set_shared_memory_base_name( - shared_memory_base_name) + if self.updates_enabled: + self.host_tensor_sharer.set_shared_memory_base_name( + shared_memory_base_name) def _add_weight_slot(self, slot_id: int, name: str, weight_slot: _tbr.MoeWeight): @@ -201,20 +340,21 @@ def _add_weight_slot(self, slot_id: int, name: str, name: The name of the weight weight_slot: The weight object """ - self.single_layer_load_balancer_impl.add_weight_slot( + self.single_layer_load_balancer_impl.add_single_weight_slot( slot_id, name, weight_slot) - def register_weight_slot(self, slot_id: int, name: str, t: torch.Tensor): + def register_weight_slot(self, local_slot_id: int, name: str, + t: torch.Tensor): """ Register a weight slot to the layer. Args: - slot_id: The ID of the slot + local_slot_id: The ID of the slot at local rank name: The name of the weight t: The weight tensor """ moe_weight = _tensor_to_weight(t) - self._add_weight_slot(slot_id, name, moe_weight) + self._add_weight_slot(local_slot_id, name, moe_weight) def _add_host_weight(self, expert_id: int, name: str, host_weight: _tbr.MoeWeight): @@ -226,7 +366,7 @@ def _add_host_weight(self, expert_id: int, name: str, name: The name of the weight host_weight: The host weight object """ - self.single_layer_load_balancer_impl.add_host_weight( + self.single_layer_load_balancer_impl.add_single_host_weight( expert_id, name, host_weight) def _add_host_weight_from_tensor(self, expert_id: int, name: str, @@ -248,46 +388,128 @@ def set_initial_weight_assignments(self, self.single_layer_load_balancer_impl.set_initial_weight_assignments( initial_weight_assignments) + def add_to_fix_weight_fn(self, + fn: Callable, + args: Tuple, + kwargs: Dict = {}): + self.to_fix_weight_fns.append((fn, args, kwargs)) + + def add_register_weight_fn(self, + fn: Callable, + args: Tuple, + kwargs: Dict = {}): + """ + Add weight register function, this function doesn't run fn directly but run all functions after model.to("cuda") + so this function can be called when model is not on GPU yet. + """ + self.register_weight_fns.append((fn, args, kwargs)) + + def fix_tensor(self, wt: torch.Tensor): + torch.ops.trtllm.migrate_to_managed(wt) + + def register_weight_slots_after_to_cuda(self): + """ + Register weights after model has been moved to cuda, should be invoked after model.to("cuda") and before finalize_model. + """ + for fn, args, kwargs in self.to_fix_weight_fns: + fn(*args, **kwargs) + + self.to_fix_weight_fns = [] + + for fn, args, kwargs in self.register_weight_fns: + fn(*args, **kwargs) + + self.register_weight_fns = [] + def py_finalize_model(self): """ Finalize the model after all layers have been added. This must be called before starting any iterations. """ - self.host_tensor_sharer.finalize_host_tensor_sharing( - self._add_host_weight_from_tensor) + if self.updates_enabled: + self.host_tensor_sharer.finalize_host_tensor_sharing( + self._add_host_weight_from_tensor) - def wait_for_gpu_stage(self) -> torch.Tensor: + def wait_for_gpu_stage(self) -> Optional[torch.Tensor]: """ Wait for the GPU stage to complete. Returns: A tensor indicating whether the stage is enabled """ - return torch.ops.trtllm.moe_load_balance_wait_gpu_stage( - self.single_layer_load_balancer_ptr) + if self.updates_enabled: + assert self.statistic_flag_tensor is None, \ + "Already has statistic_flag_tensor, should not wait." + if is_graph_capturing(): + self.cudagraph_event = torch.cuda.Event() + self.cudagraph_stream = torch.cuda.Stream() + current_stream_event = torch.cuda.Event() + current_stream_event.record(torch.cuda.current_stream()) + with torch.cuda.stream(self.cudagraph_stream): + current_stream_event.wait() + self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage( + self.single_layer_load_balancer_ptr) + self.cudagraph_event.record(self.cudagraph_stream) + else: + self.statistic_flag_tensor = torch.ops.trtllm.moe_load_balance_wait_gpu_stage( + self.single_layer_load_balancer_ptr) + return self.statistic_flag_tensor + else: + return + + def maybe_cudagraph_done_wait(self): + if self.updates_enabled: + if is_graph_capturing(): + assert self.cudagraph_event is not None, "should have cudagraph_event when capturing" + assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing" + self.cudagraph_event.wait() def set_cpu_stage(self): """ Set the CPU stage. """ - torch.ops.trtllm.moe_load_balance_set_cpu_stage( - self.single_layer_load_balancer_ptr) + if self.updates_enabled: + assert self.statistic_flag_tensor is not None, \ + "Doesn't have statistic_flag_tensor, should not set_cpu_stage." + self.statistic_flag_tensor = None + if is_graph_capturing(): + assert self.cudagraph_stream is not None, "Doesn't have cudagraph_stream, should not set_cpu_stage." + current_stream_event = torch.cuda.Event() + current_stream_event.record(torch.cuda.current_stream()) + with torch.cuda.stream(self.cudagraph_stream): + current_stream_event.wait() + torch.ops.trtllm.moe_load_balance_set_cpu_stage( + self.single_layer_load_balancer_ptr) + self.cudagraph_event.record(self.cudagraph_stream) + else: + torch.ops.trtllm.moe_load_balance_set_cpu_stage( + self.single_layer_load_balancer_ptr) + + def maybe_cudagraph_done_set_cpu_stage(self): + if self.updates_enabled: + if is_graph_capturing(): + assert self.cudagraph_event is not None, "should have cudagraph_event when capturing" + assert self.cudagraph_stream is not None, "should have cudagraph_stream when capturing" + self.cudagraph_event.wait() + self.cudagraph_stream = None + self.cudagraph_event = None def statistic(self, gathered_raw_expert_ids: torch.Tensor, - enabled: torch.Tensor, is_first_stage: bool, - is_last_stage: bool): + is_first_stage: bool, is_last_stage: bool): """ Perform statistics on the expert IDs. Args: gathered_raw_expert_ids: The gathered raw expert IDs from all ranks - enabled: A tensor indicating whether the operation is enabled is_first_stage: Whether this is the first stage is_last_stage: Whether this is the last stage """ - torch.ops.trtllm.moe_load_balance_statistic( - gathered_raw_expert_ids, enabled, - self.single_layer_load_balancer_ptr, is_first_stage, is_last_stage) + if self.updates_enabled: + assert isinstance(self.statistic_flag_tensor, torch.Tensor) + torch.ops.trtllm.moe_load_balance_statistic( + gathered_raw_expert_ids, self.statistic_flag_tensor, + self.single_layer_load_balancer_ptr, is_first_stage, + is_last_stage) def route(self, token_selected_experts: torch.Tensor, @@ -310,7 +532,15 @@ def py_pre_shutdown_cleanup(self): """ Clean up the resources before C++ shutdown and barrier """ - self.host_tensor_sharer.pre_shutdown_cleanup() + if self.updates_enabled: + self.host_tensor_sharer.pre_shutdown_cleanup() + + def py_post_shutdown_cleanup(self): + """ + Clean up the resources after C++ shutdown and barrier + """ + if self.updates_enabled: + self.host_tensor_sharer.post_shutdown_cleanup() # Global variable to store the current active MoeLoadBalancer instance @@ -346,6 +576,21 @@ def __init__(self, self.single_layer_load_balancers = [] self.shared_memory_base_name = shared_memory_base_name self._setup_mpi_comm() + self.is_shutdown = False + + self.iter_id = 0 + self.in_iter = False + + self.enable_statistic = False + self.enable_update_weights = False + + def __del__(self): + if not self.is_shutdown: + self.shutdown() + + def is_static_routing(self): + # if we don't update, then it is statistic routing. + return self.layer_updates_per_iter == 0 def _setup_mpi_comm(self): global_mpi_comm = tensorrt_llm.mpi_comm() @@ -357,6 +602,9 @@ def _setup_mpi_comm(self): f"Interesting, shared size {shared_size} is not same as local size {local_size}" self.shared_mpi_comm = shared_mpi_comm + def set_use_gpu_memcpy(self, use_gpu_memcpy: bool): + self.load_balancer_impl.set_use_gpu_memcpy(use_gpu_memcpy) + def add_layer(self, expert_count: int, top_k: int, slot_count_per_rank: int) -> SingleLayerMoeLoadBalancer: """ @@ -372,13 +620,24 @@ def add_layer(self, expert_count: int, top_k: int, """ single_layer_load_balancer_impl = self.load_balancer_impl.add_layer( expert_count, top_k, slot_count_per_rank) + updates_enabled = not self.is_static_routing() single_layer_load_balancer = SingleLayerMoeLoadBalancer( - single_layer_load_balancer_impl, self.shared_mpi_comm) + single_layer_load_balancer_impl, + self.shared_mpi_comm, + expert_count, + updates_enabled=updates_enabled) single_layer_load_balancer.set_shared_memory_base_name( self.shared_memory_base_name) self.single_layer_load_balancers.append(single_layer_load_balancer) return single_layer_load_balancer + def register_weight_slots_after_to_cuda(self): + """ + Register weights after model has been moved to cuda, should be invoked after model.to("cuda") and before finalize_model. + """ + for layer in self.single_layer_load_balancers: + layer.register_weight_slots_after_to_cuda() + def finalize_model(self): """ Finalize the model after all layers have been added. @@ -391,6 +650,7 @@ def finalize_model(self): for single_layer_load_balancer in self.single_layer_load_balancers: single_layer_load_balancer.py_finalize_model() self.load_balancer_impl.finalize_model() + torch.cuda.empty_cache() def set_warm_up_iter_count(self, iter_count: int): """ @@ -401,27 +661,33 @@ def set_warm_up_iter_count(self, iter_count: int): """ self.load_balancer_impl.set_warm_up_iter_count(iter_count) - def start_iter(self, iter_id: int, enable_statistic: bool, - enable_update_weights: bool): + def set_next_iter_info(self, enable_statistic: Optional[bool], + enable_update_weights: Optional[bool]): + if enable_statistic is not None: + self.enable_statistic = enable_statistic + if enable_update_weights is not None: + self.enable_update_weights = enable_update_weights + + def start_iter(self): """ Start a new iteration. - - Args: - iter_id: The ID of the iteration - enable_statistic: Whether to enable statistics collection - enable_update_weights: Whether to enable weight updates """ - self.load_balancer_impl.start_iter(iter_id, enable_statistic, - enable_update_weights) + assert self.in_iter == False, "already in forward" + self.in_iter = True + self.load_balancer_impl.start_iter(self.iter_id, self.enable_statistic, + self.enable_update_weights) - def end_iter(self, iter_id: int): + def end_iter(self): """ End the current iteration. Args: iter_id: The ID of the iteration to end """ - self.load_balancer_impl.end_iter(iter_id) + assert self.in_iter, "not in forward, cannot end_iter" + self.load_balancer_impl.end_iter(self.iter_id) + self.in_iter = False + self.iter_id += 1 def shutdown(self): """ @@ -432,6 +698,10 @@ def shutdown(self): self.load_balancer_impl.shutdown() # use this sync to make sure all the shm resources can be cleaned up self.shared_mpi_comm.barrier() + for single_layer_load_balancer in self.single_layer_load_balancers: + single_layer_load_balancer.py_post_shutdown_cleanup() + self.shared_mpi_comm.barrier() + self.is_shutdown = True def __repr__(self): """ @@ -482,6 +752,80 @@ def __exit__(self, exc_type, exc_val, exc_tb): return False +moe_model_arch_list = [ + 'DeepseekV3ForCausalLM', +] + + +def maybe_create_moe_load_balancer( + model_config, mapping: Optional[Mapping]) -> Optional[MoeLoadBalancer]: + ep_rank = model_config.mapping.moe_ep_rank + ep_size = model_config.mapping.moe_ep_size + model_arch = model_config.pretrained_config.architectures[0] + using_ep = mapping and mapping.moe_ep_size > 1 + in_supported_model_arch = model_arch in moe_model_arch_list + using_smart_router = mapping and mapping.moe_cluster_size > 1 + moe_load_balancer = nullcontext() + if in_supported_model_arch and using_ep and not using_smart_router and model_config.moe_load_balancer is not None: + model_config.moe_load_balancer.setup(ep_rank=ep_rank, ep_size=ep_size) + if model_config.moe_load_balancer.layer_updates_per_iter > 0: + # TODO: remove this when supported. + cpu_arch = platform.machine().lower() + assert cpu_arch == 'aarch64', "online load balancer only support aarch64, e.g. GB200 now, x86 coming soon." + + moe_load_balancer = MoeLoadBalancer( + ep_rank=ep_rank, + ep_size=ep_size, + layer_updates_per_iter=model_config.moe_load_balancer. + layer_updates_per_iter) + logger.info( + f"Created MoE LoadBalancer, layer_updates_per_iter={model_config.moe_load_balancer.layer_updates_per_iter}..." + ) + return moe_load_balancer + + +class MoeLoadBalancerIterContext: + + def __init__(self, + moe_load_balancer: Optional[MoeLoadBalancer], + enable_statistic: Optional[bool] = None, + enable_updates: Optional[bool] = None): + self.moe_load_balancer = moe_load_balancer + self.enable_statistic = enable_statistic + self.enable_updates = enable_updates + + def __enter__(self): + """ + Enter the context manager. + + Returns: + The MoeLoadBalancerIterContext instance + """ + if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing( + ): + self.moe_load_balancer.set_next_iter_info(self.enable_statistic, + self.enable_updates) + self.moe_load_balancer.start_iter() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Exit the context manager. + + Args: + exc_type: The exception type + exc_val: The exception value + exc_tb: The exception traceback + + Returns: + False to not suppress exceptions + """ + if self.moe_load_balancer is not None and not self.moe_load_balancer.is_static_routing( + ): + self.moe_load_balancer.end_iter() + return False + + def get_moe_load_balancer() -> Optional[MoeLoadBalancer]: """ Get the current active MoeLoadBalancer instance. diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index a4030b20f6c..fd376ad2eac 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -61,6 +61,14 @@ class FusedMoEMethodBase(ABC): Base class for all fused MoE methods. """ + def need_load_shared_weights(self, module): + if hasattr( + module, "layer_load_balancer" + ) and module.layer_load_balancer and module.layer_load_balancer.need_load_shared_weights( + ): + return True + return False + def create_weights(self, module: torch.nn.Module, weight_dtype: torch.dtype, w3_w1_weight_shape: tuple[int, int, int], w2_weight_shape: tuple[int, int, int]): @@ -76,12 +84,15 @@ def create_weights(self, module: torch.nn.Module, weight_dtype: torch.dtype, requires_grad=False) module.register_parameter("w2_weight", w2_weight) - def load_weights(self, module: torch.nn.Module, weights: List[Dict], - weight_loading_mode: MoEWeightLoadingMode): + def load_expert_weights_to_dst(self, module: torch.nn.Module, + weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode, + load_expert_ids: List[int], + dst_w3_w1_weights_tensor: torch.Tensor, + dst_w2_weights_tensor: torch.Tensor): # Multithread weight load is superseded by prefetch_files() in model_engine.py # Also, threading adds overhead in order to protect shuffle index cache with critical section. - for local_slot_id, expert_id in enumerate( - module.initial_local_expert_ids): + for local_slot_id, expert_id in enumerate(load_expert_ids): # expert_idx is the local slot index of current rank expert_idx = local_slot_id @@ -101,15 +112,55 @@ def load_weights(self, module: torch.nn.Module, weights: List[Dict], ) self.load_expert_w3_w1_weight(module, w1_weight, w3_weight, - module.w3_w1_weight.data[expert_idx]) + dst_w3_w1_weights_tensor[expert_idx]) self.load_expert_w2_weight(module, w2_weight, - module.w2_weight.data[expert_idx]) + dst_w2_weights_tensor[expert_idx]) + + def load_weights(self, module: torch.nn.Module, weights: List[Dict], + weight_loading_mode: MoEWeightLoadingMode): + + self.load_expert_weights_to_dst(module, weights, weight_loading_mode, + module.initial_local_expert_ids, + module.w3_w1_weight.data, + module.w2_weight.data) self.load_quant_scales(module, weights) # Re-setup quant scales after loading weights as the tensors may have been modified. self.setup_quant_scales(module) + if self.need_load_shared_weights(module): + local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids( + ) + local_shared_w3_w1_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w3_w1_weight.data.shape[1:], + dtype=module.w3_w1_weight.data.dtype, + device='cpu') + local_shared_w2_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w2_weight.data.shape[1:], + dtype=module.w2_weight.data.dtype, + device='cpu') + self.load_expert_weights_to_dst(module, weights, + weight_loading_mode, + local_shared_load_expert_ids, + local_shared_w3_w1_tensors, + local_shared_w2_tensors) + module.register_all_parameter_slot_and_to_fix_weight_fns({ + 'w3_w1_weight': + local_shared_w3_w1_tensors, + 'w2_weight': + local_shared_w2_tensors + }) + module.layer_load_balancer.host_tensor_sharer.finalize_layer_weights( + ) + + if hasattr(module, + "layer_load_balancer") and module.layer_load_balancer: + module.layer_load_balancer.set_initial_weight_assignments( + module.initial_global_assignments) + def load_quant_scales(self, module: torch.nn.Module, weights: List[Dict]): pass @@ -828,6 +879,50 @@ def load_expert_fc2_alpha_nvfp4(self, w2_weight_scale_2, w2_weight_scale_2 = 1.0 / w2_weight_scale_2[...].reshape([]) dst_w2_alpha.copy_(1.0 / (final_fc2_input_scale * w2_weight_scale_2)) + def load_all_fp4_weight_scales_and_alphas( + self, module: torch.nn.Module, weights: Dict, + load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor, + dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor, + dst_fc2_alpha: torch.Tensor): + for local_slot_id, expert_id in enumerate(load_expert_ids): + if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: + w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] + w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] + w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] + w1_weight_scale_2 = weights[f"{expert_id}.w1.weight_scale_2"] + w3_weight_scale_2 = weights[f"{expert_id}.w3.weight_scale_2"] + w2_weight_scale_2 = weights[f"{expert_id}.w2.weight_scale_2"] + elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: + w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][ + expert_id].transpose(0, 1).contiguous() + w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk( + 2, dim=0) + w2_weight_scale = weights["down_proj_weight_scale"][ + expert_id].transpose(0, 1).contiguous() + w1_weight_scale_2 = weights["gate_up_proj_weight_scale_2"] + w3_weight_scale_2 = weights["gate_up_proj_weight_scale_2"] + w2_weight_scale_2 = weights["down_proj_weight_scale_2"] + else: + raise NotImplementedError( + f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" + ) + + expert_idx = local_slot_id + + self.load_expert_w3_w1_weight_scale_nvfp4( + module, w1_weight_scale, w3_weight_scale, + dst_w3_w1_weight_scale[expert_idx]) + self.load_expert_w2_weight_scale_nvfp4( + module, w2_weight_scale, dst_w2_weight_scale[expert_idx]) + + self.load_expert_fc31_alpha_nvfp4(w1_weight_scale_2, + w3_weight_scale_2, + module.fc31_input_scale.data, + dst_fc31_alpha[expert_idx]) + self.load_expert_fc2_alpha_nvfp4(w2_weight_scale_2, + module.fc2_input_scale.data, + dst_fc2_alpha[expert_idx]) + def load_quant_scales(self, module: torch.nn.Module, weights: Dict): # Step1: Load input scales. tmp_fc31_input_scale = torch.empty(module.num_experts, @@ -862,46 +957,50 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict): tmp_fc2_input_scale.max().reciprocal()) # Step2: Load weight block scales and alphas. - for local_slot_id, expert_id in enumerate( - module.initial_local_expert_ids): - if module.weight_loading_mode == MoEWeightLoadingMode.VANILLA: - w1_weight_scale = weights[f"{expert_id}.w1.weight_scale"] - w3_weight_scale = weights[f"{expert_id}.w3.weight_scale"] - w2_weight_scale = weights[f"{expert_id}.w2.weight_scale"] - w1_weight_scale_2 = weights[f"{expert_id}.w1.weight_scale_2"] - w3_weight_scale_2 = weights[f"{expert_id}.w3.weight_scale_2"] - w2_weight_scale_2 = weights[f"{expert_id}.w2.weight_scale_2"] - elif module.weight_loading_mode == MoEWeightLoadingMode.FUSED_GATE_UP_PROJ: - w1_w3_weight_scale = weights["gate_up_proj_weight_scale"][ - expert_id].transpose(0, 1).contiguous() - w1_weight_scale, w3_weight_scale = w1_w3_weight_scale.chunk( - 2, dim=0) - w2_weight_scale = weights["down_proj_weight_scale"][ - expert_id].transpose(0, 1).contiguous() - w1_weight_scale_2 = weights["gate_up_proj_weight_scale_2"] - w3_weight_scale_2 = weights["gate_up_proj_weight_scale_2"] - w2_weight_scale_2 = weights["down_proj_weight_scale_2"] - else: - raise NotImplementedError( - f"Unknown weight loading mode in MoE: {module.weight_loading_mode}" - ) - - expert_idx = local_slot_id - - self.load_expert_w3_w1_weight_scale_nvfp4( - module, w1_weight_scale, w3_weight_scale, - module.w3_w1_weight_scale.data[expert_idx]) - self.load_expert_w2_weight_scale_nvfp4( - module, w2_weight_scale, - module.w2_weight_scale.data[expert_idx]) - - self.load_expert_fc31_alpha_nvfp4( - w1_weight_scale_2, w3_weight_scale_2, - module.fc31_input_scale.data, - module.fc31_alpha.data[expert_idx]) - self.load_expert_fc2_alpha_nvfp4(w2_weight_scale_2, - module.fc2_input_scale.data, - module.fc2_alpha.data[expert_idx]) + self.load_all_fp4_weight_scales_and_alphas( + module, weights, module.initial_local_expert_ids, + module.w3_w1_weight_scale.data, module.w2_weight_scale.data, + module.fc31_alpha.data, module.fc2_alpha.data) + + # Step 3: if need load into shared + if self.need_load_shared_weights(module): + local_shared_load_expert_ids = module.layer_load_balancer.get_load_expert_ids( + ) + local_shared_w3_w1_scale_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w3_w1_weight_scale.data.shape[1:], + dtype=module.w3_w1_weight_scale.data.dtype, + device='cpu') + local_shared_w2_scale_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.w2_weight_scale.data.shape[1:], + dtype=module.w2_weight_scale.data.dtype, + device='cpu') + local_shared_fc31_alpha_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.fc31_alpha.data.shape[1:], + dtype=module.fc31_alpha.data.dtype, + device='cpu') + local_shared_fc2_alpha_tensors = torch.empty( + (len(local_shared_load_expert_ids), ) + + module.fc2_alpha.data.shape[1:], + dtype=module.fc2_alpha.data.dtype, + device='cpu') + self.load_all_fp4_weight_scales_and_alphas( + module, weights, local_shared_load_expert_ids, + local_shared_w3_w1_scale_tensors, local_shared_w2_scale_tensors, + local_shared_fc31_alpha_tensors, local_shared_fc2_alpha_tensors) + + module.register_all_parameter_slot_and_to_fix_weight_fns({ + 'w3_w1_weight_scale': + local_shared_w3_w1_scale_tensors, + 'w2_weight_scale': + local_shared_w2_scale_tensors, + 'fc31_alpha': + local_shared_fc31_alpha_tensors, + 'fc2_alpha': + local_shared_fc2_alpha_tensors, + }) def setup_quant_scales(self, module: torch.nn.Module): module.quant_scales = FusedMoEQuantScalesNVFP4( diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index ddba0c9d6a0..c438fb8dd1b 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -1,5 +1,6 @@ import bisect import contextlib +import functools import gc import glob import inspect @@ -11,6 +12,7 @@ import weakref from abc import ABC, abstractmethod from collections import defaultdict +from contextlib import contextmanager from typing import Any, Dict, List, Optional, Tuple import psutil @@ -47,6 +49,8 @@ from ..models import AutoModelForCausalLM from ..models.modeling_utils import (DecoderModelForCausalLM, MetaInitMode, timing) +from ..modules.fused_moe.moe_load_balancer import ( + MoeLoadBalancer, MoeLoadBalancerIterContext, maybe_create_moe_load_balancer) from ..speculative import SpecConfig, SpecMetadata, get_spec_metadata from ..utils import (get_model_extra_attrs, set_torch_compiling, with_model_extra_attrs) @@ -323,6 +327,8 @@ def __init__( # py_executor.py for how this is used. self.last_spec_metadata = None + self.in_warmup = False + self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures( ) @@ -470,6 +476,25 @@ def set_lora_model_config(self, lora_target_modules: list[str], hidden_size=self.model.config.hidden_size, dtype=torch_dtype_to_str(self.model.config.torch_dtype)) + @contextmanager + def set_warmup_flag(self): + self.in_warmup = True + try: + yield + finally: + self.in_warmup = False + + @staticmethod + def with_warmup_flag(method): + + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + with self.set_warmup_flag(): + return method(self, *args, **kwargs) + + return wrapper + + @with_warmup_flag def warmup(self, resource_manager: ResourceManager) -> None: kv_cache_manager = resource_manager.get_resource_manager( self.kv_cache_manager_key) @@ -958,7 +983,8 @@ def _load_model(self, getattr(config.pretrained_config, sub_config).num_hidden_layers = num_layers - with timing("Model init total"): + with timing("Model init total"), maybe_create_moe_load_balancer( + config, self.mapping) as moe_load_balancer: try: with MetaInitMode(): model = AutoModelForCausalLM.from_config(config) @@ -1013,6 +1039,13 @@ def init_meta_tensor(t: torch.Tensor): raise NotImplementedError( f"No load support for load format: {load_format}") + if isinstance(moe_load_balancer, MoeLoadBalancer): + setattr(self, "moe_load_balancer", moe_load_balancer) + moe_load_balancer.register_weight_slots_after_to_cuda() + logger.info("moe_load_balancer finalizing model...") + moe_load_balancer.finalize_model() + logger.info("moe_load_balancer finalize model done") + torch.cuda.current_stream().synchronize() return model @@ -1946,6 +1979,15 @@ def forward(self, else: spec_metadata = None + moe_load_balancer = None + if hasattr(self, 'moe_load_balancer'): + moe_load_balancer = getattr(self, 'moe_load_balancer') + if not self.in_warmup: + moe_enable_statistic = True + moe_enable_update = True + moe_load_balancer.set_next_iter_info(moe_enable_statistic, + moe_enable_update) + if kv_cache_manager is None: inputs, gather_ids = self._prepare_tp_inputs_no_cache( scheduled_requests, attn_metadata, spec_metadata) @@ -1953,7 +1995,9 @@ def forward(self, inputs.update(extra_model_inputs) self.last_spec_metadata = spec_metadata - return self._forward_step(inputs, gather_ids, gather_context_logits) + with MoeLoadBalancerIterContext(moe_load_balancer): + return self._forward_step(inputs, gather_ids, + gather_context_logits) with self._maybe_pad_batch(scheduled_requests, kv_cache_manager) as scheduled_requests: @@ -1980,21 +2024,32 @@ def forward(self, self.iter_counter += 1 if maybe_graph is None: - outputs = self._forward_step(inputs, gather_ids, - gather_context_logits) + with MoeLoadBalancerIterContext(moe_load_balancer): + outputs = self._forward_step(inputs, gather_ids, + gather_context_logits) else: if maybe_graph.needs_capture(): + + def capture_forward_fn(inputs: Dict[str, Any]): + with MoeLoadBalancerIterContext(moe_load_balancer): + return self._forward_step( + inputs, + gather_ids=gather_ids, + gather_context_logits=gather_context_logits) + pool = maybe_graph.capture( - lambda inputs: self._forward_step( - inputs, - gather_ids=gather_ids, - gather_context_logits=gather_context_logits), + capture_forward_fn, self._cuda_graph_mem_pool, extra_model_inputs, ) self._cuda_graph_mem_pool = pool - outputs = maybe_graph.run(inputs, extra_model_inputs) + # here we don't need to use context since cuda graph capture didn't run kernel. + # maybe we need a cleaner way to do this. + outputs = maybe_graph.run(inputs, extra_model_inputs) + else: + with MoeLoadBalancerIterContext(moe_load_balancer): + outputs = maybe_graph.run(inputs, extra_model_inputs) # Note: To overlap the CPU and GPU computation as much as possible, # guided_decoder.build should be called immediately after the launch of the single step; diff --git a/tensorrt_llm/executor/serialization.py b/tensorrt_llm/executor/serialization.py index a3bd47ea6a1..31fd48faf8b 100644 --- a/tensorrt_llm/executor/serialization.py +++ b/tensorrt_llm/executor/serialization.py @@ -18,6 +18,7 @@ "llmapi.run_llm_with_postproc": ["perform_faked_oai_postprocess" ], # only used in tests ### starting import of torch models classes. They are used in test_llm_multi_gpu.py. + "tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"], "tensorrt_llm._torch.models.modeling_bert": ["BertForSequenceClassification"], "tensorrt_llm._torch.models.modeling_clip": ["CLIPVisionModel"], @@ -48,6 +49,7 @@ "tensorrt_llm._torch.models.modeling_siglip": ["SiglipVisionModel"], "tensorrt_llm._torch.models.modeling_vila": ["VilaModel"], ### ending import of torch models classes + "tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"], "tensorrt_llm._torch.pyexecutor.llm_request": ["LogitsStorage", "PyResult", "LlmResult", "LlmResponse", "LogProbStorage"], "tensorrt_llm._torch.speculative.mtp": ["MTPConfig"], @@ -58,13 +60,13 @@ ["ClusterInfo", "MathThroughput"], "tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig", "LoadFormat"], "tensorrt_llm.bindings.executor": [ - "BatchingType", "CapacitySchedulerPolicy", "ContextPhaseParams", + "BatchingType", "CacheTransceiverConfig", "CapacitySchedulerPolicy", + "ContextPhaseParams", "ContextChunkingPolicy", "DynamicBatchConfig", "ExecutorConfig", "ExtendedRuntimePerfKnobConfig", "Response", "Result", "FinishReason", "KvCacheConfig", "KvCacheTransferMode", "KvCacheRetentionConfig", "KvCacheRetentionConfig.TokenRangeRetentionConfig", "PeftCacheConfig", - "SchedulerConfig", "DynamicBatchConfig", "ContextChunkingPolicy", - "CacheTransceiverConfig" + "SchedulerConfig" ], "tensorrt_llm._torch.pyexecutor.config": ["PyTorchConfig"], "tensorrt_llm._torch.model_config": ["MoeLoadBalancerConfig"], diff --git a/tests/unittest/_torch/modules/test_moe_host_sharer.py b/tests/unittest/_torch/modules/test_moe_host_sharer.py index 5bc798410a1..0ed0ee609bb 100644 --- a/tests/unittest/_torch/modules/test_moe_host_sharer.py +++ b/tests/unittest/_torch/modules/test_moe_host_sharer.py @@ -78,6 +78,11 @@ def test_host_tensor_sharing_basic(self): size = comm.Get_size() layer_id = 0 + # Test tensor parameters + experts_per_rank = 2 # Each rank is responsible for 2 consecutive experts + expert_count = size * experts_per_rank + tensor_shape = (16, 32) # Use 2D tensor for testing + # Maximum supported ranks (can adjust as needed) max_ranks = 8 if size > max_ranks: @@ -87,17 +92,12 @@ def test_host_tensor_sharing_basic(self): shared_comm = comm.Split_type(split_type=MPI.COMM_TYPE_SHARED) # Initialize HostMoeTensorSharer - sharer = HostMoeTensorSharer(layer_id, shared_comm) + sharer = HostMoeTensorSharer(layer_id, expert_count, shared_comm) # Set shared memory base name shared_memory_base_name = "test_host_sharer" sharer.set_shared_memory_base_name(shared_memory_base_name) - # Test tensor parameters - experts_per_rank = 2 # Each rank is responsible for 2 consecutive experts - expert_count = size * experts_per_rank - tensor_shape = (16, 32) # Use 2D tensor for testing - # Calculate the range of experts this rank is responsible for start_expert_id = rank * experts_per_rank end_expert_id = start_expert_id + experts_per_rank @@ -124,6 +124,8 @@ def test_host_tensor_sharing_basic(self): sharer.pre_register_host_tensor_with_shape( expert_id, "weight", torch.float32, tensor_shape) + sharer.finalize_layer_weights() + # Ensure all processes have created and registered their tensors comm.Barrier() diff --git a/tests/unittest/_torch/modules/test_moe_load_balancer.py b/tests/unittest/_torch/modules/test_moe_load_balancer.py index 093d46f8aa3..a49c1a4eaea 100644 --- a/tests/unittest/_torch/modules/test_moe_load_balancer.py +++ b/tests/unittest/_torch/modules/test_moe_load_balancer.py @@ -2,10 +2,11 @@ from unittest.mock import MagicMock, patch import torch +from mpi4py import MPI from tensorrt_llm._torch.modules.fused_moe.moe_load_balancer import ( - MoeLoadBalancer, SingleLayerMoeLoadBalancer, get_moe_load_balancer, - moe_load_balancer_add_single_layer) + MoeLoadBalancer, MoeLoadBalancerIterContext, SingleLayerMoeLoadBalancer, + get_moe_load_balancer, moe_load_balancer_add_single_layer) class TestMoeLoadBalancer(unittest.TestCase): @@ -186,7 +187,9 @@ def test_single_layer_moe_load_balancer_methods(self, # Setup mock_single_layer_impl = MagicMock() - layer = SingleLayerMoeLoadBalancer(mock_single_layer_impl, None) + layer = SingleLayerMoeLoadBalancer(mock_single_layer_impl, + MPI.COMM_WORLD, + expert_count=4) # Mock out torch.ops.trtllm functions with patch('torch.ops.trtllm.moe_load_balance_wait_gpu_stage') as mock_wait, \ @@ -198,13 +201,13 @@ def test_single_layer_moe_load_balancer_methods(self, # add_weight_slot mock_weight = MagicMock() layer._add_weight_slot(1, "weight1", mock_weight) - mock_single_layer_impl.add_weight_slot.assert_called_once_with( + mock_single_layer_impl.add_single_weight_slot.assert_called_once_with( 1, "weight1", mock_weight) # add_host_weight mock_host_weight = MagicMock() layer._add_host_weight(2, "weight2", mock_host_weight) - mock_single_layer_impl.add_host_weight.assert_called_once_with( + mock_single_layer_impl.add_single_host_weight.assert_called_once_with( 2, "weight2", mock_host_weight) # set_initial_weight_assignments @@ -215,7 +218,8 @@ def test_single_layer_moe_load_balancer_methods(self, # wait_for_gpu_stage mock_wait.return_value = torch.tensor([1]) - result = layer.wait_for_gpu_stage() + layer.wait_for_gpu_stage() + result = layer.statistic_flag_tensor mock_wait.assert_called_once_with( mock_single_layer_impl.get_pointer()) self.assertEqual(result, mock_wait.return_value) @@ -228,7 +232,8 @@ def test_single_layer_moe_load_balancer_methods(self, # statistic mock_expert_ids = torch.tensor([[0, 1], [2, 3]]) mock_enabled = torch.tensor([1]) - layer.statistic(mock_expert_ids, mock_enabled, True, False) + layer.statistic_flag_tensor = mock_enabled + layer.statistic(mock_expert_ids, True, False) mock_statistic.assert_called_once_with( mock_expert_ids, mock_enabled, mock_single_layer_impl.get_pointer(), True, False) @@ -237,8 +242,6 @@ def test_single_layer_moe_load_balancer_methods(self, mock_selected_experts = torch.tensor([[0, 1], [2, 3]]) mock_route.return_value = torch.tensor([[0, 1], [2, 3]]) result = layer.route(mock_selected_experts) - mock_route.assert_called_once_with( - mock_selected_experts, mock_single_layer_impl.get_pointer()) assert torch.equal(result, mock_route.return_value) @patch('tensorrt_llm.bindings.internal.runtime.MoeLoadBalancer') @@ -260,14 +263,13 @@ def test_moe_load_balancer_lifecycle_methods(self, mock_load_balancer_impl): mock_load_balancer_impl.return_value.set_warm_up_iter_count.assert_called_once_with( 10) - # start_iter - balancer.start_iter(1, True, True) - mock_load_balancer_impl.return_value.start_iter.assert_called_once_with( - 1, True, True) + balancer.set_next_iter_info(True, True) + + with MoeLoadBalancerIterContext(balancer): + mock_load_balancer_impl.return_value.start_iter.assert_called_once_with( + 0, True, True) - # end_iter - balancer.end_iter(1) - mock_load_balancer_impl.return_value.end_iter.assert_called_once_with(1) + mock_load_balancer_impl.return_value.end_iter.assert_called_once_with(0) # shutdown balancer.shutdown() @@ -288,6 +290,8 @@ def test_real_statistic_kernel(self): # Create a real MoeLoadBalancer balancer = MoeLoadBalancer(ep_rank, ep_size, 1) + balancer.set_use_gpu_memcpy(True) + # Add a layer with initial weight assignments # Each slot is assigned to exactly one expert initially layer = balancer.add_layer(expert_count, top_k, slots_per_rank) @@ -297,9 +301,8 @@ def test_real_statistic_kernel(self): # Finalize the model balancer.finalize_model() - # Start iteration - enable statistic, disable weight update - iter_id = 0 - balancer.start_iter(iter_id, True, False) + # enable statistic, disable weight update + balancer.set_next_iter_info(True, False) # Create sample token data - each token selects 2 experts # 4 tokens, each selecting 2 experts @@ -314,17 +317,15 @@ def test_real_statistic_kernel(self): device="cuda") try: - # Wait for GPU stage and get enabled flag - enabled = layer.wait_for_gpu_stage() + with MoeLoadBalancerIterContext(balancer): + # Wait for GPU stage and get enabled flag + layer.wait_for_gpu_stage() - # Run statistic - just test it runs without error - layer.statistic(gathered_raw_expert_ids, enabled, True, True) - - # Set CPU stage to signal completion - layer.set_cpu_stage() + # Run statistic - just test it runs without error + layer.statistic(gathered_raw_expert_ids, True, True) - # End iteration - balancer.end_iter(iter_id) + # Set CPU stage to signal completion + layer.set_cpu_stage() # Test passed if we got here without exceptions self.assertTrue(True, "Statistic kernel ran successfully") @@ -350,6 +351,8 @@ def test_real_routing_kernel(self): # Create a real MoeLoadBalancer balancer = MoeLoadBalancer(ep_rank, ep_size, 1) + balancer.set_use_gpu_memcpy(True) + # Add a layer with known initial weight assignments layer = balancer.add_layer(expert_count, top_k, slots_per_rank) @@ -360,9 +363,8 @@ def test_real_routing_kernel(self): # Finalize the model balancer.finalize_model() - # Start iteration - enable statistic, disable weight update - iter_id = 0 - balancer.start_iter(iter_id, True, False) + # enable statistic, disable weight update + balancer.set_next_iter_info(True, False) # Create sample token data - tokens selecting different experts token_selected_experts = torch.tensor( @@ -376,17 +378,15 @@ def test_real_routing_kernel(self): device="cuda") try: - # Wait for GPU stage - layer.wait_for_gpu_stage() + with MoeLoadBalancerIterContext(balancer): + # Wait for GPU stage + layer.wait_for_gpu_stage() - # Run routing - routed_slots = layer.route(token_selected_experts) - - # Set CPU stage - layer.set_cpu_stage() + # Run routing + routed_slots = layer.route(token_selected_experts) - # End iteration - balancer.end_iter(iter_id) + # Set CPU stage + layer.set_cpu_stage() # Verify results - with our initial assignment, expert i should map to slot i expected_slots = torch.tensor( diff --git a/tests/unittest/bindings/test_bindings_moe.py b/tests/unittest/bindings/test_bindings_moe.py index ccc4c217dd0..58b7482e302 100644 --- a/tests/unittest/bindings/test_bindings_moe.py +++ b/tests/unittest/bindings/test_bindings_moe.py @@ -136,6 +136,8 @@ def test_single_layer_moe_load_balancer_operations(self): ep_size=self.ep_size, layer_updates_per_iter=self.layer_updates_per_iter) + balancer.set_use_gpu_memcpy(True) + # Add a layer layer = balancer.add_layer(expert_count=self.expert_count, top_k=self.top_k, @@ -206,6 +208,8 @@ def test_moe_load_balancer_multiple_layers(self): ep_size=self.ep_size, layer_updates_per_iter=self.layer_updates_per_iter) + balancer.set_use_gpu_memcpy(True) + # Create initial weight assignments initial_assignments = [] for r in range(self.ep_size):