Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .devcontainer/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
version: "3.9"
services:
tensorrt_llm-dev:
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202505211401-4539
image: urm.nvidia.com/sw-tensorrt-docker/tensorrt-llm:pytorch-25.04-py3-x86_64-ubuntu24.04-trt10.10.0.31-skip-tritondevel-202506021004-9420
network_mode: host
ipc: host

Expand Down
17 changes: 16 additions & 1 deletion cpp/conanfile.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import sys

from conan import ConanFile
from conan.tools.cmake import CMakeDeps, CMakeToolchain

Expand All @@ -9,10 +12,22 @@ class TensorRT_LLM(ConanFile):
virtualrunenv = False

def requirements(self):
pass # TODO add dependencies here
self.requires("libnuma/system")

def generate(self):
cmake = CMakeDeps(self)
cmake.generate()
tc = CMakeToolchain(self)
tc.generate()

def build_requirements(self):
# register libnuma_conan.py for conan
base_dir = os.path.dirname(os.path.abspath(__file__))
libnuma_path = os.path.join(base_dir, "libnuma_conan.py")
conan_bin = os.path.abspath(sys.argv[0])
if not os.path.isfile(conan_bin) or not os.access(conan_bin, os.X_OK):
raise RuntimeError(f"Conan binary not found {sys.argv[0]}")

self.run(
f"{conan_bin} export {libnuma_path} --name=libnuma --version=system"
)
36 changes: 36 additions & 0 deletions cpp/libnuma_conan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from conan import ConanFile


class LibnumaSystemConan(ConanFile):
name = "libnuma"
version = "system"
package_type = "shared-library"
settings = "os", "arch"

def package_info(self):
if self.settings.os == "Windows":
self.output.info("libnuma not needed on Windows.")
return

self.cpp_info.includedirs = ["/usr/include"]
libdirs = []

arch = str(self.settings.arch)
os_name = str(self.settings.os)

if os_name == "Linux":
if arch == "x86_64":
libdirs.append("/usr/lib/x86_64-linux-gnu")
elif arch in ["armv8", "aarch64"]:
libdirs.append("/usr/lib/aarch64-linux-gnu")
else:
self.output.warn(
f"Unrecognized architecture: {arch}, falling back to /usr/lib"
)
libdirs.append("/usr/lib")
else:
self.output.warn(f"Unsupported OS: {os_name}, assuming /usr/lib")
libdirs.append("/usr/lib")

self.cpp_info.libdirs = libdirs
self.cpp_info.system_libs = ["numa"]
172 changes: 168 additions & 4 deletions cpp/tensorrt_llm/kernels/moeLoadBalance/moeLoadBalanceKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -245,12 +245,166 @@ void moeStatisticDevice(MoeLoadBalanceMetaInfo metaInfo, MoeLoadBalanceStatistic
}
}

template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
__global__ void moeComputeRouteNoRedundantKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo,
int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount)
{
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
int expertIds[ITEM_PER_THREAD];
int slotIds[ITEM_PER_THREAD];
for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT)
{
sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId];
}

int blockOffset = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD;

for (; blockOffset < tokenCount * metaInfo.topK; blockOffset += gridDim.x * THREAD_COUNT * ITEM_PER_THREAD)
{
int tokenIdxBase = blockOffset + threadIdx.x;
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
expertIds[i]
= tokenIdx < tokenCount * metaInfo.topK ? tokenSelectedExperts[tokenIdx] : metaInfo.expertCount;
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
if (expertIds[i] < 0 || expertIds[i] >= metaInfo.expertCount)
{
expertIds[i] = metaInfo.expertCount;
}
}
if (blockOffset == blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD)
{
__syncthreads();
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
slotIds[i] = expertIds[i] < metaInfo.expertCount ? sharedGlobalSlotIdsInfo[expertIds[i]]
: metaInfo.epSize * metaInfo.slotCountPerRank;
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
if (tokenIdx < tokenCount * metaInfo.topK)
{
tokenRoutedSlotIds[tokenIdx] = slotIds[i];
}
}
}
}

template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
__global__ void moeComputeRouteKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo,
int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount, bool offsetByEpRank)
{
int warpId = threadIdx.x / 32;
int laneId = threadIdx.x % 32;
static int const kWarpCount = THREAD_COUNT / 32;
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];
__shared__ int sharedExpertReplicaCountAndStartOffset[MAX_EXPERT_COUNT];

__shared__ int sharedArbitrateExpertId[THREAD_COUNT * ITEM_PER_THREAD];
__shared__ int sharedExpertCount[MAX_EXPERT_COUNT];
for (int expertIdx = threadIdx.x; expertIdx < metaInfo.expertCount; expertIdx += THREAD_COUNT)
{
int replicaCount = placementInfo.expertReplicaCount[expertIdx];
int replicaStartOffset = placementInfo.expertReplicaStartOffset[expertIdx];
sharedExpertReplicaCountAndStartOffset[expertIdx] = (replicaCount << 16) | replicaStartOffset;
sharedExpertCount[expertIdx] = 0;
}
for (int slotId = threadIdx.x; slotId < metaInfo.epSize * metaInfo.slotCountPerRank; slotId += THREAD_COUNT)
{
sharedGlobalSlotIdsInfo[slotId] = placementInfo.globalSlotIds[slotId];
}

int expertIds[ITEM_PER_THREAD];
int tokenIdxBase = blockIdx.x * THREAD_COUNT * ITEM_PER_THREAD + threadIdx.x;
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
expertIds[i] = tokenIdx < tokenCount * metaInfo.topK ? tokenSelectedExperts[tokenIdx] : metaInfo.expertCount;
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
if (expertIds[i] < 0 || expertIds[i] >= metaInfo.expertCount)
{
expertIds[i] = metaInfo.expertCount;
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int countAndStart
= expertIds[i] < metaInfo.expertCount ? sharedExpertReplicaCountAndStartOffset[expertIds[i]] : (1 << 16);
int arbitrateExpertId = (countAndStart >> 16) > 1 ? expertIds[i] : metaInfo.expertCount;
sharedArbitrateExpertId[threadIdx.x + i * THREAD_COUNT] = arbitrateExpertId;
}
__syncthreads();
int baseOffset = blockIdx.x + (offsetByEpRank ? metaInfo.epRank : 0);
if (warpId == 0)
{
#pragma unroll
for (int i = 0; i < kWarpCount * ITEM_PER_THREAD; ++i)
{
int expertId = sharedArbitrateExpertId[laneId + i * 32];
__syncwarp();
unsigned match = __match_any_sync(0xFFFFFFFF, expertId);
int leader = __ffs(match) - 1;
int matchCount = __popc(match);
int oldVal = 0;
if (laneId == leader && expertId < metaInfo.expertCount)
{
oldVal = atomicAdd_block(&sharedExpertCount[expertId], matchCount);
}
__syncwarp();
oldVal = __shfl_sync(0XFFFFFFFF, oldVal, leader);
unsigned lowerMask = match & ((1u << laneId) - 1);
int rankInGroup = __popc(lowerMask);
int offset = oldVal + rankInGroup;
offset += baseOffset;
sharedArbitrateExpertId[laneId + i * 32] = offset;
}
}
__syncthreads();
int targetGlobalSlotId[ITEM_PER_THREAD];
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int countAndStart
= expertIds[i] < metaInfo.expertCount ? sharedExpertReplicaCountAndStartOffset[expertIds[i]] : (1 << 16);
int count = countAndStart >> 16;
int offset = countAndStart & 0xFFFF;
int arbitratedIndex = sharedArbitrateExpertId[threadIdx.x + i * THREAD_COUNT];
offset += arbitratedIndex % count;
targetGlobalSlotId[i] = expertIds[i] < metaInfo.expertCount ? sharedGlobalSlotIdsInfo[offset]
: metaInfo.epSize * metaInfo.slotCountPerRank;
}
#pragma unroll
for (int i = 0; i < ITEM_PER_THREAD; i++)
{
int tokenIdx = tokenIdxBase + i * THREAD_COUNT;
if (tokenIdx < tokenCount * metaInfo.topK)
{
tokenRoutedSlotIds[tokenIdx] = targetGlobalSlotId[i];
}
}
}

template <int MAX_EXPERT_COUNT = 1024, int THREAD_COUNT = 256, int ITEM_PER_THREAD = 4>
__global__ void moeComputeRouteSortKernel(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo placementInfo,
int* const tokenSelectedExperts, int* tokenRoutedSlotIds, int tokenCount, bool offsetByEpRank)
{
using BlockSort = cub::BlockRadixSort<int, THREAD_COUNT, 1>;
extern __shared__ int sharedGlobalSlotIdsInfo[];
extern __shared__ int16_t sharedGlobalSlotIdsInfo[];

__shared__ typename BlockSort::TempStorage tempStorage;

Expand Down Expand Up @@ -361,9 +515,19 @@ void moeComputeRouteDevice(MoeLoadBalanceMetaInfo metaInfo, MoePlacementInfo pla
constexpr int kThreadCount = 256;
constexpr int kEltPerThread = 4;
int blockCount = (tokenCount * metaInfo.topK + kThreadCount * kEltPerThread - 1) / (kThreadCount * kEltPerThread);
int dynamicShmSize = sizeof(int) * metaInfo.epSize * metaInfo.slotCountPerRank;
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
int dynamicShmSize = sizeof(int16_t) * metaInfo.epSize * metaInfo.slotCountPerRank;
if (metaInfo.expertCount == metaInfo.epSize * metaInfo.slotCountPerRank)
{
// no redundant expert, so we don't need complex routing, but just assign to the correct solt.
moeComputeRouteNoRedundantKernel<1024, kThreadCount, kEltPerThread>
<<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount);
}
else
{
moeComputeRouteKernel<1024, kThreadCount, kEltPerThread><<<blockCount, kThreadCount, dynamicShmSize, stream>>>(
metaInfo, placementInfo, tokenSelectedExperts, tokenRoutedSlotIds, tokenCount, offsetByEpRank);
}
}

void moeWaitSignalForCpuStageHost(MoeLoadBalanceSingleLayerSignal* signal)
Expand Down
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/pybind/runtime/moeBindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

#include "moeBindings.h"
#include "tensorrt_llm/runtime/moeLoadBalancer.h"
#include "tensorrt_llm/runtime/moeLoadBalancer/moeLoadBalancer.h"
#include <pybind11/functional.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
Expand Down Expand Up @@ -98,6 +98,8 @@ void initMoeBindings(pybind11::module_& m)
py::class_<tr::MoeLoadBalancer>(m, "MoeLoadBalancer")
.def(py::init<int, int, int>(), py::arg("ep_rank"), py::arg("ep_size"), py::arg("layer_updates_per_iter"),
"Initialize the MoeLoadBalancer with the specified expert parallel rank, size, and update frequency")
.def("set_use_gpu_memcpy", &tr::MoeLoadBalancer::setUseGpuMemcpy, py::arg("use_gpu_memcpy"),
"Set whether to use GPU memcpy for weight updates")
.def("add_layer", &tr::MoeLoadBalancer::AddLayer, py::arg("expert_count"), py::arg("top_k"),
py::arg("slot_count_per_rank"), "Add a new MOE layer to the load balancer")
.def("finalize_model", &tr::MoeLoadBalancer::finalizeModel,
Expand Down
27 changes: 26 additions & 1 deletion cpp/tensorrt_llm/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@ set(SRCS
ipcNvlsMemory.cpp
mcastDeviceMemory.cpp
memoryCounters.cpp
moeLoadBalancer.cpp
moeLoadBalancer/moeLoadBalancer.cpp
moeLoadBalancer/topologyDetector.cpp
ncclCommunicator.cpp
promptTuningParams.cpp
runtimeKernels.cu
Expand Down Expand Up @@ -80,3 +81,27 @@ target_include_directories(runtime_src PRIVATE ${MPI_C_INCLUDE_DIRS})
if(ENABLE_MULTI_DEVICE)
target_link_libraries(runtime_src PUBLIC ${NCCL_LIB})
endif()

if(NOT WIN32)
find_package(libnuma QUIET CONFIG)

if(NOT libnuma_FOUND)
message(
STATUS "libnuma not found via Conan, falling back to system libnuma")
find_path(NUMA_INCLUDE_DIR numa.h)
find_library(NUMA_LIBRARY numa)

if(NUMA_INCLUDE_DIR AND NUMA_LIBRARY)
add_library(libnuma::libnuma UNKNOWN IMPORTED)
set_target_properties(
libnuma::libnuma
PROPERTIES IMPORTED_LOCATION "${NUMA_LIBRARY}"
INTERFACE_INCLUDE_DIRECTORIES "${NUMA_INCLUDE_DIR}")
else()
message(FATAL_ERROR "NUMA library not found, please install libnuma-dev")
endif()
else()
message(STATUS "libnuma found.")
endif()
target_link_libraries(runtime_src PUBLIC libnuma::libnuma)
endif()
Loading