Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 83 additions & 57 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,50 +550,60 @@ def forward(
hidden_states, residual)

if (self.fusion_config.POST_MOE_FUSION
or self.fusion_config.POST_MLP_FUSION
) and self.next_layer_layernorm is not None:
# Get the scale for the next allreduce fusion op
if self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant):
scale = self.next_attn.qkv_proj.input_scale
else:
# Add just the fusion op to RESIDUAL_RMS_NORM due to this is the last decoder layer
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
scale = None

# TODO: MIN_LATENCY_MODE is hardcoded to False
if cutlass_min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]

allreduce_output = self.moe_allreduce(
residual,
self.next_layer_layernorm.weight,
device_num_experts=num_activated_experts_per_node,
scale_input=experts_to_token_score,
active_experts_token_input=hidden_states_activated_experts,
token_input=shared_output,
eps=self.next_layer_layernorm.variance_epsilon,
)
else:
allreduce_output = self.all_reduce(
or self.fusion_config.POST_MLP_FUSION):
# If there is no extra layernorm, do another pure allreduce because
# the allreduce in feed-forward module has been disabled.
if self.next_layer_layernorm is None:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.post_feed_forward_fusion_op,
fusion_op=None,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=scale,
eps=self.next_layer_layernorm.variance_epsilon,
))

# Unpack the allreduce output
if self.next_attn is not None and self.is_nvfp4:
act_fp4, act_sf, residual = allreduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
hidden_states, residual = allreduce_output
# The next layernorm exists but it could be the last decoder layer.
# Adjust the scale and fusion pattern.
if self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant):
scale = self.next_attn.qkv_proj.input_scale
else:
self.post_feed_forward_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
scale = None

# TODO: MIN_LATENCY_MODE is hardcoded to False
if cutlass_min_latency_mode:
shared_output = hidden_states[0]
hidden_states_activated_experts = hidden_states[1]
num_activated_experts_per_node = hidden_states[2]
experts_to_token_score = hidden_states[3]

allreduce_output = self.moe_allreduce(
residual,
self.next_layer_layernorm.weight,
device_num_experts=num_activated_experts_per_node,
scale_input=experts_to_token_score,
active_experts_token_input=
hidden_states_activated_experts,
token_input=shared_output,
eps=self.next_layer_layernorm.variance_epsilon,
)
else:
allreduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.post_feed_forward_fusion_op,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=scale,
eps=self.next_layer_layernorm.variance_epsilon,
))

# Unpack the allreduce output
if self.next_attn is not None and self.is_nvfp4:
act_fp4, act_sf, residual = allreduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
hidden_states, residual = allreduce_output
elif self.next_layer_layernorm:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
Expand Down Expand Up @@ -706,6 +716,7 @@ def forward(
scale = self.mlp.gate_up_proj.input_scale
else:
scale = None

all_reduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
Expand Down Expand Up @@ -748,25 +759,40 @@ def forward(

spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
if self.POST_MLP_FUSION and self.next_attn is not None:
if self.is_nvfp4 or self.is_fp8_quant:
scale = self.next_attn.qkv_proj.input_scale
else:
scale = None
all_reduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.post_mlp_fusion_op,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=scale,
eps=self.next_layer_layernorm.variance_epsilon,
))
if self.is_nvfp4:
act_fp4, act_sf, residual = all_reduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)

if self.POST_MLP_FUSION:
# If there is no extra layernorm, do another pure allreduce.
if self.next_layer_layernorm is None:
hidden_states, residual = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=None,
residual=residual,
))
else:
hidden_states, residual = all_reduce_output
# The next layernorm exists but it could be the last decoder layer.
# Adjust the scale and fusion pattern.
if self.next_attn is not None and (self.is_nvfp4
or self.is_fp8_quant):
scale = self.next_attn.qkv_proj.input_scale
else:
self.post_mlp_fusion_op = AllReduceFusionOp.RESIDUAL_RMS_NORM
scale = None

all_reduce_output = self.all_reduce(
hidden_states,
all_reduce_params=AllReduceParams(
fusion_op=self.post_mlp_fusion_op,
residual=residual,
norm_weight=self.next_layer_layernorm.weight,
scale=scale,
eps=self.next_layer_layernorm.variance_epsilon,
))
if self.next_attn is not None and self.is_nvfp4:
act_fp4, act_sf, residual = all_reduce_output
hidden_states = Fp4QuantizedTensor(act_fp4, act_sf)
else:
hidden_states, residual = all_reduce_output
elif self.next_layer_layernorm:
hidden_states, residual = self.next_layer_layernorm(
hidden_states, residual)
Expand Down