Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 47 additions & 7 deletions tensorrt_llm/_torch/models/modeling_nemotron_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,29 +149,36 @@ def forward(
position_ids: torch.IntTensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if not self.block_config.attention.no_op:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)

hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
**kwargs,
)
hidden_states = residual + hidden_states

if not self.block_config.ffn.no_op:
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
if residual is None:
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
else:
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states, **kwargs)
hidden_states = residual + hidden_states

return hidden_states
return hidden_states, residual


class NemotronNASModel(DecoderModel):
Expand Down Expand Up @@ -225,6 +232,39 @@ def __init__(self, model_config):
eps=config.rms_norm_eps,
dtype=config.torch_dtype)

def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: Optional[torch.IntTensor] = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
lora_params: Optional[dict] = None,
**kwargs,
) -> torch.Tensor:

if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

hidden_states = inputs_embeds
residual = None

for decoder_layer in self.layers:
hidden_states, residual = decoder_layer(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual,
lora_params=lora_params,
)

hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


@register_auto_model("DeciLMForCausalLM")
class NemotronNASForCausalLM(DecoderModelForCausalLM[NemotronNASModel,
Expand Down