Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ def reducescatter(
if isinstance(input, torch.Tensor):
assert input.shape[dim] == sum_split_size
else:
for val in input:
if val is not None and val.shape[dim] != sum_split_size:
print(
f"[reducescatter] val.shape={val.shape}, dim={dim}, val.shape[dim]={val.shape[dim]}, sum_split_size={sum_split_size}, sizes={sizes}"
)
Comment on lines +243 to +247
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove debug print statement from production code

Debug print statements should not be in production code. Consider using proper logging instead.

-            for val in input:
-                if val is not None and val.shape[dim] != sum_split_size:
-                    print(
-                        f"[reducescatter] val.shape={val.shape}, dim={dim}, val.shape[dim]={val.shape[dim]}, sum_split_size={sum_split_size}, sizes={sizes}"
-                    )
+            for val in input:
+                if val is not None and val.shape[dim] != sum_split_size:
+                    logger.debug(
+                        f"[reducescatter] val.shape={val.shape}, dim={dim}, val.shape[dim]={val.shape[dim]}, sum_split_size={sum_split_size}, sizes={sizes}"
+                    )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for val in input:
if val is not None and val.shape[dim] != sum_split_size:
print(
f"[reducescatter] val.shape={val.shape}, dim={dim}, val.shape[dim]={val.shape[dim]}, sum_split_size={sum_split_size}, sizes={sizes}"
)
for val in input:
if val is not None and val.shape[dim] != sum_split_size:
logger.debug(
f"[reducescatter] val.shape={val.shape}, dim={dim}, val.shape[dim]={val.shape[dim]}, sum_split_size={sum_split_size}, sizes={sizes}"
)
🧰 Tools
🪛 Ruff (0.12.2)

246-246: Line too long (156 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/distributed/ops.py around lines 243 to 247, remove the
debug print statement that outputs shape and size info; replace it with a proper
logging call using the module logger (e.g., logger.debug or logger.warning) and
include the same contextual message and variables so it's available in logs but
not printed to stdout; ensure the module imports and uses the logger
consistently and that the log level chosen is appropriate for diagnostics.

assert all([
val.shape[dim] == sum_split_size for val in input
if val is not None
Expand Down Expand Up @@ -455,8 +460,16 @@ def __init__(self,
self.workspace = get_allreduce_workspace(self.mapping)

# Initialize MNNVL AllReduce if needed
if self.strategy == AllReduceStrategy.MNNVL:
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
# if self.strategy == AllReduceStrategy.MNNVL:
# if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
if self.strategy in (AllReduceStrategy.AUTO,
AllReduceStrategy.MNNVL):
if self.mapping.tp_size != self.mapping.world_size:
logger.debug(
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
f"!= world_size:{self.mapping.world_size}")
self.mnnvl_allreduce = None
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
try:
self.mnnvl_allreduce = MNNVLAllReduce(
self.mapping, dtype) if dtype else None
Expand Down
25 changes: 20 additions & 5 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from tensorrt_llm._ipc_utils import can_access_peer
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.functional import PositionEmbeddingType
from tensorrt_llm.functional import AllReduceStrategy, PositionEmbeddingType
from tensorrt_llm.llmapi.utils import enable_llm_debug
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
Expand All @@ -52,6 +52,7 @@
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams,
MoEAllReduce, MoEAllReduceParams, allgather)
from ..distributed.ops import MNNVLAllReduce
from ..model_config import ModelConfig
from ..modules.attention import MLA
from ..modules.decoder_layer import DecoderLayer
Expand Down Expand Up @@ -738,10 +739,24 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
intermediate_size // block_size,
self.mapping.tp_size,
)
mlp_tp_size = math.gcd(
tp,
self.mapping.gpus_per_node,
) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP
# mlp_tp_size = math.gcd(
# tp,
# self.mapping.gpus_per_node,
# ) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP
if tp > self.mapping.gpus_per_node and (
self.model_config.allreduce_strategy not in (
AllReduceStrategy.AUTO,
AllReduceStrategy.MNNVL,
) or not MNNVLAllReduce.is_mnnvl(
self.mapping,
self.model_config.pretrained_config.torch_dtype)):
mlp_tp_size = math.gcd(
tp,
self.mapping.gpus_per_node,
) # Avoid costly inter-node TP when MNNVL is not supported and tp > gpus_per_node
else:
mlp_tp_size = tp

return mlp_tp_size

def forward(
Expand Down
112 changes: 90 additions & 22 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo
from tensorrt_llm._utils import logger
from tensorrt_llm._utils import get_sm_version, logger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Incorrect logger import; use tensorrt_llm.logger.

logger is not provided by tensorrt_llm._utils. Import from the canonical module.

-from tensorrt_llm._utils import get_sm_version, logger
+from tensorrt_llm._utils import get_sm_version
+from tensorrt_llm.logger import logger
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
from tensorrt_llm._utils import get_sm_version, logger
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.logger import logger
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py around line 8, the
code imports logger from tensorrt_llm._utils but that module does not export
logger; replace the import to pull logger from the canonical module
tensorrt_llm.logger (i.e., import get_sm_version from tensorrt_llm._utils and
import logger from tensorrt_llm.logger) so the module uses the correct logger
object.

from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.mapping import Mapping

Expand All @@ -15,8 +15,10 @@
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor
from .deep_ep_utils import buffer_pool, deep_ep_installed
from .interface import MoE
from .moe_backend import MoEBackend, MoEBackendSelection
from .moe_load_balancer import get_moe_load_balancer
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod,
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm,
FP8QDQFusedMoEMethod, MoEWeightLoadingMode,
NVFP4CutlassFusedMoEMethod,
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod)
Expand Down Expand Up @@ -232,6 +234,9 @@ def __init__(
self.enable_dummy_allreduce = os.environ.get(
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1"

# MoE backend will be lazily initialized when first accessed (see moe_backend property)
self._moe_backend_impl = None

def _check_configs(self):
assert self._weights_created

Expand Down Expand Up @@ -303,13 +308,20 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int:
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens):
# Disable alltoall when chunking is used
if self.calculate_num_chunks(all_rank_num_tokens) > 1:
print(
f"can not use alltoall due to chunking {self.calculate_num_chunks(all_rank_num_tokens)}"
)
return False

# For DeepEPLowLatency, check if tokens exceed the threshold
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency
and all_rank_max_num_tokens > self.deep_ep_max_num_tokens):
print(
f"can not use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}"
)
Comment on lines +319 to +321
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove debug print statement from production code

Debug print statements should be removed.

-            print(
-                f"can not use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}"
-            )
+            logger.debug(
+                f"Cannot use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}"
+            )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(
f"can not use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}"
)
logger.debug(
f"Cannot use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}"
)
🧰 Tools
🪛 Ruff (0.12.2)

320-320: Line too long (127 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py around lines
319-321, remove the debug print statement; replace it with a proper logging call
(e.g., module logger or logging.warning) or drop it entirely if no runtime
notice is needed. If using logging, ensure an appropriate logger is available or
import logging and call logger.warning(f"cannot use alltoall due to
deep_ep_max_num_tokens {all_rank_max_num_tokens} >
{self.deep_ep_max_num_tokens}") to preserve context instead of printing to
stdout.

Comment on lines +311 to +321
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Quick audit: ensure no stray print() calls remain in fused_moe modules.

Run this repo-local grep to confirm cleanup across the module.

Also applies to: 324-324, 333-345, 418-421, 812-815


🏁 Script executed:

#!/bin/bash
# Find any remaining print() in fused_moe Python sources (code, not tests).
rg -nP '^\s*print\(' tensorrt_llm/_torch/modules/fused_moe -g '!**/test/**' -S

Length of output: 1144


Remove all print() calls from fused_moe modules

Audit via the provided rg output shows the following remaining print() statements that should be removed or replaced with proper logging:

  • tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py
    • Line 963: print( … )
  • tensorrt_llm/_torch/modules/fused_moe/moe_backend.py
    • Lines 893, 897: print("xxi select backend: …")
  • tensorrt_llm/_torch/modules/fused_moe/routing.py
    • Lines 77, 95, 102: debug routing prints
  • tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
    • Lines 311, 319, 324, 333, 337, 342: alltoall and chunking diagnostics

Please remove these print() calls. If retaining these messages is necessary for debugging or telemetry, convert them to the module’s logger (e.g., logging.debug or logging.info) at the appropriate level.

🧰 Tools
🪛 Ruff (0.12.2)

320-320: Line too long (127 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py around lines 311
to 321, remove the inline print() calls and replace them with the module logger
(e.g., logger.debug or logger.info) preserving the original message
text/formatting; ensure a logger is available (import logging and create logger
= logging.getLogger(__name__) if not already present) and keep control
flow/return behavior unchanged so the function still returns False where
intended.

return False

print(f"all to all type {self.alltoall_method_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove debug print statement from production code

Debug print statement should be removed.

-        print(f"all to all type {self.alltoall_method_type}")
+        logger.debug(f"Alltoall type: {self.alltoall_method_type}")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(f"all to all type {self.alltoall_method_type}")
logger.debug(f"Alltoall type: {self.alltoall_method_type}")
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py around line 324,
there is a debug print statement printing the all-to-all method type; remove
this print and replace it with either no output or a proper logger call if
runtime visibility is required (use the module's logger or Python logging at an
appropriate log level). Ensure no stray print() calls remain in production code
and run tests/linters to confirm removal.

return self.enable_alltoall

def _get_quant_method(self):
Expand All @@ -318,7 +330,19 @@ def _get_quant_method(self):
if self.quant_config.layer_quant_mode.has_fp8_qdq():
return FP8QDQFusedMoEMethod()
elif self.quant_config.layer_quant_mode.has_fp8_block_scales():
return DeepSeekFP8BlockScalesFusedMoEMethod()
print(
f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}"
)
if get_sm_version() == 100:
print(
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
)
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
else:
print(
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
)
return DeepSeekFP8BlockScalesFusedMoEMethod()
Comment on lines +333 to +345
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Remove debug print statements from production code

Multiple debug print statements in the quantization method selection should be removed.

-                print(
-                    f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}"
-                )
+                logger.debug(
+                    f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}"
+                )
                 if get_sm_version() == 100:
-                    print(
-                        f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
-                    )
+                    logger.debug(
+                        "wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
+                    )
                     return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
                 else:
-                    print(
-                        f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
-                    )
+                    logger.debug(
+                        "wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
+                    )
                     return DeepSeekFP8BlockScalesFusedMoEMethod()
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
print(
f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}"
)
if get_sm_version() == 100:
print(
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
)
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
else:
print(
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
)
return DeepSeekFP8BlockScalesFusedMoEMethod()
logger.debug(
f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}"
)
if get_sm_version() == 100:
logger.debug(
"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
)
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
else:
logger.debug(
"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
)
return DeepSeekFP8BlockScalesFusedMoEMethod()
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py around lines 333
to 345 there are debug print() calls in _get_quant_method; remove these print
statements so production code has no stdout debug noise, and if project logging
is available replace them with an appropriate logger.debug(...) calls preserving
the same conditional flow and returned classes (no behavioral change).

elif self.quant_config.layer_quant_mode.has_nvfp4():
return NVFP4CutlassFusedMoEMethod()
elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group(
Expand All @@ -341,6 +365,19 @@ def create_weights(self):
self._weights_created = True
self._check_configs()

@property
def moe_backend_impl(self) -> MoEBackend:
"""
Lazily initialize and return the MoE backend.

The backend is selected based on hardware capabilities and quantization
configuration, which are only available after weights are created.
"""
if self._moe_backend_impl is None:
assert self._weights_created, "Weights must be created before accessing moe_backend"
self._moe_backend_impl = MoEBackendSelection.select_backend(self)
return self._moe_backend_impl

def dummy_allreduce(self):
"""
Debug function for eliminating imbalance during performance analysis.
Expand Down Expand Up @@ -388,11 +425,13 @@ def forward_chunk(

is_first_call, is_last_call = repeating_info

# print(
# f"xxi shape 1: enter wide_ep forward_chunk: layer_load_balancer={self.layer_load_balancer}, is_first_call={is_first_call}, is_last_call={is_last_call}, x shape: {getattr(x, 'shape', None)}, router_logits shape: {getattr(router_logits, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, repeating_info: {repeating_info}"
# )

if self.layer_load_balancer and is_first_call:
self.layer_load_balancer.start_wait_gpu_stage()

use_deepseek_fp8_block_scale = False
use_w4_group_scaling = False
weight_dtype = self.w3_w1_weight.dtype

token_selected_experts, token_final_scales = self.routing_method.apply(
Expand Down Expand Up @@ -464,7 +503,7 @@ def forward_chunk(
self.dummy_allreduce()
token_count = x.shape[0]
alltoall_info = None
if is_last_call:
if self.layer_load_balancer and is_last_call:
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor(
)
else:
Expand Down Expand Up @@ -547,9 +586,8 @@ def forward_chunk(
x_sf = x_sf.view((x_row, -1))

elif self.has_deepseek_fp8_block_scales:
use_deepseek_fp8_block_scale = True
pass
elif self.has_w4afp8:
use_w4_group_scaling = True
weight_dtype = torch.quint4x2
else:
raise ValueError(
Expand All @@ -572,12 +610,12 @@ def forward_chunk(
sizes=None if use_dp_padding else all_rank_num_tokens)
x_row = x.shape[0]

ep_size = self.ep_size
ep_rank = self.ep_rank
# ep_size = self.ep_size
# ep_rank = self.ep_rank
w3_w1_weight = self.w3_w1_weight
w2_weight = self.w2_weight
cluster_size = self.cluster_size
cluster_rank = self.cluster_rank
# cluster_size = self.cluster_size
# cluster_rank = self.cluster_rank
quant_scales = self.quant_scales

if use_postquant_alltoall:
Expand Down Expand Up @@ -638,7 +676,37 @@ def forward_chunk(
f"Not available alltoall method type: {self.alltoall_method_type!r}"
)

final_hidden_states = torch.ops.trtllm.fused_moe(
# Original fused_moe call (preserved as reference)
# final_hidden_states = torch.ops.trtllm.fused_moe(
# x,
# token_selected_slots,
# token_final_scales,
# w3_w1_weight.view(weight_dtype),
# None, # w3_w1_bias
# w2_weight.view(weight_dtype),
# None, # w2_bias
# output_dtype,
# quant_scales=quant_scales,
# input_sf=x_sf,
# swizzled_input_sf=False,
# tp_size=self.tp_size,
# tp_rank=self.tp_rank,
# ep_size=ep_size,
# ep_rank=ep_rank,
# cluster_size=cluster_size,
# cluster_rank=cluster_rank,
# enable_alltoall=use_all_to_all,
# use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
# use_w4_group_scaling=use_w4_group_scaling,
# min_latency_mode=False,
# tune_max_num_tokens=self.tune_max_num_tokens,
# tuner_num_tokens=tuner_num_tokens,
# tuner_top_k=tuner_top_k,
# )

# Use backend interface with module as first parameter for automatic configuration extraction
final_hidden_states = self.moe_backend_impl.run_moe(
self, # Module as first parameter
x,
token_selected_slots,
token_final_scales,
Expand All @@ -650,21 +718,17 @@ def forward_chunk(
quant_scales=quant_scales,
input_sf=x_sf,
swizzled_input_sf=False,
tp_size=self.tp_size,
tp_rank=self.tp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
cluster_size=cluster_size,
cluster_rank=cluster_rank,
enable_alltoall=use_all_to_all,
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale,
use_w4_group_scaling=use_w4_group_scaling,
# Only need to pass runtime-variable parameters
min_latency_mode=False,
tune_max_num_tokens=self.tune_max_num_tokens,
use_fused_finalize=True,
tuner_num_tokens=tuner_num_tokens,
tuner_top_k=tuner_top_k,
)

# print(
# f"xxi shape 4 after moe backend : {getattr(x, 'shape', None)}, final_hidden_states shape: {getattr(final_hidden_states, 'shape', None)}, token_selected_slots shape: {getattr(token_selected_slots, 'shape', None)}, token_final_scales shape: {getattr(token_final_scales, 'shape', None)}, w3_w1_weight shape: {getattr(w3_w1_weight, 'shape', None)}, w2_weight shape: {getattr(w2_weight, 'shape', None)}, quant_scales: {getattr(quant_scales, 'shape', None)}, input_sf: {getattr(x_sf, 'shape', None)}, swizzled_input_sf: False, tp_size: {self.tp_size}, tp_rank: {self.tp_rank}, ep_size: {ep_size}, ep_rank: {ep_rank}, cluster_size: {cluster_size}, cluster_rank: {cluster_rank}, enable_alltoall: {use_all_to_all}, use_deepseek_fp8_block_scale: {use_deepseek_fp8_block_scale}, use_w4_group_scaling: {use_w4_group_scaling}, min_latency_mode: False, tune_max_num_tokens: {self.tune_max_num_tokens}, tuner_num_tokens: {tuner_num_tokens}, tuner_top_k: {tuner_top_k}"
# )

if self.layer_load_balancer and is_last_call:
self.layer_load_balancer.start_set_cpu_stage()

Expand Down Expand Up @@ -743,6 +807,10 @@ def forward(
all_rank_max_num_tokens=all_rank_max_num_tokens,
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
# 一行打印所有信息
# print(
# f"xxi x.shape: {getattr(x, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_num_tokens_padded: {all_rank_num_tokens_padded}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, outputs.shape: {getattr(outputs, 'shape', None)}, use_dp_padding(again): {use_dp_padding}"
# )
outputs = self.reducescatter_or_allreduce(
outputs,
use_all_to_all,
Expand Down
Loading