Skip to content

Commit 0e57a0b

Browse files
committed
Fix
Signed-off-by: Yi Zhang <[email protected]>
1 parent 5e50fcc commit 0e57a0b

File tree

2 files changed

+2
-3
lines changed

2 files changed

+2
-3
lines changed

cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
7676

7777
TORCH_CHECK(num_experts % 4 == 0, "Routing kernel expects that num_experts must be divisible by 4");
7878
TORCH_CHECK(num_experts > top_k, "num_experts must be greater than top_k");
79+
TORCH_CHECK(num_experts <= 256, "num_experts must be less than or equal to 256");
7980

8081
tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::MoERunnerArgs args;
8182
tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::MoEWorkspace workspace;
@@ -123,7 +124,7 @@ std::vector<torch::Tensor> fp4_block_scale_moe_runner(torch::Tensor const& routi
123124
{args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits.device(), std::nullopt);
124125
at::Tensor expert_indexes = at::detail::empty_cuda(
125126
{args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt);
126-
at::Tensor expert_count_histogram = at::detail::empty_cuda({((num_experts * 2 + 255) / 256) * 256},
127+
at::Tensor expert_count_histogram = at::detail::empty_cuda({2 * 256},
127128
at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts
128129
routing_logits.device(), std::nullopt);
129130

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -871,8 +871,6 @@ def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
871871
pytest.skip("https://nvbugs/5252559")
872872
if torch_compile and pp_size > 1:
873873
pytest.skip("PP with torch.compile is not supported yet.")
874-
if not attention_dp and (tp_size > 1 or ep_size > 1):
875-
pytest.skip("https://nvbugs/5336321")
876874
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
877875
torch_compile_config = TorchCompileConfig(
878876
enable_fullgraph=True,

0 commit comments

Comments
 (0)