Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1244,6 +1244,8 @@ void run(Data const& data, void* stream)
{
// Reset the global histograms (not used in single-cluster code path).
// Cover both for the cooperative and two-kernel code paths.
TLLM_CHECK_WITH_INFO(
data.mPtrExpertCounts != nullptr, "When #tokens is large, `mPtrExpertCounts` is a required input.");
TLLM_CUDA_CHECK(cudaMemsetAsync(
data.mPtrExpertCounts, 0, static_cast<size_t>(2 * NumThreads) * sizeof(int32_t), (cudaStream_t) stream));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#pragma once

#include "IntFastDiv.h"

#include "tensorrt_llm/kernels/trtllmGenKernels/batchedGemm/trtllmGen_bmm_export/trtllm/gen/DtypeDecl.h"

#include <cuda.h>
Expand All @@ -31,7 +30,6 @@ namespace moe::dev

namespace routing
{

////////////////////////////////////////////////////////////////////////////////////////////////////

namespace tg = batchedGemm::trtllm::gen;
Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
{args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits.device(), std::nullopt);
at::Tensor expert_indexes = at::detail::empty_cuda(
{args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
at::Tensor expert_count_histogram = at::detail::empty_cuda({((num_experts * 2 + 255) / 256) * 256},
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
at::Tensor expert_count_histogram = at::detail::empty_cuda({size_of_expert_count_histogram},
at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts
routing_logits.device(), std::nullopt);

Expand Down
3 changes: 2 additions & 1 deletion cpp/tensorrt_llm/thop/fp8BlockScaleMoe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,8 @@ at::Tensor run_fp8_block_scale_moe(at::Tensor const& routing_logits, at::Tensor
{args.num_tokens, args.top_k}, routing_bias.scalar_type(), routing_logits.device(), std::nullopt);
at::Tensor expert_indexes = at::detail::empty_cuda(
{args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
at::Tensor expert_count_histogram = at::detail::empty_cuda({2 * 256},
int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2));
at::Tensor expert_count_histogram = at::detail::empty_cuda({size_of_expert_count_histogram},
at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts
routing_logits.device(), std::nullopt);

Expand Down
12 changes: 5 additions & 7 deletions cpp/tests/unit_tests/kernels/routing/routingDeepSeekTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,15 +246,14 @@ class RoutingDeepSeekKernelTest : public RoutingKernelTest<T>

TYPED_TEST_SUITE(RoutingDeepSeekKernelTest, Bf16Types);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/10,
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f);
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

Expand All @@ -265,20 +264,19 @@ TYPED_TEST(RoutingDeepSeekKernelTest, ClusterLevelParallelizationWithExpertParal
/*expertParallelization=*/2, /*expertParallelizationId=*/1,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f);
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

TYPED_TEST(RoutingDeepSeekKernelTest, DeviceLevelParallelization)
TYPED_TEST(RoutingDeepSeekKernelTest, CooperativeLevelParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::DeepSeekV3, /*numTokens=*/1030,
/*numExperts=*/128, /*topK=*/1,
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f);
/*nGroup*/ 8, /*topkGroup*/ 4, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 10);
this->runTest(param);
};
#endif

} // namespace
9 changes: 4 additions & 5 deletions cpp/tests/unit_tests/kernels/routing/routingLlama4Test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,22 +141,21 @@ TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelization)
/*numExperts=*/128, /*topK=*/1,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f);
/*usePdl=*/true, /*getExpWeights=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f,
/*requiredComputeCapability*/ 8);
this->runTest(param);
};

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/100,
/*numExperts=*/128, /*topK=*/1,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
#endif

TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelization)
{
Expand All @@ -165,7 +164,7 @@ TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelization)
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
this->runTest(param);
};

Expand Down
10 changes: 4 additions & 6 deletions cpp/tests/unit_tests/kernels/routing/routingRenormalizeTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,15 +179,14 @@ class RoutingRenormalizeKernelTest : public RoutingKernelTest<T>

TYPED_TEST_SUITE(RoutingRenormalizeKernelTest, FloatAndBf16Types);

#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelization)
{
RoutingKernelTestParam param(RoutingMethodType::Renormalize, /*numTokens=*/10,
/*numExperts=*/128, /*topK=*/8,
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

Expand All @@ -198,7 +197,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithExpertPa
/*expertParallelization=*/2, /*expertParallelizationId=*/1,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};

Expand All @@ -209,10 +208,9 @@ TYPED_TEST(RoutingRenormalizeKernelTest, ClusterLevelParallelizationWithRenormal
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9);
this->runTest(param);
};
#endif

TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization)
{
Expand All @@ -221,7 +219,7 @@ TYPED_TEST(RoutingRenormalizeKernelTest, DeviceLevelParallelization)
/*expertParallelization=*/1, /*expertParallelizationId=*/0,
/*paddingLog2=*/3, /*localExpertsStrideLog2=*/0,
/*usePdl=*/true, /*getExpWeights=*/true,
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f);
/*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8);
this->runTest(param);
};

Expand Down
5 changes: 5 additions & 0 deletions cpp/tests/unit_tests/kernels/routing/routingTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,11 @@ void RoutingKernelTest<T>::verifyResult(RoutingKernelTestParam const& param)
template <typename T>
void RoutingKernelTest<T>::runTest(RoutingKernelTestParam const& param)
{
if (mDeviceProp.major < param.requiredComputeCapability)
{
GTEST_SKIP() << "Skip test due to compute capability requirement.";
}

// Allocate buffers
allocateBuffers(param);
// Setup buffers
Expand Down
4 changes: 3 additions & 1 deletion cpp/tests/unit_tests/kernels/routing/routingTest.h
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ struct RoutingKernelTestParam
bool usePdl{true};
bool getExpWeights{true};

int requiredComputeCapability{9};
// Special for renormalize routing method
bool doSoftmaxBeforeTopK{false};
bool normTopkProb{true};
Expand All @@ -242,7 +243,7 @@ struct RoutingKernelTestParam
RoutingKernelTestParam(RoutingMethodType routingMethod, int32_t numTokens, int32_t numExperts, uint32_t topK,
int32_t expertParallelization = 1, int32_t expertParallelizationId = 0, int32_t paddingLog2 = 3,
int32_t localExpertsStrideLog2 = 0, bool usePdl = true, bool getExpWeights = true, int32_t nGroup = 1,
int32_t topkGroup = 1, float routedScalingFactor = 1.0f)
int32_t topkGroup = 1, float routedScalingFactor = 1.0f, int requiredComputeCapability = 9)
: routingMethod(routingMethod)
, numTokens(numTokens)
, numExperts(numExperts)
Expand All @@ -254,6 +255,7 @@ struct RoutingKernelTestParam
, nGroup(nGroup)
, topkGroup(topkGroup)
, routedScalingFactor(routedScalingFactor)
, requiredComputeCapability(requiredComputeCapability)
{
// Check the routing method
if (routingMethod != RoutingMethodType::Renormalize && routingMethod != RoutingMethodType::RenormalizeNaive
Expand Down