@@ -319,19 +319,19 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum
319
319
}
320
320
}
321
321
322
- template <typename STEP_COMMUNICATOR_TYPE >
322
+ template <typename PipelineConfig >
323
323
class PacketPipeline
324
324
{
325
325
public:
326
326
__device__ __inline__ PacketPipeline (
327
- void * bufferBase, STEP_COMMUNICATOR_TYPE * stepCommunicator, int * sharedNewStepPtr, bool isSender)
327
+ void * bufferBase, StepCommunicatorBase * stepCommunicator, int * sharedNewStepPtr, bool isSender)
328
328
: bufferBase(bufferBase)
329
329
, stepCommunicator(stepCommunicator)
330
330
, shared_new_step(sharedNewStepPtr)
331
331
{
332
332
step = 0 ;
333
333
needRelease = false ;
334
- packetId = isSender ? 0 : PACKET_PER_STEP - 1 ;
334
+ packetId = isSender ? 0 : PipelineConfig:: PACKET_PER_STEP - 1 ;
335
335
}
336
336
337
337
__device__ __forceinline__ void * getFirstSendPacket ()
@@ -343,9 +343,10 @@ public:
343
343
{
344
344
345
345
packetId++;
346
- if (packetId < PACKET_PER_STEP)
346
+ if (packetId < PipelineConfig:: PACKET_PER_STEP)
347
347
{
348
- return acquireNewStep ? bufferBase + step * PACKET_PER_STEP * PACKET_SIZE + packetId * PACKET_SIZE
348
+ return acquireNewStep ? bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
349
+ + packetId * PipelineConfig::PACKET_SIZE
349
350
: nullptr ;
350
351
}
351
352
@@ -365,7 +366,7 @@ public:
365
366
{
366
367
step = *(shared_new_step);
367
368
packetId = 0 ;
368
- return bufferBase + step * PACKET_SIZE * PACKET_PER_STEP;
369
+ return bufferBase + step * PipelineConfig:: PACKET_SIZE * PipelineConfig:: PACKET_PER_STEP;
369
370
}
370
371
371
372
return nullptr ;
@@ -382,9 +383,10 @@ public:
382
383
__device__ __inline__ void * getNewRecvPacket ()
383
384
{
384
385
packetId++;
385
- if (packetId < PACKET_PER_STEP)
386
+ if (packetId < PipelineConfig:: PACKET_PER_STEP)
386
387
{
387
- return bufferBase + step * PACKET_PER_STEP * PACKET_SIZE + packetId * PACKET_SIZE;
388
+ return bufferBase + step * PipelineConfig::PACKET_PER_STEP * PipelineConfig::PACKET_SIZE
389
+ + packetId * PipelineConfig::PACKET_SIZE;
388
390
}
389
391
390
392
__syncthreads ();
@@ -401,7 +403,7 @@ public:
401
403
__syncthreads ();
402
404
packetId = 0 ;
403
405
step = *(shared_new_step);
404
- void * packetPtr = bufferBase + step * PACKET_SIZE * PACKET_PER_STEP;
406
+ void * packetPtr = bufferBase + step * PipelineConfig:: PACKET_SIZE * PipelineConfig:: PACKET_PER_STEP;
405
407
406
408
return packetPtr;
407
409
}
@@ -415,14 +417,14 @@ public:
415
417
}
416
418
417
419
void * bufferBase;
418
- STEP_COMMUNICATOR_TYPE * stepCommunicator;
420
+ StepCommunicatorBase * stepCommunicator;
419
421
int step;
420
422
int packetId;
421
423
bool needRelease;
422
424
int * shared_new_step;
423
425
};
424
426
425
- template <typename STEP_COMMUNICATOR_TYPE >
427
+ template <typename PipelineConfig, typename ExpertType, typename ScaleType >
426
428
__global__ void allToAllMetadataDevice (int * sendExperts, int * recvExperts, float * sendScales, float * recvScales,
427
429
int * localExpertStatics, int * gatheredExpertStatics, MoeCommWorkspace workspace, int * sendCountsCumsum,
428
430
int * localSendIndice, int * recvCountsCumsum, int * localRecvIndice, int tokenCount, int maxTokenCountPerRank,
@@ -431,22 +433,21 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
431
433
bool isSender = (blockIdx .y == 0 );
432
434
int targetRankId = blockIdx .x ;
433
435
int slotCountPerRank = slotCount / rankCount;
434
- int groupSize = topK / UNIT_SIZE;
435
- int groupId = threadIdx .x % groupSize;
436
+ int groupSize = topK / PipelineConfig::UNIT_SIZE;
436
437
437
438
__shared__ int sharedNewStep;
438
- __align__ (16 ) int experts[UNIT_SIZE];
439
- __align__ (16 ) float scales[UNIT_SIZE];
439
+ __align__ (16 ) int experts[PipelineConfig:: UNIT_SIZE];
440
+ __align__ (16 ) float scales[PipelineConfig:: UNIT_SIZE];
440
441
441
442
uint8_t * bufferBase = (uint8_t *) (workspace.getFifoBasePtr (isSender, rankId, targetRankId, 0 , 1 ));
442
- STEP_COMMUNICATOR_TYPE stepCommunicator (workspace.getFifoConnInfo (isSender, rankId, targetRankId, 0 , rankCount, 1 ));
443
- PacketPipeline<STEP_COMMUNICATOR_TYPE > pipeline (bufferBase, &stepCommunicator, &sharedNewStep, isSender);
443
+ StepCommunicatorBase stepCommunicator (workspace.getFifoConnInfo (isSender, rankId, targetRankId, 0 , rankCount, 1 ));
444
+ PacketPipeline<PipelineConfig > pipeline (bufferBase, &stepCommunicator, &sharedNewStep, isSender);
444
445
445
446
if (isSender)
446
447
{
447
448
int baseCumsum = targetRankId == 0 ? 0 : *(sendCountsCumsum + targetRankId - 1 );
448
449
int sendTokenCount = *(sendCountsCumsum + targetRankId) - baseCumsum;
449
- int unitCount = sendTokenCount * topK / UNIT_SIZE;
450
+ int unitCount = sendTokenCount * topK / PipelineConfig:: UNIT_SIZE;
450
451
451
452
void * packPtr = pipeline.getFirstSendPacket ();
452
453
int indexBase = 0 ;
@@ -457,13 +458,15 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
457
458
if (threadIdx .x < UNIT_PER_ITER)
458
459
{
459
460
int index = indexBase + threadIdx .x ;
461
+ int groupId = index % groupSize;
460
462
if (index < unitCount)
461
463
{
462
464
int tokenId = *(localSendIndice + maxTokenCountPerRank * targetRankId + (index / groupSize));
463
- *((int4 *) (experts)) = *(int4 *) (sendExperts + tokenId * topK + groupId * UNIT_SIZE);
465
+ *((ExpertType*) (experts))
466
+ = *(ExpertType*) (sendExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
464
467
465
468
#pragma unroll
466
- for (int j = 0 ; j < UNIT_SIZE; j++)
469
+ for (int j = 0 ; j < PipelineConfig:: UNIT_SIZE; j++)
467
470
{
468
471
int expertId = experts[j];
469
472
if (expertId / slotCountPerRank != targetRankId)
@@ -472,14 +475,15 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
472
475
}
473
476
}
474
477
475
- int * expertsPtr = (int *) (packPtr) + threadIdx .x * UNIT_SIZE;
476
- *((int4 *) (expertsPtr)) = *((int4 *) (experts));
478
+ int * expertsPtr = (int *) (packPtr) + threadIdx .x * PipelineConfig:: UNIT_SIZE;
479
+ *((ExpertType *) (expertsPtr)) = *((ExpertType *) (experts));
477
480
if (sendScales != nullptr )
478
481
{
479
- *((float4 *) (scales)) = *(float4 *) (sendScales + tokenId * topK + groupId * UNIT_SIZE);
480
- float * scaleBasePtr = (float *) (packPtr + SCALE_OFFSET);
481
- float * scalesPtr = (float *) (scaleBasePtr) + threadIdx .x * UNIT_SIZE;
482
- *((float4 *) (scalesPtr)) = *((float4 *) (scales));
482
+ *((ScaleType*) (scales))
483
+ = *(ScaleType*) (sendScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
484
+ float * scaleBasePtr = (float *) (packPtr + PipelineConfig::SCALE_OFFSET);
485
+ float * scalesPtr = (float *) (scaleBasePtr) + threadIdx .x * PipelineConfig::UNIT_SIZE;
486
+ *((ScaleType*) (scalesPtr)) = *((ScaleType*) (scales));
483
487
}
484
488
}
485
489
}
@@ -488,7 +492,7 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
488
492
int staticCopyIdx = threadIdx .x - UNIT_PER_ITER;
489
493
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
490
494
{
491
- int4 * staticBasePtr = (int4 *) (packPtr + STATIC_COPY_OFFSET);
495
+ int4 * staticBasePtr = (int4 *) (packPtr + PipelineConfig:: STATIC_COPY_OFFSET);
492
496
int4 staticData = *(int4 *) (localExpertStatics + staticCopyBase + staticCopyIdx * 4 );
493
497
*(staticBasePtr + staticCopyIdx) = staticData;
494
498
}
@@ -521,18 +525,21 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
521
525
if (threadIdx .x < packetUnitCount)
522
526
{
523
527
int tokenId = baseCumsum + (unitIdBase + threadIdx .x ) / groupSize;
524
- int * expertsPtr = (int *) (packetPtr) + threadIdx .x * UNIT_SIZE;
525
- *((int4 *) (experts)) = *((int4 *) (expertsPtr));
526
- int4 * dstExpertsPtr = (int4 *) (recvExperts + tokenId * topK + groupId * UNIT_SIZE);
527
- *dstExpertsPtr = *((int4 *) (experts));
528
+ int groupId = (unitIdBase + threadIdx .x ) % groupSize;
529
+ int * expertsPtr = (int *) (packetPtr) + threadIdx .x * PipelineConfig::UNIT_SIZE;
530
+ *((ExpertType*) (experts)) = *((ExpertType*) (expertsPtr));
531
+ ExpertType* dstExpertsPtr
532
+ = (ExpertType*) (recvExperts + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
533
+ *dstExpertsPtr = *((ExpertType*) (experts));
528
534
529
535
if (recvScales != nullptr )
530
536
{
531
- float * scaleBasePtr = (float *) (packetPtr + SCALE_OFFSET);
532
- float * scalesPtr = scaleBasePtr + threadIdx .x * UNIT_SIZE;
533
- *((float4 *) (scales)) = *((float4 *) (scalesPtr));
534
- float4 * dstScalesPtr = (float4 *) (recvScales + tokenId * topK + groupId * UNIT_SIZE);
535
- *dstScalesPtr = *((float4 *) (scales));
537
+ float * scaleBasePtr = (float *) (packetPtr + PipelineConfig::SCALE_OFFSET);
538
+ float * scalesPtr = scaleBasePtr + threadIdx .x * PipelineConfig::UNIT_SIZE;
539
+ *((ScaleType*) (scales)) = *((ScaleType*) (scalesPtr));
540
+ ScaleType* dstScalesPtr
541
+ = (ScaleType*) (recvScales + tokenId * topK + groupId * PipelineConfig::UNIT_SIZE);
542
+ *dstScalesPtr = *((ScaleType*) (scales));
536
543
}
537
544
}
538
545
}
@@ -541,7 +548,7 @@ __global__ void allToAllMetadataDevice(int* sendExperts, int* recvExperts, float
541
548
int staticCopyIdx = threadIdx .x - UNIT_PER_ITER;
542
549
if (staticCopyBase + staticCopyIdx * 4 < expertCount)
543
550
{
544
- int4 * staticBasePtr = (int4 *) (packetPtr + STATIC_COPY_OFFSET);
551
+ int4 * staticBasePtr = (int4 *) (packetPtr + PipelineConfig:: STATIC_COPY_OFFSET);
545
552
int4 staticData = *(staticBasePtr + staticCopyIdx);
546
553
*(int4 *) (gatheredExpertStatics + targetRankId * expertCount + staticCopyBase + staticCopyIdx * 4 )
547
554
= staticData;
@@ -630,10 +637,28 @@ void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, flo
630
637
dim3 block (block_size);
631
638
dim3 grid (rankCount, 2 );
632
639
633
- allToAllMetadataDevice<StepCommunicatorBase><<<grid, block, 0 , stream>>> (sendExperts, recvExperts, sendScales,
634
- recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum, localSendIndice,
635
- recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount, slotCount, rankId,
636
- rankCount);
640
+ if (topK % 4 == 0 )
641
+ {
642
+ using PipelineConfig = PipelineConfig<4 , 16 >;
643
+ static_assert (
644
+ PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
645
+ " FIFO size is too small" );
646
+ allToAllMetadataDevice<PipelineConfig, int4 , float4 ><<<grid, block, 0 , stream>>> (sendExperts, recvExperts,
647
+ sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
648
+ localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
649
+ slotCount, rankId, rankCount);
650
+ }
651
+ else
652
+ {
653
+ using PipelineConfig = PipelineConfig<1 , 64 >;
654
+ static_assert (
655
+ PipelineConfig::PACKET_SIZE_IN_U64 * PipelineConfig::PACKET_PER_STEP * STEP_DEPTH <= FIFO_SIZE_IN_U64,
656
+ " FIFO size is too small" );
657
+ allToAllMetadataDevice<PipelineConfig, int , float ><<<grid, block, 0 , stream>>> (sendExperts, recvExperts,
658
+ sendScales, recvScales, localExpertStatics, gatheredExpertStatics, workspace, sendCountsCumsum,
659
+ localSendIndice, recvCountsCumsum, localRecvIndice, tokenCount, maxTokenCountPerRank, topK, expertCount,
660
+ slotCount, rankId, rankCount);
661
+ }
637
662
638
663
int smCount = tensorrt_llm::common::getMultiProcessorCount ();
639
664
memsetExpertIdsDevice<<<smCount, 256 , 0 , stream>>> (
@@ -642,7 +667,7 @@ void allToAllMetadata(int* sendExperts, int* recvExperts, float* sendScales, flo
642
667
643
668
size_t getMoePrepareWorkspaceSize (int epSize)
644
669
{
645
- return (STEP_DEPTH * PACKET_PER_STEP * PACKET_SIZE + StepCommunicatorBase::META_SIZE) * epSize;
670
+ return (FIFO_SIZE_IN_U64 * 8 + StepCommunicatorBase::META_SIZE) * epSize;
646
671
}
647
672
648
673
} // namespace moe_prepare
0 commit comments