Skip to content

Commit f8f48ce

Browse files
committed
remove all_rank_max_num_tokens
Signed-off-by: qgai <[email protected]>
1 parent 26e38f4 commit f8f48ce

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

tensorrt_llm/_torch/models/modeling_deepseekv3.py

100644100755
Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,7 @@ def split_kv_b_proj(kv_b_proj: torch.Tensor,
395395
p.data.copy_(module_weights[n][:])
396396

397397
if self.model_config.quant_config.layer_quant_mode.has_fp8_block_scales(
398-
) and is_sm_100f() and hasattr(
399-
module, "weight_scale"):
398+
) and is_sm_100f() and hasattr(module, "weight_scale"):
400399
weight, weight_scale = resmooth_to_fp8_e8m0(
401400
module.weight, module.weight_scale)
402401
transfromed_scale = transform_sf_into_required_layout(
@@ -805,8 +804,9 @@ def __init__(self,
805804
for key in [EventType.Main, EventType.MoeShared]
806805
}
807806

808-
def _compute_shared_expert_tp_size(self, intermediate_size: int,
809-
block_size: int) -> int:
807+
def _compute_shared_expert_tp_size(
808+
self, intermediate_size: int,
809+
block_size: int) -> tuple[int, float | None]:
810810
"""
811811
In the case of Deepseek-R1, the TP size of MLP is capped by intermediate_size // block_size.
812812
For example, when the intermediate_size is 2048 and block scaling size is 128,
@@ -818,7 +818,9 @@ def _compute_shared_expert_tp_size(self, intermediate_size: int,
818818
it's 128. For NVFP4, it's 16.
819819
820820
Returns:
821-
int: The computed tp_size.
821+
tuple[int, float | None]: A tuple containing (shared_tp_size, shared_output_scale).
822+
- shared_tp_size: The computed TP size.
823+
- shared_output_scale: The output scale factor, or None if not needed.
822824
"""
823825

824826
assert intermediate_size % block_size == 0, "intermediate_size must be divisible by block_size."

tensorrt_llm/_torch/models/modeling_speculative.py

100644100755
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,6 @@ def forward(
393393
hidden_states: torch.Tensor,
394394
attn_metadata: AttentionMetadata,
395395
all_rank_num_tokens: Optional[List[int]] = None,
396-
all_rank_max_num_tokens: Optional[int] = None,
397396
**kwargs,
398397
) -> Tuple[torch.Tensor, torch.Tensor]:
399398
hidden_states = self.layers(
@@ -403,7 +402,6 @@ def forward(
403402
embed_tokens=self.embed_tokens,
404403
attn_metadata=attn_metadata,
405404
all_rank_num_tokens=all_rank_num_tokens,
406-
all_rank_max_num_tokens=all_rank_max_num_tokens,
407405
)
408406

409407
return hidden_states
@@ -458,7 +456,6 @@ def forward(self,
458456
hidden_states=hidden_states,
459457
attn_metadata=attn_metadata,
460458
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
461-
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
462459
**kwargs)
463460
return self.logits_processor.forward(
464461
output,

0 commit comments

Comments
 (0)