diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu index 94251935005..9b260c04bd5 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.cu @@ -1244,6 +1244,8 @@ void run(Data const& data, void* stream) { // Reset the global histograms (not used in single-cluster code path). // Cover both for the cooperative and two-kernel code paths. + TLLM_CHECK_WITH_INFO( + data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input."); TLLM_CUDA_CHECK(cudaMemsetAsync( data.mPtrExpertCounts, 0, static_cast(2 * NumThreads) * sizeof(int32_t), (cudaStream_t) stream)); } diff --git a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h index 2846703f6b4..ecd7ce7654b 100644 --- a/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h +++ b/cpp/tensorrt_llm/kernels/trtllmGenKernels/blockScaleMoe/RoutingKernel.h @@ -17,7 +17,6 @@ #pragma once #include "IntFastDiv.h" - #include "tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h" #include @@ -31,7 +30,6 @@ namespace moe::dev namespace routing { - //////////////////////////////////////////////////////////////////////////////////////////////////// namespace tg = batchedGemm::trtllm::gen; diff --git a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp index b9b232255d2..1c00d6fdbbf 100644 --- a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp @@ -123,7 +123,8 @@ std::vector fp4_block_scale_moe_runner(torch::Tensor const& routi {args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits.device(), std::nullopt); at::Tensor expert_indexes = at::detail::empty_cuda( {args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt); - at::Tensor expert_count_histogram = at::detail::empty_cuda({((num_experts * 2 + 255) / 256) * 256}, + int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); + at::Tensor expert_count_histogram = at::detail::empty_cuda({size_of_expert_count_histogram}, at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts routing_logits.device(), std::nullopt); diff --git a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp index 476afa928e2..bb72c0e87de 100644 --- a/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp @@ -109,7 +109,8 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor {args.num_tokens, args.top_k}, routing_bias.scalar_type(), routing_logits.device(), std::nullopt); at::Tensor expert_indexes = at::detail::empty_cuda( {args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt); - at::Tensor expert_count_histogram = at::detail::empty_cuda({2 * 256}, + int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); + at::Tensor expert_count_histogram = at::detail::empty_cuda({size_of_expert_count_histogram}, at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts routing_logits.device(), std::nullopt); diff --git a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp index 2a2c1c9a766..8d4de68e9c3 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp @@ -246,7 +246,6 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest TYPED_TEST_SUITE(RoutingDeepSeekKernelTest, Bf16Types); -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10, @@ -254,7 +253,7 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization) /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f); + /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -265,20 +264,19 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f); + /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; -TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization) +TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030, - /*numExperts=*/128, /*topK=*/1, + /*numExperts=*/128, /*topK=*/8, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, - /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f); + /*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10); this->runTest(param); }; -#endif } // namespace diff --git a/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp b/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp index 484eec69cfd..1996c038449 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp @@ -141,11 +141,11 @@ TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelization) /*numExperts=*/128, /*topK=*/1, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, - /*usePdl=*/true, /*getExpWeights=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f); + /*usePdl=*/true, /*getExpWeights=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f, + /*requiredComputeCapability*/ 8); this->runTest(param); }; -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/100, @@ -153,10 +153,9 @@ TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelization) /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f); + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; -#endif TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelization) { @@ -165,7 +164,7 @@ TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelization) /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f); + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); this->runTest(param); }; diff --git a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp index 0889e2a8ae6..f4e2b7c7d4b 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp @@ -179,7 +179,6 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest TYPED_TEST_SUITE(RoutingRenormalizeKernelTest, FloatAndBf16Types); -#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10, @@ -187,7 +186,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization) /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f); + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -198,7 +197,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa /*expertParallelization=*/2, /*expertParallelizationId=*/1, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f); + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; @@ -209,10 +208,9 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f); + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; -#endif TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization) { @@ -221,7 +219,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization) /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, - /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f); + /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); this->runTest(param); }; diff --git a/cpp/tests/unit_tests/kernels/routing/routingTest.cpp b/cpp/tests/unit_tests/kernels/routing/routingTest.cpp index 3485e9e3c3f..8941a8fb18c 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingTest.cpp +++ b/cpp/tests/unit_tests/kernels/routing/routingTest.cpp @@ -332,6 +332,11 @@ void RoutingKernelTest::verifyResult(RoutingKernelTestParam const& param) template void RoutingKernelTest::runTest(RoutingKernelTestParam const& param) { + if (mDeviceProp.major < param.requiredComputeCapability) + { + GTEST_SKIP() << "Skip test due to compute capability requirement."; + } + // Allocate buffers allocateBuffers(param); // Setup buffers diff --git a/cpp/tests/unit_tests/kernels/routing/routingTest.h b/cpp/tests/unit_tests/kernels/routing/routingTest.h index 67bccd36c6d..890bae74627 100644 --- a/cpp/tests/unit_tests/kernels/routing/routingTest.h +++ b/cpp/tests/unit_tests/kernels/routing/routingTest.h @@ -218,6 +218,7 @@ struct RoutingKernelTestParam bool usePdl{true}; bool getExpWeights{true}; + int requiredComputeCapability{9}; // Special for renormalize routing method bool doSoftmaxBeforeTopK{false}; bool normTopkProb{true}; @@ -242,7 +243,7 @@ struct RoutingKernelTestParam RoutingKernelTestParam(RoutingMethodType routingMethod, int32_t numTokens, int32_t numExperts, uint32_t topK, int32_t expertParallelization = 1, int32_t expertParallelizationId = 0, int32_t paddingLog2 = 3, int32_t localExpertsStrideLog2 = 0, bool usePdl = true, bool getExpWeights = true, int32_t nGroup = 1, - int32_t topkGroup = 1, float routedScalingFactor = 1.0f) + int32_t topkGroup = 1, float routedScalingFactor = 1.0f, int requiredComputeCapability = 9) : routingMethod(routingMethod) , numTokens(numTokens) , numExperts(numExperts) @@ -254,6 +255,7 @@ struct RoutingKernelTestParam , nGroup(nGroup) , topkGroup(topkGroup) , routedScalingFactor(routedScalingFactor) + , requiredComputeCapability(requiredComputeCapability) { // Check the routing method if (routingMethod != RoutingMethodType::Renormalize && routingMethod != RoutingMethodType::RenormalizeNaive