diff --git a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py index 3ab1cdb37ca..cbd5ff4a964 100644 --- a/tensorrt_llm/_torch/models/modeling_nemotron_nas.py +++ b/tensorrt_llm/_torch/models/modeling_nemotron_nas.py @@ -149,12 +149,17 @@ def forward( position_ids: torch.IntTensor, hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor] = None, **kwargs, ) -> torch.Tensor: if not self.block_config.attention.no_op: # Self Attention - residual = hidden_states - hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) hidden_states = self.self_attn( position_ids=position_ids, @@ -162,16 +167,18 @@ def forward( attn_metadata=attn_metadata, **kwargs, ) - hidden_states = residual + hidden_states if not self.block_config.ffn.no_op: # Fully Connected - residual = hidden_states - hidden_states = self.post_attention_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + else: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) hidden_states = self.mlp(hidden_states, **kwargs) - hidden_states = residual + hidden_states - return hidden_states + return hidden_states, residual class NemotronNASModel(DecoderModel): @@ -225,6 +232,39 @@ def __init__(self, model_config): eps=config.rms_norm_eps, dtype=config.torch_dtype) + def forward( + self, + attn_metadata: AttentionMetadata, + input_ids: Optional[torch.IntTensor] = None, + position_ids: Optional[torch.IntTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + lora_params: Optional[dict] = None, + **kwargs, + ) -> torch.Tensor: + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + residual = None + + for decoder_layer in self.layers: + hidden_states, residual = decoder_layer( + position_ids=position_ids, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + residual=residual, + lora_params=lora_params, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + @register_auto_model("DeciLMForCausalLM") class NemotronNASForCausalLM(DecoderModelForCausalLM[NemotronNASModel,