-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Don't review - Testing: Support fp8 block wide ep #7209
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,7 +5,7 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
import torch | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm._mnnvl_utils import MnnvlMemory, MnnvlMoe, MoEAlltoallInfo | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm._utils import logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm._utils import get_sm_version, logger | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Incorrect logger import; use tensorrt_llm.logger. logger is not provided by tensorrt_llm._utils. Import from the canonical module. -from tensorrt_llm._utils import get_sm_version, logger
+from tensorrt_llm._utils import get_sm_version
+from tensorrt_llm.logger import logger 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm.functional import AllReduceStrategy | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from tensorrt_llm.mapping import Mapping | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -15,8 +15,10 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||
from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .deep_ep_utils import buffer_pool, deep_ep_installed | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .interface import MoE | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .moe_backend import MoEBackend, MoEBackendSelection | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .moe_load_balancer import get_moe_load_balancer | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethod, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
FP8QDQFusedMoEMethod, MoEWeightLoadingMode, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
NVFP4CutlassFusedMoEMethod, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
UnquantizedFusedMoEMethod, WInt4AFP8FusedMoEMethod) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -232,6 +234,9 @@ def __init__( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
self.enable_dummy_allreduce = os.environ.get( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
"TRTLLM_ENABLE_DUMMY_ALLREDUCE", "0") == "1" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# MoE backend will be lazily initialized when first accessed (see moe_backend property) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._moe_backend_impl = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _check_configs(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert self._weights_created | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -303,13 +308,20 @@ def calculate_num_chunks(self, all_rank_num_tokens: List[int]) -> int: | |||||||||||||||||||||||||||||||||||||||||||||||||||||
def can_use_alltoall(self, all_rank_num_tokens, all_rank_max_num_tokens): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Disable alltoall when chunking is used | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.calculate_num_chunks(all_rank_num_tokens) > 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"can not use alltoall due to chunking {self.calculate_num_chunks(all_rank_num_tokens)}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# For DeepEPLowLatency, check if tokens exceed the threshold | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (self.alltoall_method_type == AlltoallMethodType.DeepEPLowLatency | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
and all_rank_max_num_tokens > self.deep_ep_max_num_tokens): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"can not use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+319
to
+321
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove debug print statement from production code Debug print statements should be removed. - print(
- f"can not use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}"
- )
+ logger.debug(
+ f"Cannot use alltoall due to deep_ep_max_num_tokens {all_rank_max_num_tokens} > {self.deep_ep_max_num_tokens}"
+ ) 📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.12.2)320-320: Line too long (127 > 120) (E501) 🤖 Prompt for AI Agents
Comment on lines
+311
to
+321
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 💡 Verification agent 🧩 Analysis chainQuick audit: ensure no stray print() calls remain in fused_moe modules. Run this repo-local grep to confirm cleanup across the module. Also applies to: 324-324, 333-345, 418-421, 812-815 🏁 Script executed: #!/bin/bash
# Find any remaining print() in fused_moe Python sources (code, not tests).
rg -nP '^\s*print\(' tensorrt_llm/_torch/modules/fused_moe -g '!**/test/**' -S Length of output: 1144 Remove all Audit via the provided
Please remove these 🧰 Tools🪛 Ruff (0.12.2)320-320: Line too long (127 > 120) (E501) 🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
print(f"all to all type {self.alltoall_method_type}") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove debug print statement from production code Debug print statement should be removed. - print(f"all to all type {self.alltoall_method_type}")
+ logger.debug(f"Alltoall type: {self.alltoall_method_type}") 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self.enable_alltoall | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def _get_quant_method(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -318,7 +330,19 @@ def _get_quant_method(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.quant_config.layer_quant_mode.has_fp8_qdq(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return FP8QDQFusedMoEMethod() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif self.quant_config.layer_quant_mode.has_fp8_block_scales(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return DeepSeekFP8BlockScalesFusedMoEMethod() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if get_sm_version() == 100: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return DeepSeekFP8BlockScalesFusedMoEMethod() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+333
to
+345
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove debug print statements from production code Multiple debug print statements in the quantization method selection should be removed. - print(
- f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}"
- )
+ logger.debug(
+ f"wide_ep _get_quant_method: get_sm_version()={get_sm_version()}"
+ )
if get_sm_version() == 100:
- print(
- f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
- )
+ logger.debug(
+ "wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm"
+ )
return DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm()
else:
- print(
- f"wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
- )
+ logger.debug(
+ "wide_ep _get_quant_method: use DeepSeekFP8BlockScalesFusedMoEMethod"
+ )
return DeepSeekFP8BlockScalesFusedMoEMethod() 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif self.quant_config.layer_quant_mode.has_nvfp4(): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return NVFP4CutlassFusedMoEMethod() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif self.quant_config.layer_quant_mode.is_int4_weight_only_per_group( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -341,6 +365,19 @@ def create_weights(self): | |||||||||||||||||||||||||||||||||||||||||||||||||||||
self._weights_created = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._check_configs() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
@property | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
def moe_backend_impl(self) -> MoEBackend: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Lazily initialize and return the MoE backend. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
The backend is selected based on hardware capabilities and quantization | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
configuration, which are only available after weights are created. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self._moe_backend_impl is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
assert self._weights_created, "Weights must be created before accessing moe_backend" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self._moe_backend_impl = MoEBackendSelection.select_backend(self) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
return self._moe_backend_impl | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
def dummy_allreduce(self): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
Debug function for eliminating imbalance during performance analysis. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -388,11 +425,13 @@ def forward_chunk( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
is_first_call, is_last_call = repeating_info | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# f"xxi shape 1: enter wide_ep forward_chunk: layer_load_balancer={self.layer_load_balancer}, is_first_call={is_first_call}, is_last_call={is_last_call}, x shape: {getattr(x, 'shape', None)}, router_logits shape: {getattr(router_logits, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, repeating_info: {repeating_info}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.layer_load_balancer and is_first_call: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.layer_load_balancer.start_wait_gpu_stage() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_deepseek_fp8_block_scale = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_w4_group_scaling = False | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
weight_dtype = self.w3_w1_weight.dtype | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
token_selected_experts, token_final_scales = self.routing_method.apply( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -464,7 +503,7 @@ def forward_chunk( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
self.dummy_allreduce() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
token_count = x.shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
alltoall_info = None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if is_last_call: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.layer_load_balancer and is_last_call: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
loadbalancer_local_statistic_info = self.layer_load_balancer.get_local_statistic_tensor( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -547,9 +586,8 @@ def forward_chunk( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
x_sf = x_sf.view((x_row, -1)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif self.has_deepseek_fp8_block_scales: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_deepseek_fp8_block_scale = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
pass | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
elif self.has_w4afp8: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_w4_group_scaling = True | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
weight_dtype = torch.quint4x2 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -572,12 +610,12 @@ def forward_chunk( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
sizes=None if use_dp_padding else all_rank_num_tokens) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
x_row = x.shape[0] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
ep_size = self.ep_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
ep_rank = self.ep_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ep_size = self.ep_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ep_rank = self.ep_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
w3_w1_weight = self.w3_w1_weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
w2_weight = self.w2_weight | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
cluster_size = self.cluster_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
cluster_rank = self.cluster_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# cluster_size = self.cluster_size | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# cluster_rank = self.cluster_rank | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_scales = self.quant_scales | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if use_postquant_alltoall: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -638,7 +676,37 @@ def forward_chunk( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
f"Not available alltoall method type: {self.alltoall_method_type!r}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
final_hidden_states = torch.ops.trtllm.fused_moe( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Original fused_moe call (preserved as reference) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# final_hidden_states = torch.ops.trtllm.fused_moe( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# token_selected_slots, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# token_final_scales, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# w3_w1_weight.view(weight_dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# None, # w3_w1_bias | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# w2_weight.view(weight_dtype), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# None, # w2_bias | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# output_dtype, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# quant_scales=quant_scales, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# input_sf=x_sf, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# swizzled_input_sf=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# tp_size=self.tp_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# tp_rank=self.tp_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ep_size=ep_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ep_rank=ep_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# cluster_size=cluster_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# cluster_rank=cluster_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# enable_alltoall=use_all_to_all, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# use_w4_group_scaling=use_w4_group_scaling, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# min_latency_mode=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# tune_max_num_tokens=self.tune_max_num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# tuner_num_tokens=tuner_num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# tuner_top_k=tuner_top_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Use backend interface with module as first parameter for automatic configuration extraction | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
final_hidden_states = self.moe_backend_impl.run_moe( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self, # Module as first parameter | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
x, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
token_selected_slots, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
token_final_scales, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -650,21 +718,17 @@ def forward_chunk( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
quant_scales=quant_scales, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
input_sf=x_sf, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
swizzled_input_sf=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
tp_size=self.tp_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
tp_rank=self.tp_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
ep_size=ep_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
ep_rank=ep_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
cluster_size=cluster_size, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
cluster_rank=cluster_rank, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
enable_alltoall=use_all_to_all, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_w4_group_scaling=use_w4_group_scaling, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# Only need to pass runtime-variable parameters | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
min_latency_mode=False, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
tune_max_num_tokens=self.tune_max_num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_fused_finalize=True, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
tuner_num_tokens=tuner_num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
tuner_top_k=tuner_top_k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
# print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# f"xxi shape 4 after moe backend : {getattr(x, 'shape', None)}, final_hidden_states shape: {getattr(final_hidden_states, 'shape', None)}, token_selected_slots shape: {getattr(token_selected_slots, 'shape', None)}, token_final_scales shape: {getattr(token_final_scales, 'shape', None)}, w3_w1_weight shape: {getattr(w3_w1_weight, 'shape', None)}, w2_weight shape: {getattr(w2_weight, 'shape', None)}, quant_scales: {getattr(quant_scales, 'shape', None)}, input_sf: {getattr(x_sf, 'shape', None)}, swizzled_input_sf: False, tp_size: {self.tp_size}, tp_rank: {self.tp_rank}, ep_size: {ep_size}, ep_rank: {ep_rank}, cluster_size: {cluster_size}, cluster_rank: {cluster_rank}, enable_alltoall: {use_all_to_all}, use_deepseek_fp8_block_scale: {use_deepseek_fp8_block_scale}, use_w4_group_scaling: {use_w4_group_scaling}, min_latency_mode: False, tune_max_num_tokens: {self.tune_max_num_tokens}, tuner_num_tokens: {tuner_num_tokens}, tuner_top_k: {tuner_top_k}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
if self.layer_load_balancer and is_last_call: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
self.layer_load_balancer.start_set_cpu_stage() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -743,6 +807,10 @@ def forward( | |||||||||||||||||||||||||||||||||||||||||||||||||||||
all_rank_max_num_tokens=all_rank_max_num_tokens, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_dp_padding=use_dp_padding, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
repeating_info=(is_first_call, is_last_call)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# 一行打印所有信息 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# print( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# f"xxi x.shape: {getattr(x, 'shape', None)}, use_all_to_all: {use_all_to_all}, all_rank_num_tokens: {all_rank_num_tokens}, all_rank_num_tokens_padded: {all_rank_num_tokens_padded}, all_rank_max_num_tokens: {all_rank_max_num_tokens}, use_dp_padding: {use_dp_padding}, outputs.shape: {getattr(outputs, 'shape', None)}, use_dp_padding(again): {use_dp_padding}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
# ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
outputs = self.reducescatter_or_allreduce( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
outputs, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
use_all_to_all, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debug print statement from production code
Debug print statements should not be in production code. Consider using proper logging instead.
📝 Committable suggestion
🧰 Tools
🪛 Ruff (0.12.2)
246-246: Line too long (156 > 120)
(E501)
🤖 Prompt for AI Agents