Skip to content

Commit fddb7f1

Browse files
authored
feat: moe prepare support topk % 4 != 0 (#5742)
Signed-off-by: Fred Wei <[email protected]>
1 parent eb5cb5b commit fddb7f1

File tree

3 files changed

+83
-53
lines changed

3 files changed

+83
-53
lines changed

cpp/tensorrt_llm/kernels/moePrepareKernels.cu

Lines changed: 67 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -319,19 +319,19 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum
319319
}
320320
}
321321

322-
template <typename STEP_COMMUNICATOR_TYPE>
322+
template <typename PipelineConfig>
323323
class PacketPipeline
324324
{
325325
public:
326326
__device__ __inline__ PacketPipeline(
327-
void* bufferBase, STEP_COMMUNICATOR_TYPE* stepCommunicator, int* sharedNewStepPtr, bool isSender)
327+
void* bufferBase, StepCommunicatorBase* stepCommunicator, int* sharedNewStepPtr, bool isSender)
328328
: bufferBase(bufferBase)
329329
, stepCommunicator(stepCommunicator)
330330
, shared_new_step(sharedNewStepPtr)
331331
{
332332
step = 0;
333333
needRelease = false;
334-
packetId = isSender ? 0 : PACKET_PER_STEP - 1;
334+
packetId = isSender ? 0 : PipelineConfig::PACKET_PER_STEP - 1;
335335
}
336336

337337
__device__ __forceinline__ void* getFirstSendPacket()
@@ -343,9 +343,10 @@ public:
343343
{
344344

345345
packetId++;
346-
if (packetId < PACKET_PER_STEP)
346+
if (packetId < PipelineConfig::PACKET_PER_STEP)
347347
{
348-
return acquireNewStep ? bufferBase + step * PACKET_PER_STEP * PACKET_SIZE + packetId * PACKET_SIZE
348+
return acquireNewStep ? bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
349+
+ packetId * PipelineConfig::PACKET_SIZE
349350
: nullptr;
350351
}
351352

@@ -365,7 +366,7 @@ public:
365366
{
366367
step = *(shared_new_step);
367368
packetId = 0;
368-
return bufferBase + step * PACKET_SIZE * PACKET_PER_STEP;
369+
return bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
369370
}
370371

371372
return nullptr;
@@ -382,9 +383,10 @@ public:
382383
__device__ __inline__ void* getNewRecvPacket()
383384
{
384385
packetId++;
385-
if (packetId < PACKET_PER_STEP)
386+
if (packetId < PipelineConfig::PACKET_PER_STEP)
386387
{
387-
return bufferBase + step * PACKET_PER_STEP * PACKET_SIZE + packetId * PACKET_SIZE;
388+
return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
389+
+ packetId * PipelineConfig::PACKET_SIZE;
388390
}
389391

390392
__syncthreads();
@@ -401,7 +403,7 @@ public:
401403
__syncthreads();
402404
packetId = 0;
403405
step = *(shared_new_step);
404-
void* packetPtr = bufferBase + step * PACKET_SIZE * PACKET_PER_STEP;
406+
void* packetPtr = bufferBase + step * PipelineConfig::PACKET_SIZE * PipelineConfig::PACKET_PER_STEP;
405407

406408
return packetPtr;
407409
}
@@ -415,14 +417,14 @@ public:
415417
}
416418

417419
void* bufferBase;
418-
STEP_COMMUNICATOR_TYPE* stepCommunicator;
420+
StepCommunicatorBase* stepCommunicator;
419421
int step;
420422
int packetId;
421423
bool needRelease;
422424
int* shared_new_step;
423425
};
424426

425-
template <typename STEP_COMMUNICATOR_TYPE>
427+
template <typename PipelineConfig, typename ExpertType, typename ScaleType>
426428
__global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float* sendScales, float* recvScales,
427429
int* localExpertStatics, int* gatheredExpertStatics, MoeCommWorkspace workspace, int* sendCountsCumsum,
428430
int* localSendIndice, int* recvCountsCumsum, int* localRecvIndice, int tokenCount, int maxTokenCountPerRank,
@@ -431,22 +433,21 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
431433
bool isSender = (blockIdx.y == 0);
432434
int targetRankId = blockIdx.x;
433435
int slotCountPerRank = slotCount / rankCount;
434-
int groupSize = topK / UNIT_SIZE;
435-
int groupId = threadIdx.x % groupSize;
436+
int groupSize = topK / PipelineConfig::UNIT_SIZE;
436437

437438
__shared__ int sharedNewStep;
438-
__align__(16) int experts[UNIT_SIZE];
439-
__align__(16) float scales[UNIT_SIZE];
439+
__align__(16) int experts[PipelineConfig::UNIT_SIZE];
440+
__align__(16) float scales[PipelineConfig::UNIT_SIZE];
440441

441442
uint8_t* bufferBase = (uint8_t*) (workspace.getFifoBasePtr(isSender, rankId, targetRankId, 0, 1));
442-
STEP_COMMUNICATOR_TYPE stepCommunicator(workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1));
443-
PacketPipeline<STEP_COMMUNICATOR_TYPE> pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender);
443+
StepCommunicatorBase stepCommunicator(workspace.getFifoConnInfo(isSender, rankId, targetRankId, 0, rankCount, 1));
444+
PacketPipeline<PipelineConfig> pipeline(bufferBase, &stepCommunicator, &sharedNewStep, isSender);
444445

445446
if (isSender)
446447
{
447448
int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1);
448449
int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum;
449-
int unitCount = sendTokenCount * topK / UNIT_SIZE;
450+
int unitCount = sendTokenCount * topK / PipelineConfig::UNIT_SIZE;
450451

451452
void* packPtr = pipeline.getFirstSendPacket();
452453
int indexBase = 0;
@@ -457,13 +458,15 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
457458
if (threadIdx.x < UNIT_PER_ITER)
458459
{
459460
int index = indexBase + threadIdx.x;
461+
int groupId = index % groupSize;
460462
if (index < unitCount)
461463
{
462464
int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize));
463-
*((int4*) (experts)) = *(int4*) (sendExperts + tokenId * topK + groupId * UNIT_SIZE);
465+
*((ExpertType*) (experts))
466+
= *(ExpertType*) (sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
464467

465468
#pragma unroll
466-
for (int j = 0; j < UNIT_SIZE; j++)
469+
for (int j = 0; j < PipelineConfig::UNIT_SIZE; j++)
467470
{
468471
int expertId = experts[j];
469472
if (expertId / slotCountPerRank != targetRankId)
@@ -472,14 +475,15 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
472475
}
473476
}
474477

475-
int* expertsPtr = (int*) (packPtr) + threadIdx.x * UNIT_SIZE;
476-
*((int4*) (expertsPtr)) = *((int4*) (experts));
478+
int* expertsPtr = (int*) (packPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
479+
*((ExpertType*) (expertsPtr)) = *((ExpertType*) (experts));
477480
if (sendScales != nullptr)
478481
{
479-
*((float4*) (scales)) = *(float4*) (sendScales + tokenId * topK + groupId * UNIT_SIZE);
480-
float* scaleBasePtr = (float*) (packPtr + SCALE_OFFSET);
481-
float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * UNIT_SIZE;
482-
*((float4*) (scalesPtr)) = *((float4*) (scales));
482+
*((ScaleType*) (scales))
483+
= *(ScaleType*) (sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
484+
float* scaleBasePtr = (float*) (packPtr + PipelineConfig::SCALE_OFFSET);
485+
float* scalesPtr = (float*) (scaleBasePtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
486+
*((ScaleType*) (scalesPtr)) = *((ScaleType*) (scales));
483487
}
484488
}
485489
}
@@ -488,7 +492,7 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
488492
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
489493
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
490494
{
491-
int4* staticBasePtr = (int4*) (packPtr + STATIC_COPY_OFFSET);
495+
int4* staticBasePtr = (int4*) (packPtr + PipelineConfig::STATIC_COPY_OFFSET);
492496
int4 staticData = *(int4*) (localExpertStatics + staticCopyBase + staticCopyIdx * 4);
493497
*(staticBasePtr + staticCopyIdx) = staticData;
494498
}
@@ -521,18 +525,21 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
521525
if (threadIdx.x < packetUnitCount)
522526
{
523527
int tokenId = baseCumsum + (unitIdBase + threadIdx.x) / groupSize;
524-
int* expertsPtr = (int*) (packetPtr) + threadIdx.x * UNIT_SIZE;
525-
*((int4*) (experts)) = *((int4*) (expertsPtr));
526-
int4* dstExpertsPtr = (int4*) (recvExperts + tokenId * topK + groupId * UNIT_SIZE);
527-
*dstExpertsPtr = *((int4*) (experts));
528+
int groupId = (unitIdBase + threadIdx.x) % groupSize;
529+
int* expertsPtr = (int*) (packetPtr) + threadIdx.x * PipelineConfig::UNIT_SIZE;
530+
*((ExpertType*) (experts)) = *((ExpertType*) (expertsPtr));
531+
ExpertType* dstExpertsPtr
532+
= (ExpertType*) (recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
533+
*dstExpertsPtr = *((ExpertType*) (experts));
528534

529535
if (recvScales != nullptr)
530536
{
531-
float* scaleBasePtr = (float*) (packetPtr + SCALE_OFFSET);
532-
float* scalesPtr = scaleBasePtr + threadIdx.x * UNIT_SIZE;
533-
*((float4*) (scales)) = *((float4*) (scalesPtr));
534-
float4* dstScalesPtr = (float4*) (recvScales + tokenId * topK + groupId * UNIT_SIZE);
535-
*dstScalesPtr = *((float4*) (scales));
537+
float* scaleBasePtr = (float*) (packetPtr + PipelineConfig::SCALE_OFFSET);
538+
float* scalesPtr = scaleBasePtr + threadIdx.x * PipelineConfig::UNIT_SIZE;
539+
*((ScaleType*) (scales)) = *((ScaleType*) (scalesPtr));
540+
ScaleType* dstScalesPtr
541+
= (ScaleType*) (recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
542+
*dstScalesPtr = *((ScaleType*) (scales));
536543
}
537544
}
538545
}
@@ -541,7 +548,7 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
541548
int staticCopyIdx = threadIdx.x - UNIT_PER_ITER;
542549
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
543550
{
544-
int4* staticBasePtr = (int4*) (packetPtr + STATIC_COPY_OFFSET);
551+
int4* staticBasePtr = (int4*) (packetPtr + PipelineConfig::STATIC_COPY_OFFSET);
545552
int4 staticData = *(staticBasePtr + staticCopyIdx);
546553
*(int4*) (gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + staticCopyIdx * 4)
547554
= staticData;
@@ -630,10 +637,28 @@ void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, flo
630637
dim3 block(block_size);
631638
dim3 grid(rankCount, 2);
632639

633-
allToAllMetadataDevice<StepCommunicatorBase><<<grid, block, 0, stream>>>(sendExperts, recvExperts, sendScales,
634-
recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum, localSendIndice,
635-
recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount, slotCount, rankId,
636-
rankCount);
640+
if (topK % 4 == 0)
641+
{
642+
using PipelineConfig = PipelineConfig<4, 16>;
643+
static_assert(
644+
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
645+
"FIFO size is too small");
646+
allToAllMetadataDevice<PipelineConfig, int4, float4><<<grid, block, 0, stream>>>(sendExperts, recvExperts,
647+
sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
648+
localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
649+
slotCount, rankId, rankCount);
650+
}
651+
else
652+
{
653+
using PipelineConfig = PipelineConfig<1, 64>;
654+
static_assert(
655+
PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
656+
"FIFO size is too small");
657+
allToAllMetadataDevice<PipelineConfig, int, float><<<grid, block, 0, stream>>>(sendExperts, recvExperts,
658+
sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
659+
localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
660+
slotCount, rankId, rankCount);
661+
}
637662

638663
int smCount = tensorrt_llm::common::getMultiProcessorCount();
639664
memsetExpertIdsDevice<<<smCount, 256, 0, stream>>>(
@@ -642,7 +667,7 @@ void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, flo
642667

643668
size_t getMoePrepareWorkspaceSize(int epSize)
644669
{
645-
return (STEP_DEPTH * PACKET_PER_STEP * PACKET_SIZE + StepCommunicatorBase::META_SIZE) * epSize;
670+
return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize;
646671
}
647672

648673
} // namespace moe_prepare

cpp/tensorrt_llm/kernels/moePrepareKernels.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ namespace moe_prepare
2929
{
3030

3131
#define STEP_DEPTH 2
32-
#define PACKET_PER_STEP 16
3332
#define THREADS_PER_UNIT 1
3433
#define UNIT_PER_PIPELINE 128
3534
#define PIPELINE_PER_CTA 4
@@ -39,21 +38,26 @@ namespace moe_prepare
3938
#define BYTES_COUNTER 8
4039
#define CUMSUM_THREADS_PER_BLOCK 128
4140

42-
#define UNIT_SIZE 4
4341
#define UNIT_PER_ITER 256
4442
#define STATIC_COPY_PER_ITER 128
45-
#define MAX_TOKEN_SIZE 8192
4643

47-
static constexpr int UNIT_BYTES_SIZE = EXPERT_BYTES_PER_UNIT + SCALE_BYTES_PER_UNIT;
4844
static constexpr int THREADS_PER_PIPELINE = THREADS_PER_UNIT * UNIT_PER_PIPELINE;
4945
static constexpr int THREADS_PER_CTA = THREADS_PER_PIPELINE * PIPELINE_PER_CTA;
5046

51-
static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int);
52-
static constexpr int STATIC_COPY_OFFSET = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
53-
static constexpr int PACKET_SIZE
54-
= UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float)) + STATIC_COPY_PER_ITER * 4 * sizeof(int);
55-
static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8);
56-
static constexpr int FIFO_SIZE_IN_U64 = PACKET_SIZE_IN_U64 * PACKET_PER_STEP * STEP_DEPTH;
47+
template <int UNIT_SIZE_INPUT, int PACKET_PER_STEP_INPUT>
48+
struct PipelineConfig
49+
{
50+
static constexpr int UNIT_SIZE = UNIT_SIZE_INPUT;
51+
static constexpr int PACKET_PER_STEP = PACKET_PER_STEP_INPUT;
52+
static constexpr int UNIT_BYTES_SIZE = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
53+
static constexpr int SCALE_OFFSET = UNIT_SIZE * UNIT_PER_ITER * sizeof(int);
54+
static constexpr int STATIC_COPY_OFFSET = UNIT_SIZE * UNIT_PER_ITER * (sizeof(int) + sizeof(float));
55+
static constexpr int PACKET_SIZE = UNIT_BYTES_SIZE + STATIC_COPY_PER_ITER * 4 * sizeof(int);
56+
static constexpr int PACKET_SIZE_IN_U64 = (PACKET_SIZE / 8);
57+
};
58+
59+
// 1MB FIFO size
60+
static constexpr int FIFO_SIZE_IN_U64 = 1024 * 1024 / 8;
5761

5862
#ifdef __CUDACC__
5963
#define ALIGN_256 __align__(256)

tests/unittest/_torch/thop/test_moe_alltoall.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -471,12 +471,13 @@ def test_moe_local_gather(self, ep_rank: int, ep_size: int,
471471

472472
@parameterized.expand([
473473
(0, 2, 16, 20, 8, 512),
474-
(0, 2, 16, 16, 4, 8),
474+
(0, 2, 16, 16, 3, 300),
475475
(0, 4, 20, 24, 8, 4000),
476476
(0, 8, 96, 96, 8, 1000),
477477
(3, 8, 128, 128, 8, 1000),
478478
(3, 8, 128, 144, 8, 1),
479479
(0, 4, 72, 80, 4, 2256),
480+
(0, 4, 72, 80, 6, 3333),
480481
# Hang with stream count > 8
481482
#(0, 9, 90, 8, 100),
482483
])

0 commit comments

Comments
 (0)