From 23cb5299801d7334d8a713ebb78da3b202067156 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Sun, 10 Aug 2025 00:52:56 -0700 Subject: [PATCH 1/5] pre-allocate workspace for dg moe. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 255 ++++++++++++++++-- tensorrt_llm/quantization/utils/fp8_utils.py | 16 +- 2 files changed, 237 insertions(+), 34 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index fcfdeb0fad2..1e32cea0eea 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -11,7 +11,7 @@ from ...distributed import allgather from ...model_config import ModelConfig -from ...utils import AuxStreamType, Fp4QuantizedTensor +from ...utils import AuxStreamType, EventType, Fp4QuantizedTensor from .fused_moe_cutlass import CutlassFusedMoE from .quantization import (DeepSeekFP8BlockScalesFusedMoEMethodDeepGemm, MoEWeightLoadingMode, UnquantizedFusedMoEMethod) @@ -88,6 +88,7 @@ def _masked_index_copy_group_quant_fp8( def masked_index_copy_group_quant_fp8( output: torch.Tensor, + output_s: torch.Tensor, input: torch.Tensor, start_offsets: torch.Tensor, row_indices: torch.Tensor, @@ -107,15 +108,11 @@ def masked_index_copy_group_quant_fp8( row_size = output.shape[0] col_size = output.shape[1] dim_size = output.shape[2] - - # create padded output_s + alignment = 4 scale_dim = (dim_size + group_size - 1) // group_size padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment padded_col_size = (col_size + alignment - 1) // alignment * alignment - output_s = torch.zeros((row_size, padded_dim_size // 4, padded_col_size), - dtype=torch.int32, - device='cuda') # get block/grid/stage/warp num_groups = (dim_size + group_size - 1) // group_size @@ -247,6 +244,7 @@ def preprocess_after_permute(expert_first_token_offset_tensor, @nvtx_range("[DG]") def deepgemm_fp8_group_blockwise_gemm( + d: torch.Tensor, a: torch.Tensor, b: torch.Tensor, sfa: torch.Tensor, @@ -254,10 +252,6 @@ def deepgemm_fp8_group_blockwise_gemm( masked_m: torch.Tensor, expected_m: int, ) -> torch.Tensor: - d = torch.empty((a.shape[0], a.shape[1], b.shape[1]), - device=b.device, - dtype=torch.bfloat16) - # NOTES: shape must be `[G, M, K] @ [G, N, K].mT` assert a.stride(-1) == 1 assert b.stride(-1) == 1 @@ -287,7 +281,7 @@ def deepgemm_fp8_group_blockwise_gemm( masked_m, expected_m, disable_ue8m0_cast=True) - return d + return class DeepGemmFusedMoE(CutlassFusedMoE): @@ -341,6 +335,38 @@ def __init__( apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, ) + + def get_workspace(self, m_max: int, group_size: int): + hidden_size_0 = max(self.hidden_size, self.w3_w1_weight.shape[1] // 2) + workspace_0 = torch.empty( + (self.expert_size_per_partition * m_max * hidden_size_0), + dtype=torch.float8_e4m3fn, + device='cuda') + + hidden_size_1 = max(self.w3_w1_weight.shape[1], self.w2_weight.shape[1]) + workspace_1 = torch.empty( + (self.expert_size_per_partition * m_max * self.hidden_size), + dtype=torch.bfloat16, + device='cuda') + + alignment = 4 + scale_dim = (self.hidden_size + group_size - 1) // group_size + padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment + padded_col_size = (m_max + alignment - 1) // alignment * alignment + scale_k = (self.w3_w1_weight.shape[1] // 2 + group_size - 1) // group_size + scale_k_padded = (scale_k + alignment - 1) // alignment * alignment + row_size = max(padded_dim_size // 4, scale_k_padded // 4) + workspace_sf = torch.empty( + (self.expert_size_per_partition * row_size * padded_col_size), + dtype=torch.int32, + device='cuda') + + workspace = { + "workspace_0": workspace_0, + "workspace_1": workspace_1, + "workspace_sf": workspace_sf, + } + return workspace def _get_quant_method(self): if self.quant_config is not None and self.quant_config.layer_quant_mode.has_any_quant( @@ -362,6 +388,7 @@ def forward_chunk( output_dtype: Optional[torch.dtype] = None, all_rank_num_tokens: Optional[List[int]] = None, use_dp_padding: Optional[bool] = None, + workspace: Optional[dict] = None, ) -> torch.Tensor: if isinstance(x, Fp4QuantizedTensor): assert output_dtype is not None @@ -437,22 +464,44 @@ def forward_chunk( masked_m, token_to_expert_map = preprocess_after_permute( expert_first_token_offset_tensor, permuted_data_tensor) - m_max = (x.shape[0] + 127) // 128 * 128 expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition - act_input_fp8 = torch.empty( - (self.expert_size_per_partition, m_max, self.hidden_size), - dtype=torch.float8_e4m3fn, - device='cuda') + # prepare workspace + m_max = (x.shape[0] + 127) // 128 * 128 + act_input_fp8 = workspace["workspace_0"][0: self.expert_size_per_partition * m_max * self.hidden_size] + # act_input_fp8.view(self.expert_size_per_partition, m_max, self.hidden_size) + act_input_fp8 = act_input_fp8.as_strided( + size=(self.expert_size_per_partition, m_max, self.hidden_size), + stride=(m_max * self.hidden_size, self.hidden_size, 1), + ) + alignment = 4 + scale_dim = (self.hidden_size + 128 - 1) // 128 + padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment + padded_col_size = (m_max + alignment - 1) // alignment * alignment + act_input_sf = workspace["workspace_sf"][0: self.expert_size_per_partition * padded_dim_size // 4 * padded_col_size] + # act_input_sf.view(self.expert_size_per_partition, padded_dim_size // 4, padded_col_size) + act_input_sf = act_input_sf.as_strided( + size=(self.expert_size_per_partition, padded_dim_size // 4, padded_col_size), + stride=(padded_dim_size // 4 * padded_col_size, padded_col_size, 1), + ) act_input_sf = masked_index_copy_group_quant_fp8( act_input_fp8, + act_input_sf, permuted_data_tensor, expert_first_token_offset_tensor, token_to_expert_map, group_size=128) - - h1 = deepgemm_fp8_group_blockwise_gemm( + + # prepare workspace + h1 = workspace["workspace_1"][0: self.expert_size_per_partition * m_max * self.w3_w1_weight.shape[1]] + # h1.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1]) + h1 = h1.as_strided( + size=(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1]), + stride=(m_max * self.w3_w1_weight.shape[1], self.w3_w1_weight.shape[1], 1), + ) + deepgemm_fp8_group_blockwise_gemm( + d=h1, a=act_input_fp8, b=self.w3_w1_weight, sfa=act_input_sf, @@ -460,10 +509,36 @@ def forward_chunk( masked_m=masked_m, expected_m=expected_m, ) - act_input_fp8, act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( + + # prepare workspace + h2 = workspace["workspace_0"][0: self.expert_size_per_partition * m_max * self.w3_w1_weight.shape[1] // 2] + # h2.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1] // 2) + h2 = h2.as_strided( + size=(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1] // 2), + stride=(m_max * self.w3_w1_weight.shape[1] // 2, self.w3_w1_weight.shape[1] // 2, 1), + ) + scale_k = (self.w3_w1_weight.shape[1] // 2 + 128 - 1) // 128 + scale_k_padded = (scale_k + alignment - 1) // alignment * alignment + h2_sf = workspace["workspace_sf"][0: self.expert_size_per_partition * scale_k_padded // 4 * padded_col_size] + # h2_sf.view(self.expert_size_per_partition, scale_k_padded // 4, padded_col_size) + h2_sf = h2_sf.as_strided( + size=(self.expert_size_per_partition, scale_k_padded // 4, padded_col_size), + stride=(scale_k_padded // 4 * padded_col_size, padded_col_size, 1), + ) + act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( + output=h2, output_scale=h2_sf, input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True) - h3 = deepgemm_fp8_group_blockwise_gemm( - a=act_input_fp8, + + # prepare workspace + h3 = workspace["workspace_1"][0: self.expert_size_per_partition * m_max * self.w2_weight.shape[1]] + # h3.view(self.expert_size_per_partition, m_max, self.w2_weight.shape[1]) + h3 = h3.as_strided( + size=(self.expert_size_per_partition, m_max, self.w2_weight.shape[1]), + stride=(m_max * self.w2_weight.shape[1], self.w2_weight.shape[1], 1), + ) + deepgemm_fp8_group_blockwise_gemm( + d=h3, + a=h2, b=self.w2_weight, sfa=act_input_sf, sfb=self.quant_scales[1], @@ -495,3 +570,141 @@ def forward_chunk( ) return final_hidden_states + + def forward( + self, + x: Union[torch.Tensor, Fp4QuantizedTensor], + router_logits: torch.Tensor, + do_finalize: bool = True, # used by other MoE backends + output_dtype: Optional[torch.dtype] = None, + all_rank_num_tokens: Optional[List[int]] = None, + all_rank_max_num_tokens: Optional[int] = None, + use_dp_padding: Optional[bool] = None, + ) -> torch.Tensor: + assert do_finalize, "CutlassFusedMoE does not support do_finalize=False" + if self.use_dp and self.parallel_size > 1: + assert all_rank_num_tokens is not None + assert use_dp_padding is not None + num_rows = sum(all_rank_num_tokens) + else: + num_rows = x.shape[0] + + # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks + num_chunks = (num_rows + self.moe_max_num_tokens - + 1) // self.moe_max_num_tokens + + if use_dp_padding: + all_rank_num_tokens_padded = [all_rank_max_num_tokens + ] * len(all_rank_num_tokens) + else: + all_rank_num_tokens_padded = all_rank_num_tokens + + if num_chunks == 1: + # create workspace + num_rows = x.shape[0] + if self.use_dp: + num_rows = sum(all_rank_num_tokens_padded) + m_max = (num_rows + 127) // 128 * 128 + workspace = self.get_workspace(m_max, 128) + outputs = self.forward_chunk( + x, + router_logits, + output_dtype, + all_rank_num_tokens=all_rank_num_tokens_padded, + use_dp_padding=use_dp_padding, + workspace=workspace) + outputs = self.reducescatter_or_allreduce( + outputs, + all_rank_num_tokens=all_rank_num_tokens_padded, + use_dp_padding=use_dp_padding) + else: + if self.use_dp: + all_rank_chunk_size_list = [ + self.split_chunk(val, num_chunks) + for val in all_rank_num_tokens_padded + ] + all_rank_num_tokens_list = [[ + val[idx_chunk] for val in all_rank_chunk_size_list + ] for idx_chunk in range(num_chunks)] + chunk_size_list = all_rank_chunk_size_list[self.rank] + else: + all_rank_num_tokens_list = [None] * num_chunks + chunk_size_list = self.split_chunk(x.shape[0], num_chunks) + + # create workspace + chunk_size_0 = sum(all_rank_num_tokens_list[0]) if self.use_dp else chunk_size_list[0] + workspace_0 = self.get_workspace((chunk_size_0 + 127) // 128 * 128, 128) + chunk_size_1 = sum(all_rank_num_tokens_list[1]) if self.use_dp else chunk_size_list[1] + workspace_1 = self.get_workspace((chunk_size_1 + 127) // 128 * 128, 128) + + x_list = x.split(chunk_size_list) + router_logits_list = router_logits.split(chunk_size_list) + + self.event_dict[EventType.Main].record() + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.Main].wait() + + def _forward_chunk(x_, router_logits_, idx, workspace): + # num_rows = x_.shape[0] + # if self.use_dp: + # num_rows = sum(all_rank_num_tokens_list[idx]) + # m_max = (num_rows + 127) // 128 * 128 + # workspace = self.get_workspace(m_max, 128) + return self.forward_chunk( + x_, + router_logits_, + all_rank_num_tokens=all_rank_num_tokens_list[idx] + if self.use_dp else None, + use_dp_padding=use_dp_padding, + workspace=workspace) + + def _reducescatter_or_allreduce(x_, idx): + return self.reducescatter_or_allreduce( + x_, + all_rank_num_tokens=all_rank_num_tokens_list[idx], + use_dp_padding=use_dp_padding) + + outputs_list = [] + # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap + for idx_chunk, (x, router_logits) in enumerate( + zip(x_list, router_logits_list)): + + if idx_chunk % 2 == 0: + with torch.cuda.stream(self.aux_stream): + chunk_size = sum(all_rank_num_tokens_list[idx_chunk]) if self.use_dp else chunk_size_list[idx_chunk] + if chunk_size != chunk_size_0: + workspace_0 = self.get_workspace((chunk_size + 127) // 128 * 128, 128) + chunk_size_0 = chunk_size + outputs = _forward_chunk(x, router_logits, idx_chunk, workspace_0) + if idx_chunk > 0: + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], idx_chunk - 1) + else: + chunk_size = sum(all_rank_num_tokens_list[idx_chunk]) if self.use_dp else chunk_size_list[idx_chunk] + if chunk_size != chunk_size_1: + workspace_1 = self.get_workspace((chunk_size + 127) // 128 * 128, 128) + chunk_size_1 = chunk_size + outputs = _forward_chunk(x, router_logits, idx_chunk, workspace_1) + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], idx_chunk - 1) + + outputs_list.append(outputs) + + if num_chunks % 2 == 0: + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], -1) + else: + with torch.cuda.stream(self.aux_stream): + outputs_list[-1] = _reducescatter_or_allreduce( + outputs_list[-1], -1) + with torch.cuda.stream(self.aux_stream): + self.event_dict[EventType.MoeChunkingOverlap].record() + self.event_dict[EventType.MoeChunkingOverlap].wait() + + outputs = torch.cat(outputs_list) + + if self.use_dp and self.parallel_size > 1: + rank = self.mapping.tp_rank + outputs = outputs[:all_rank_num_tokens[rank]] + return outputs \ No newline at end of file diff --git a/tensorrt_llm/quantization/utils/fp8_utils.py b/tensorrt_llm/quantization/utils/fp8_utils.py index 19bd24671dd..4c486a15115 100644 --- a/tensorrt_llm/quantization/utils/fp8_utils.py +++ b/tensorrt_llm/quantization/utils/fp8_utils.py @@ -302,6 +302,8 @@ def _silu_and_mul_post_quant_kernel( def silu_and_mul_masked_post_quant_fwd( + output: torch.Tensor, + output_scale: torch.Tensor, input: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor, @@ -328,18 +330,6 @@ def silu_and_mul_masked_post_quant_fwd( g, m, k = input.shape k = k // 2 - # Create output - output = torch.empty((g, m, k), dtype=torch.float8_e4m3fn, device="cuda") - - # Create output scale - alignment = 4 - scale_k = ceil_div(k, quant_group_size) - m_padded = align(m, alignment) - scale_k_padded = align(scale_k, alignment) - output_scale = torch.zeros((g, scale_k_padded // 4, m_padded), - dtype=torch.int32, - device='cuda') - # Get block/grid/stage/warp expert_num = len(masked_m) @@ -382,7 +372,7 @@ def silu_and_mul_masked_post_quant_fwd( g, tma_stride_check=True, ) - return output, output_scale + return output_scale @triton.jit From c3c39d05b1115f4231d11daa30328d2c3f11b278 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 11 Aug 2025 02:24:55 -0700 Subject: [PATCH 2/5] fix. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 109 ++++++++++-------- 1 file changed, 62 insertions(+), 47 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 1e32cea0eea..7e9d2fda7f0 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -108,7 +108,7 @@ def masked_index_copy_group_quant_fp8( row_size = output.shape[0] col_size = output.shape[1] dim_size = output.shape[2] - + alignment = 4 scale_dim = (dim_size + group_size - 1) // group_size padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment @@ -335,32 +335,33 @@ def __init__( apply_router_weight_on_input=apply_router_weight_on_input, layer_idx=layer_idx, ) - + def get_workspace(self, m_max: int, group_size: int): hidden_size_0 = max(self.hidden_size, self.w3_w1_weight.shape[1] // 2) workspace_0 = torch.empty( (self.expert_size_per_partition * m_max * hidden_size_0), dtype=torch.float8_e4m3fn, device='cuda') - - hidden_size_1 = max(self.w3_w1_weight.shape[1], self.w2_weight.shape[1]) + + max(self.w3_w1_weight.shape[1], self.w2_weight.shape[1]) workspace_1 = torch.empty( (self.expert_size_per_partition * m_max * self.hidden_size), dtype=torch.bfloat16, device='cuda') - + alignment = 4 scale_dim = (self.hidden_size + group_size - 1) // group_size padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment padded_col_size = (m_max + alignment - 1) // alignment * alignment - scale_k = (self.w3_w1_weight.shape[1] // 2 + group_size - 1) // group_size + scale_k = (self.w3_w1_weight.shape[1] // 2 + group_size - + 1) // group_size scale_k_padded = (scale_k + alignment - 1) // alignment * alignment row_size = max(padded_dim_size // 4, scale_k_padded // 4) workspace_sf = torch.empty( (self.expert_size_per_partition * row_size * padded_col_size), dtype=torch.int32, device='cuda') - + workspace = { "workspace_0": workspace_0, "workspace_1": workspace_1, @@ -469,7 +470,9 @@ def forward_chunk( 1) // self.expert_size_per_partition # prepare workspace m_max = (x.shape[0] + 127) // 128 * 128 - act_input_fp8 = workspace["workspace_0"][0: self.expert_size_per_partition * m_max * self.hidden_size] + act_input_fp8 = workspace["workspace_0"][0:self. + expert_size_per_partition * + m_max * self.hidden_size] # act_input_fp8.view(self.expert_size_per_partition, m_max, self.hidden_size) act_input_fp8 = act_input_fp8.as_strided( size=(self.expert_size_per_partition, m_max, self.hidden_size), @@ -479,10 +482,14 @@ def forward_chunk( scale_dim = (self.hidden_size + 128 - 1) // 128 padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment padded_col_size = (m_max + alignment - 1) // alignment * alignment - act_input_sf = workspace["workspace_sf"][0: self.expert_size_per_partition * padded_dim_size // 4 * padded_col_size] + act_input_sf = workspace["workspace_sf"][0:self. + expert_size_per_partition * + padded_dim_size // 4 * + padded_col_size] # act_input_sf.view(self.expert_size_per_partition, padded_dim_size // 4, padded_col_size) act_input_sf = act_input_sf.as_strided( - size=(self.expert_size_per_partition, padded_dim_size // 4, padded_col_size), + size=(self.expert_size_per_partition, padded_dim_size // 4, + padded_col_size), stride=(padded_dim_size // 4 * padded_col_size, padded_col_size, 1), ) act_input_sf = masked_index_copy_group_quant_fp8( @@ -492,13 +499,16 @@ def forward_chunk( expert_first_token_offset_tensor, token_to_expert_map, group_size=128) - + # prepare workspace - h1 = workspace["workspace_1"][0: self.expert_size_per_partition * m_max * self.w3_w1_weight.shape[1]] + h1 = workspace["workspace_1"][0:self.expert_size_per_partition * m_max * + self.w3_w1_weight.shape[1]] # h1.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1]) h1 = h1.as_strided( - size=(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1]), - stride=(m_max * self.w3_w1_weight.shape[1], self.w3_w1_weight.shape[1], 1), + size=(self.expert_size_per_partition, m_max, + self.w3_w1_weight.shape[1]), + stride=(m_max * self.w3_w1_weight.shape[1], + self.w3_w1_weight.shape[1], 1), ) deepgemm_fp8_group_blockwise_gemm( d=h1, @@ -509,32 +519,44 @@ def forward_chunk( masked_m=masked_m, expected_m=expected_m, ) - + # prepare workspace - h2 = workspace["workspace_0"][0: self.expert_size_per_partition * m_max * self.w3_w1_weight.shape[1] // 2] + h2 = workspace["workspace_0"][0:self.expert_size_per_partition * m_max * + self.w3_w1_weight.shape[1] // 2] # h2.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1] // 2) h2 = h2.as_strided( - size=(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1] // 2), - stride=(m_max * self.w3_w1_weight.shape[1] // 2, self.w3_w1_weight.shape[1] // 2, 1), + size=(self.expert_size_per_partition, m_max, + self.w3_w1_weight.shape[1] // 2), + stride=(m_max * self.w3_w1_weight.shape[1] // 2, + self.w3_w1_weight.shape[1] // 2, 1), ) scale_k = (self.w3_w1_weight.shape[1] // 2 + 128 - 1) // 128 scale_k_padded = (scale_k + alignment - 1) // alignment * alignment - h2_sf = workspace["workspace_sf"][0: self.expert_size_per_partition * scale_k_padded // 4 * padded_col_size] + h2_sf = workspace["workspace_sf"][0:self.expert_size_per_partition * + scale_k_padded // 4 * padded_col_size] # h2_sf.view(self.expert_size_per_partition, scale_k_padded // 4, padded_col_size) h2_sf = h2_sf.as_strided( - size=(self.expert_size_per_partition, scale_k_padded // 4, padded_col_size), + size=(self.expert_size_per_partition, scale_k_padded // 4, + padded_col_size), stride=(scale_k_padded // 4 * padded_col_size, padded_col_size, 1), ) act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( - output=h2, output_scale=h2_sf, - input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True) - + output=h2, + output_scale=h2_sf, + input=h1, + quant_group_size=128, + masked_m=masked_m, + scale_ue8m0=True) + # prepare workspace - h3 = workspace["workspace_1"][0: self.expert_size_per_partition * m_max * self.w2_weight.shape[1]] + h3 = workspace["workspace_1"][0:self.expert_size_per_partition * m_max * + self.w2_weight.shape[1]] # h3.view(self.expert_size_per_partition, m_max, self.w2_weight.shape[1]) h3 = h3.as_strided( - size=(self.expert_size_per_partition, m_max, self.w2_weight.shape[1]), - stride=(m_max * self.w2_weight.shape[1], self.w2_weight.shape[1], 1), + size=(self.expert_size_per_partition, m_max, + self.w2_weight.shape[1]), + stride=(m_max * self.w2_weight.shape[1], self.w2_weight.shape[1], + 1), ) deepgemm_fp8_group_blockwise_gemm( d=h3, @@ -630,12 +652,16 @@ def forward( else: all_rank_num_tokens_list = [None] * num_chunks chunk_size_list = self.split_chunk(x.shape[0], num_chunks) - + # create workspace - chunk_size_0 = sum(all_rank_num_tokens_list[0]) if self.use_dp else chunk_size_list[0] - workspace_0 = self.get_workspace((chunk_size_0 + 127) // 128 * 128, 128) - chunk_size_1 = sum(all_rank_num_tokens_list[1]) if self.use_dp else chunk_size_list[1] - workspace_1 = self.get_workspace((chunk_size_1 + 127) // 128 * 128, 128) + chunk_size_0 = sum(all_rank_num_tokens_list[0] + ) if self.use_dp else chunk_size_list[0] + workspace_0 = self.get_workspace((chunk_size_0 + 127) // 128 * 128, + 128) + chunk_size_1 = sum(all_rank_num_tokens_list[1] + ) if self.use_dp else chunk_size_list[1] + workspace_1 = self.get_workspace((chunk_size_1 + 127) // 128 * 128, + 128) x_list = x.split(chunk_size_list) router_logits_list = router_logits.split(chunk_size_list) @@ -645,11 +671,6 @@ def forward( self.event_dict[EventType.Main].wait() def _forward_chunk(x_, router_logits_, idx, workspace): - # num_rows = x_.shape[0] - # if self.use_dp: - # num_rows = sum(all_rank_num_tokens_list[idx]) - # m_max = (num_rows + 127) // 128 * 128 - # workspace = self.get_workspace(m_max, 128) return self.forward_chunk( x_, router_logits_, @@ -671,20 +692,14 @@ def _reducescatter_or_allreduce(x_, idx): if idx_chunk % 2 == 0: with torch.cuda.stream(self.aux_stream): - chunk_size = sum(all_rank_num_tokens_list[idx_chunk]) if self.use_dp else chunk_size_list[idx_chunk] - if chunk_size != chunk_size_0: - workspace_0 = self.get_workspace((chunk_size + 127) // 128 * 128, 128) - chunk_size_0 = chunk_size - outputs = _forward_chunk(x, router_logits, idx_chunk, workspace_0) + outputs = _forward_chunk(x, router_logits, idx_chunk, + workspace_0) if idx_chunk > 0: outputs_list[-1] = _reducescatter_or_allreduce( outputs_list[-1], idx_chunk - 1) else: - chunk_size = sum(all_rank_num_tokens_list[idx_chunk]) if self.use_dp else chunk_size_list[idx_chunk] - if chunk_size != chunk_size_1: - workspace_1 = self.get_workspace((chunk_size + 127) // 128 * 128, 128) - chunk_size_1 = chunk_size - outputs = _forward_chunk(x, router_logits, idx_chunk, workspace_1) + outputs = _forward_chunk(x, router_logits, idx_chunk, + workspace_1) with torch.cuda.stream(self.aux_stream): outputs_list[-1] = _reducescatter_or_allreduce( outputs_list[-1], idx_chunk - 1) @@ -707,4 +722,4 @@ def _reducescatter_or_allreduce(x_, idx): if self.use_dp and self.parallel_size > 1: rank = self.mapping.tp_rank outputs = outputs[:all_rank_num_tokens[rank]] - return outputs \ No newline at end of file + return outputs From d171a49b1bc2a09079b0be34e60b3afb52f9f8bd Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 11 Aug 2025 20:30:56 -0700 Subject: [PATCH 3/5] code clean. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../modules/fused_moe/fused_moe_deepgemm.py | 155 ++++++++---------- 1 file changed, 67 insertions(+), 88 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 7e9d2fda7f0..6b3f9932f08 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -284,6 +284,15 @@ def deepgemm_fp8_group_blockwise_gemm( return +def set_strides(workspace: torch.Tensor, g: int, m: int, k: int): + workspace = workspace[0:g * m * k] + workspace = workspace.as_strided( + size=(g, m, k), + stride=(m * k, k, 1), + ) + return workspace + + class DeepGemmFusedMoE(CutlassFusedMoE): """ Python Flow of Fused Mixture of Experts (MoE) Layer. @@ -337,28 +346,26 @@ def __init__( ) def get_workspace(self, m_max: int, group_size: int): - hidden_size_0 = max(self.hidden_size, self.w3_w1_weight.shape[1] // 2) - workspace_0 = torch.empty( - (self.expert_size_per_partition * m_max * hidden_size_0), - dtype=torch.float8_e4m3fn, - device='cuda') - - max(self.w3_w1_weight.shape[1], self.w2_weight.shape[1]) + hidden_size = self.hidden_size + intermediate_size = self.intermediate_size + num_experts = self.expert_size_per_partition + + # create workspace + fp8_dim = max(hidden_size, intermediate_size) + workspace_0 = torch.empty((num_experts * m_max * fp8_dim), + dtype=torch.float8_e4m3fn, + device='cuda') workspace_1 = torch.empty( - (self.expert_size_per_partition * m_max * self.hidden_size), + (num_experts * m_max * max(intermediate_size * 2, hidden_size)), dtype=torch.bfloat16, device='cuda') - alignment = 4 - scale_dim = (self.hidden_size + group_size - 1) // group_size - padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment - padded_col_size = (m_max + alignment - 1) // alignment * alignment - scale_k = (self.w3_w1_weight.shape[1] // 2 + group_size - - 1) // group_size - scale_k_padded = (scale_k + alignment - 1) // alignment * alignment - row_size = max(padded_dim_size // 4, scale_k_padded // 4) + # create workspace for scaling factors + m_padded = fp8_utils.align(m_max, 4) + scale_k = fp8_utils.ceil_div(fp8_dim, group_size) + scale_k_padded = fp8_utils.align(scale_k, 4) workspace_sf = torch.empty( - (self.expert_size_per_partition * row_size * padded_col_size), + (num_experts * (scale_k_padded // 4) * m_padded), dtype=torch.int32, device='cuda') @@ -468,30 +475,20 @@ def forward_chunk( expected_m = (token_selected_experts.numel() + self.expert_size_per_partition - 1) // self.expert_size_per_partition - # prepare workspace - m_max = (x.shape[0] + 127) // 128 * 128 - act_input_fp8 = workspace["workspace_0"][0:self. - expert_size_per_partition * - m_max * self.hidden_size] - # act_input_fp8.view(self.expert_size_per_partition, m_max, self.hidden_size) - act_input_fp8 = act_input_fp8.as_strided( - size=(self.expert_size_per_partition, m_max, self.hidden_size), - stride=(m_max * self.hidden_size, self.hidden_size, 1), - ) - alignment = 4 - scale_dim = (self.hidden_size + 128 - 1) // 128 - padded_dim_size = (scale_dim + alignment - 1) // alignment * alignment - padded_col_size = (m_max + alignment - 1) // alignment * alignment - act_input_sf = workspace["workspace_sf"][0:self. - expert_size_per_partition * - padded_dim_size // 4 * - padded_col_size] - # act_input_sf.view(self.expert_size_per_partition, padded_dim_size // 4, padded_col_size) - act_input_sf = act_input_sf.as_strided( - size=(self.expert_size_per_partition, padded_dim_size // 4, - padded_col_size), - stride=(padded_dim_size // 4 * padded_col_size, padded_col_size, 1), - ) + + # padding and quantization + m_max = fp8_utils.align(x.shape[0], 128) + act_input_fp8 = set_strides(workspace["workspace_0"], + self.expert_size_per_partition, m_max, + self.hidden_size) + + m_padded = fp8_utils.align(m_max, 4) + scale_k = fp8_utils.ceil_div(self.hidden_size, 128) + scale_k_padded = fp8_utils.align(scale_k, 4) + act_input_sf = set_strides(workspace["workspace_sf"], + self.expert_size_per_partition, + scale_k_padded // 4, m_padded) + act_input_sf = masked_index_copy_group_quant_fp8( act_input_fp8, act_input_sf, @@ -500,16 +497,11 @@ def forward_chunk( token_to_expert_map, group_size=128) - # prepare workspace - h1 = workspace["workspace_1"][0:self.expert_size_per_partition * m_max * - self.w3_w1_weight.shape[1]] - # h1.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1]) - h1 = h1.as_strided( - size=(self.expert_size_per_partition, m_max, - self.w3_w1_weight.shape[1]), - stride=(m_max * self.w3_w1_weight.shape[1], - self.w3_w1_weight.shape[1], 1), - ) + # grouped gemm 1 + h1 = set_strides(workspace["workspace_1"], + self.expert_size_per_partition, m_max, + self.intermediate_size * 2) + deepgemm_fp8_group_blockwise_gemm( d=h1, a=act_input_fp8, @@ -520,47 +512,33 @@ def forward_chunk( expected_m=expected_m, ) - # prepare workspace - h2 = workspace["workspace_0"][0:self.expert_size_per_partition * m_max * - self.w3_w1_weight.shape[1] // 2] - # h2.view(self.expert_size_per_partition, m_max, self.w3_w1_weight.shape[1] // 2) - h2 = h2.as_strided( - size=(self.expert_size_per_partition, m_max, - self.w3_w1_weight.shape[1] // 2), - stride=(m_max * self.w3_w1_weight.shape[1] // 2, - self.w3_w1_weight.shape[1] // 2, 1), - ) - scale_k = (self.w3_w1_weight.shape[1] // 2 + 128 - 1) // 128 - scale_k_padded = (scale_k + alignment - 1) // alignment * alignment - h2_sf = workspace["workspace_sf"][0:self.expert_size_per_partition * - scale_k_padded // 4 * padded_col_size] - # h2_sf.view(self.expert_size_per_partition, scale_k_padded // 4, padded_col_size) - h2_sf = h2_sf.as_strided( - size=(self.expert_size_per_partition, scale_k_padded // 4, - padded_col_size), - stride=(scale_k_padded // 4 * padded_col_size, padded_col_size, 1), - ) + # activation and quantization + act_input_fp8 = set_strides(workspace["workspace_0"], + self.expert_size_per_partition, m_max, + self.intermediate_size) + + scale_k = fp8_utils.ceil_div(self.intermediate_size, 128) + scale_k_padded = fp8_utils.align(scale_k, 4) + act_input_sf = set_strides(workspace["workspace_sf"], + self.expert_size_per_partition, + scale_k_padded // 4, m_padded) + act_input_sf = fp8_utils.silu_and_mul_masked_post_quant_fwd( - output=h2, - output_scale=h2_sf, + output=act_input_fp8, + output_scale=act_input_sf, input=h1, quant_group_size=128, masked_m=masked_m, scale_ue8m0=True) - # prepare workspace - h3 = workspace["workspace_1"][0:self.expert_size_per_partition * m_max * - self.w2_weight.shape[1]] - # h3.view(self.expert_size_per_partition, m_max, self.w2_weight.shape[1]) - h3 = h3.as_strided( - size=(self.expert_size_per_partition, m_max, - self.w2_weight.shape[1]), - stride=(m_max * self.w2_weight.shape[1], self.w2_weight.shape[1], - 1), - ) + # grouped gemm 2 + h3 = set_strides(workspace["workspace_1"], + self.expert_size_per_partition, m_max, + self.hidden_size) + deepgemm_fp8_group_blockwise_gemm( d=h3, - a=h2, + a=act_input_fp8, b=self.w2_weight, sfa=act_input_sf, sfb=self.quant_scales[1], @@ -568,6 +546,7 @@ def forward_chunk( expected_m=expected_m, ) + # gather and finalize triton_masked_index_gather(permuted_data_tensor, h3, expert_first_token_offset_tensor, token_to_expert_map) @@ -626,7 +605,7 @@ def forward( num_rows = x.shape[0] if self.use_dp: num_rows = sum(all_rank_num_tokens_padded) - m_max = (num_rows + 127) // 128 * 128 + m_max = fp8_utils.align(num_rows, 128) workspace = self.get_workspace(m_max, 128) outputs = self.forward_chunk( x, @@ -656,11 +635,11 @@ def forward( # create workspace chunk_size_0 = sum(all_rank_num_tokens_list[0] ) if self.use_dp else chunk_size_list[0] - workspace_0 = self.get_workspace((chunk_size_0 + 127) // 128 * 128, - 128) chunk_size_1 = sum(all_rank_num_tokens_list[1] ) if self.use_dp else chunk_size_list[1] - workspace_1 = self.get_workspace((chunk_size_1 + 127) // 128 * 128, + workspace_0 = self.get_workspace(fp8_utils.align(chunk_size_0, 128), + 128) + workspace_1 = self.get_workspace(fp8_utils.align(chunk_size_1, 128), 128) x_list = x.split(chunk_size_list) From 306b5c6ae7abbe73bcf64710d6b3061d6318554b Mon Sep 17 00:00:00 2001 From: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> Date: Tue, 12 Aug 2025 04:16:06 +0000 Subject: [PATCH 4/5] Update default_moe_max_num_tokens. Signed-off-by: Yuxian Qiu <142763828+yuxianq@users.noreply.github.com> --- .../_torch/modules/fused_moe/fused_moe_cutlass.py | 7 +++---- .../_torch/modules/fused_moe/fused_moe_deepgemm.py | 10 ++++++++++ .../_torch/modules/fused_moe/fused_moe_vanilla.py | 7 ++----- .../_torch/modules/fused_moe/fused_moe_wide_ep.py | 6 +++--- 4 files changed, 18 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index 56274f875f4..d7be42b170f 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -112,11 +112,10 @@ def __init__( max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - if self.use_dp: - max_num_tokens *= model_config.mapping.world_size - self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied - if self.moe_max_num_tokens < max_num_tokens: + if self.moe_max_num_tokens < moe_max_num_tokens: self.aux_stream = aux_stream_dict[ AuxStreamType. MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream( diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 6b3f9932f08..03afe554ddf 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -330,6 +330,16 @@ def __init__( apply_router_weight_on_input: bool = False, layer_idx: Optional[int] = None, ): + if model_config.moe_max_num_tokens is None: + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + # The default moe_max_num_tokens is calculated from the following formula: + # max_isl = 8196, max_batch_size = 1024, mtp = 0 + # max_num_tokens = ((mtp+1)*max_batch_size+max_isl+128+63)//64*64 = 9344 + # moe_max_num_tokens = max_num_tokens * 2 = 18688 + # It can avoid OOM for 8k/1k cases. + default_moe_max_num_tokens = 18688 + if moe_max_num_tokens > default_moe_max_num_tokens: + model_config.moe_max_num_tokens = default_moe_max_num_tokens super().__init__( routing_method=routing_method, diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py index 3249bac979b..337057c0f4d 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py @@ -83,11 +83,8 @@ def __init__( max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - if self.use_dp: - max_num_tokens *= model_config.mapping.world_size - self.moe_max_num_tokens = (model_config.moe_max_num_tokens - if model_config.moe_max_num_tokens - is not None else max_num_tokens) + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens self._weights_created = False if not model_config.skip_create_weights_in_init: diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index c74c8966b68..c0b3919e237 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -152,10 +152,10 @@ def __init__( max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled - max_num_tokens *= model_config.mapping.world_size - self.moe_max_num_tokens = model_config.moe_max_num_tokens if model_config.moe_max_num_tokens is not None else max_num_tokens + moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size + self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens # The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied - if self.moe_max_num_tokens < max_num_tokens: + if self.moe_max_num_tokens < moe_max_num_tokens: self.aux_stream = aux_stream_dict[ AuxStreamType. MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream( From d6150c5df17ad519ebef1e40734485c5695bde50 Mon Sep 17 00:00:00 2001 From: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> Date: Mon, 11 Aug 2025 22:07:36 -0700 Subject: [PATCH 5/5] Only switch to chunked dg moe when num_rows is greater than self.moe_max_num_tokens * 2. Signed-off-by: Fanrong Li <23290157+lfr-0531@users.noreply.github.com> --- .../_torch/modules/fused_moe/fused_moe_cutlass.py | 1 - .../_torch/modules/fused_moe/fused_moe_deepgemm.py | 11 ++++++++--- .../_torch/modules/fused_moe/fused_moe_vanilla.py | 1 - .../_torch/modules/fused_moe/fused_moe_wide_ep.py | 1 - tensorrt_llm/mapping.py | 4 ++++ 5 files changed, 12 insertions(+), 6 deletions(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py index d7be42b170f..2e257c306ae 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py @@ -110,7 +110,6 @@ def __init__( assert len( self.initial_local_expert_ids) == self.expert_size_per_partition - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py index 03afe554ddf..a5ca05694b9 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_deepgemm.py @@ -339,7 +339,9 @@ def __init__( # It can avoid OOM for 8k/1k cases. default_moe_max_num_tokens = 18688 if moe_max_num_tokens > default_moe_max_num_tokens: + model_config._frozen = False model_config.moe_max_num_tokens = default_moe_max_num_tokens + model_config._frozen = True super().__init__( routing_method=routing_method, @@ -600,9 +602,12 @@ def forward( else: num_rows = x.shape[0] - # in case of num_rows is larger than max_chunk_size, we need to split the input into multiple chunks - num_chunks = (num_rows + self.moe_max_num_tokens - - 1) // self.moe_max_num_tokens + # In case of num_rows is larger than max_chunk_size * 2, we need to split the input into multiple chunks. + # Because we will use two streams in chunked moe and preallocate two workspaces. + num_chunks = 1 + if num_rows > self.moe_max_num_tokens * 2: + num_chunks = (num_rows + self.moe_max_num_tokens - + 1) // self.moe_max_num_tokens if use_dp_padding: all_rank_num_tokens_padded = [all_rank_max_num_tokens diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py index 337057c0f4d..ed6f11993b2 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_vanilla.py @@ -81,7 +81,6 @@ def __init__( self.num_experts) self.expert_size_per_partition = self.expert_end - self.expert_start - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index c0b3919e237..4eb9d77606b 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -150,7 +150,6 @@ def __init__( assert len( self.initial_local_expert_ids) == self.expert_size_per_partition - max_num_tokens = model_config.max_num_tokens # The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens diff --git a/tensorrt_llm/mapping.py b/tensorrt_llm/mapping.py index f78fe093f71..22824ea350d 100644 --- a/tensorrt_llm/mapping.py +++ b/tensorrt_llm/mapping.py @@ -372,6 +372,10 @@ def node_rank(self): def local_rank(self): return self.rank % self.gpus_per_node + @property + def dp_size(self): + return self.tp_size if self.enable_attention_dp else 1 + def has_cp(self): return self.cp_size > 1