diff --git a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp index b9b232255d2..73274f18043 100644 --- a/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp +++ b/cpp/tensorrt_llm/thop/fp4BlockScaleMoe.cpp @@ -76,6 +76,7 @@ std::vector fp4_block_scale_moe_runner(torch::Tensor const& routi TORCH_CHECK(num_experts % 4 == 0, "Routing kernel expects that num_experts must be divisible by 4"); TORCH_CHECK(num_experts > top_k, "num_experts must be greater than top_k"); + TORCH_CHECK(num_experts <= 256, "num_experts must be less than or equal to 256"); tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::MoERunnerArgs args; tensorrt_llm::kernels::trtllmGenFp8BlockScaleMoe::MoE::MoEWorkspace workspace; @@ -123,7 +124,7 @@ std::vector fp4_block_scale_moe_runner(torch::Tensor const& routi {args.num_tokens, args.top_k}, routing_bias_dtype, routing_logits.device(), std::nullopt); at::Tensor expert_indexes = at::detail::empty_cuda( {args.num_tokens, args.top_k}, at::ScalarType::Int, routing_logits.device(), std::nullopt); - at::Tensor expert_count_histogram = at::detail::empty_cuda({((num_experts * 2 + 255) / 256) * 256}, + at::Tensor expert_count_histogram = at::detail::empty_cuda({2 * 256}, at::ScalarType::Int, // 256 is the max number of threads per block and max number of experts routing_logits.device(), std::nullopt); diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index ad674d948d2..8685bab0cb6 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -724,10 +724,9 @@ def _compute_mlp_tp_size(self, intermediate_size: int, return mlp_tp_size def _enable_min_latency_mode(self, num_tokens: int): - return (num_tokens <= 128 and self.fusion_config.POST_MOE_FUSION - and self.is_nvfp4 and self.model_config.moe_backend == 'CUTLASS' - and not self.mapping.is_multi_node() - and self.allreduce.mnnvl_allreduce is None) + # Disable cutlass min latency mode since it will cause illegal memory access + # Will use trtllm-gen moe in future + return False def forward( self, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index b820358de63..23d7a468269 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -882,8 +882,7 @@ def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph, pytest.skip("https://nvbugs/5252559") if torch_compile and pp_size > 1: pytest.skip("PP with torch.compile is not supported yet.") - if not attention_dp and (tp_size > 1 or ep_size > 1): - pytest.skip("https://nvbugs/5336321") + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = TorchCompileConfig( enable_fullgraph=True,