Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tensorrt_llm/_mnnvl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ def mnnvl_moe_alltoallv_combine(
top_k: int,
token_count: int,
use_low_precision_combine: bool = False,
do_reduce: bool = True,
):
assert x.dim() == 2, "2D tensor supported, please reshape."
output_tensors = torch.ops.trtllm.moe_comm(
Expand All @@ -614,7 +615,8 @@ def mnnvl_moe_alltoallv_combine(
[True],
use_low_precision_combine,
)
output_tensor = output_tensors[0]
return torch.sum(
output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False
)
output_tensor = output_tensors[0].reshape(token_count, top_k, x.shape[1])
if do_reduce:
return torch.sum(output_tensor, dim=1, keepdim=False)
else:
return output_tensor
22 changes: 19 additions & 3 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def weight_dequant(x: torch.Tensor,
return y


@torch.compile(dynamic=True)
def moe_reduce_add_shared_output(routed_output, shared_output):
routed_output = torch.sum(routed_output, dim=1, keepdim=False)
return shared_output + routed_output


class DeepseekV3MTPHead(nn.Module):

def __init__(self, model_config: ModelConfig[PretrainedConfig]):
Expand Down Expand Up @@ -585,6 +591,8 @@ def _compute_routed_output():
do_finalize)
return routed_output

# NOTE: define compiled helpers at module scope to avoid defining decorators inside compiled frames

routed_output, shared_output = maybe_execute_in_parallel(
_compute_routed_output, _compute_shared_output,
self.event_dict[EventType.Main],
Expand All @@ -593,9 +601,17 @@ def _compute_routed_output():
if not do_finalize:
return [shared_output, *routed_output]
else:
assert shared_output.size() == routed_output.size(
), f'unmatched tensor shape'
final_hidden_states = shared_output + routed_output
if routed_output.dim() == 3:
assert shared_output.numel(
) * self.top_k == routed_output.numel(
), 'unmatched tensor shape'
final_hidden_states = moe_reduce_add_shared_output(
routed_output, shared_output)
else:
assert shared_output.size() == routed_output.size(
), 'unmatched tensor shape'
final_hidden_states = shared_output + routed_output

if not self.use_dp and self.mapping.tp_size > 1:
final_hidden_states = self.allreduce(
final_hidden_states,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -906,7 +906,8 @@ def alltoall_combine(self, final_hidden_states: torch.Tensor,
ep_size=self.ep_size,
top_k=top_k,
token_count=token_count,
use_low_precision_combine=self.use_low_precision_combine)
use_low_precision_combine=self.use_low_precision_combine,
do_reduce=False)

return final_hidden_states

Expand Down
Loading