Skip to content

Commit a608b00

Browse files
authored
Fix mPtrExpertCounts allocation in MoE TRT-LLM backend (nvfp4) (#5519)
Signed-off-by: Christina Zhang <[email protected]>
1 parent 7f1893f commit a608b00

File tree

9 files changed

+27
-23
lines changed

9 files changed

+27
-23
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1244,6 +1244,8 @@ void run(Data const& data, void* stream)
12441244
{
12451245
// Reset the global histograms (not used in single-cluster code path).
12461246
// Cover both for the cooperative and two-kernel code paths.
1247+
TLLM_CHECK_WITH_INFO(
1248+
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
12471249
TLLM_CUDA_CHECK(cudaMemsetAsync(
12481250
data.mPtrExpertCounts, 0, static_cast<size_t>(2 * NumThreads) * sizeof(int32_t), (cudaStream_t) stream));
12491251
}

cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
#pragma once
1818

1919
#include "IntFastDiv.h"
20-
2120
#include "tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"
2221

2322
#include <cuda.h>
@@ -31,7 +30,6 @@ namespace moe::dev
3130

3231
namespace routing
3332
{
34-
3533
////////////////////////////////////////////////////////////////////////////////////////////////////
3634

3735
namespace tg = batchedGemm::trtllm::gen;

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,8 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
123123
{args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits.device(), std::nullopt);
124124
at::Tensor expert_indexes = at::detail::empty_cuda(
125125
{args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
126-
at::Tensor expert_count_histogram = at::detail::empty_cuda({((num_experts * 2 + 255) / 256) * 256},
126+
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
127+
at::Tensor expert_count_histogram = at::detail::empty_cuda({size_of_expert_count_histogram},
127128
at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts
128129
routing_logits.device(), std::nullopt);
129130

cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor
109109
{args.num_tokens, args.top_k}, routing_bias.scalar_type(), routing_logits.device(), std::nullopt);
110110
at::Tensor expert_indexes = at::detail::empty_cuda(
111111
{args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
112-
at::Tensor expert_count_histogram = at::detail::empty_cuda({2 * 256},
112+
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
113+
at::Tensor expert_count_histogram = at::detail::empty_cuda({size_of_expert_count_histogram},
113114
at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts
114115
routing_logits.device(), std::nullopt);
115116

cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -246,15 +246,14 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest<T>
246246

247247
TYPED_TEST_SUITE(RoutingDeepSeekKernelTest, Bf16Types);
248248

249-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
250249
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
251250
{
252251
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10,
253252
/*numExperts=*/128, /*topK=*/8,
254253
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
255254
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
256255
/*usePdl=*/true, /*getExpWeights=*/true,
257-
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f);
256+
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
258257
this->runTest(param);
259258
};
260259

@@ -265,20 +264,19 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
265264
/*expertParallelization=*/2, /*expertParallelizationId=*/1,
266265
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
267266
/*usePdl=*/true, /*getExpWeights=*/true,
268-
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f);
267+
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
269268
this->runTest(param);
270269
};
271270

272-
TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
271+
TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization)
273272
{
274273
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030,
275-
/*numExperts=*/128, /*topK=*/1,
274+
/*numExperts=*/128, /*topK=*/8,
276275
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
277276
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
278277
/*usePdl=*/true, /*getExpWeights=*/true,
279-
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f);
278+
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
280279
this->runTest(param);
281280
};
282-
#endif
283281

284282
} // namespace

cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,22 +141,21 @@ TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelization)
141141
/*numExperts=*/128, /*topK=*/1,
142142
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
143143
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
144-
/*usePdl=*/true, /*getExpWeights=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f);
144+
/*usePdl=*/true, /*getExpWeights=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f,
145+
/*requiredComputeCapability*/ 8);
145146
this->runTest(param);
146147
};
147148

148-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
149149
TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelization)
150150
{
151151
RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/100,
152152
/*numExperts=*/128, /*topK=*/1,
153153
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
154154
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
155155
/*usePdl=*/true, /*getExpWeights=*/true,
156-
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
156+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
157157
this->runTest(param);
158158
};
159-
#endif
160159

161160
TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelization)
162161
{
@@ -165,7 +164,7 @@ TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelization)
165164
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
166165
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
167166
/*usePdl=*/true, /*getExpWeights=*/true,
168-
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
167+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
169168
this->runTest(param);
170169
};
171170

cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,14 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest<T>
179179

180180
TYPED_TEST_SUITE(RoutingRenormalizeKernelTest, FloatAndBf16Types);
181181

182-
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
183182
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization)
184183
{
185184
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10,
186185
/*numExperts=*/128, /*topK=*/8,
187186
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
188187
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
189188
/*usePdl=*/true, /*getExpWeights=*/true,
190-
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
189+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
191190
this->runTest(param);
192191
};
193192

@@ -198,7 +197,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
198197
/*expertParallelization=*/2, /*expertParallelizationId=*/1,
199198
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
200199
/*usePdl=*/true, /*getExpWeights=*/true,
201-
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
200+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
202201
this->runTest(param);
203202
};
204203

@@ -209,10 +208,9 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
209208
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
210209
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
211210
/*usePdl=*/true, /*getExpWeights=*/true,
212-
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
211+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
213212
this->runTest(param);
214213
};
215-
#endif
216214

217215
TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization)
218216
{
@@ -221,7 +219,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization)
221219
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
222220
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
223221
/*usePdl=*/true, /*getExpWeights=*/true,
224-
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
222+
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
225223
this->runTest(param);
226224
};
227225

cpp/tests/unit_tests/kernels/routing/routingTest.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ void RoutingKernelTest<T>::verifyResult(RoutingKernelTestParam const& param)
332332
template <typename T>
333333
void RoutingKernelTest<T>::runTest(RoutingKernelTestParam const& param)
334334
{
335+
if (mDeviceProp.major < param.requiredComputeCapability)
336+
{
337+
GTEST_SKIP() << "Skip test due to compute capability requirement.";
338+
}
339+
335340
// Allocate buffers
336341
allocateBuffers(param);
337342
// Setup buffers

cpp/tests/unit_tests/kernels/routing/routingTest.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ struct RoutingKernelTestParam
218218
bool usePdl{true};
219219
bool getExpWeights{true};
220220

221+
int requiredComputeCapability{9};
221222
// Special for renormalize routing method
222223
bool doSoftmaxBeforeTopK{false};
223224
bool normTopkProb{true};
@@ -242,7 +243,7 @@ struct RoutingKernelTestParam
242243
RoutingKernelTestParam(RoutingMethodType routingMethod, int32_t numTokens, int32_t numExperts, uint32_t topK,
243244
int32_t expertParallelization = 1, int32_t expertParallelizationId = 0, int32_t paddingLog2 = 3,
244245
int32_t localExpertsStrideLog2 = 0, bool usePdl = true, bool getExpWeights = true, int32_t nGroup = 1,
245-
int32_t topkGroup = 1, float routedScalingFactor = 1.0f)
246+
int32_t topkGroup = 1, float routedScalingFactor = 1.0f, int requiredComputeCapability = 9)
246247
: routingMethod(routingMethod)
247248
, numTokens(numTokens)
248249
, numExperts(numExperts)
@@ -254,6 +255,7 @@ struct RoutingKernelTestParam
254255
, nGroup(nGroup)
255256
, topkGroup(topkGroup)
256257
, routedScalingFactor(routedScalingFactor)
258+
, requiredComputeCapability(requiredComputeCapability)
257259
{
258260
// Check the routing method
259261
if (routingMethod != RoutingMethodType::Renormalize && routingMethod != RoutingMethodType::RenormalizeNaive

0 commit comments

Comments
 (0)