2020from torch .nn import functional as F
2121from transformers import AutoConfig , PretrainedConfig
2222
23+ from tensorrt_llm ._torch .modules .mamba .mamba2_metadata import Mamba2Metadata
24+
2325from ..attention_backend import AttentionMetadata
2426from ..model_config import ModelConfig
2527from ..modules .attention import Attention
@@ -71,6 +73,7 @@ def forward(
7173 self ,
7274 hidden_states : torch .Tensor ,
7375 attn_metadata : AttentionMetadata ,
76+ ** kwargs ,
7477 ) -> torch .Tensor :
7578 return super ().forward (hidden_states )
7679
@@ -99,6 +102,7 @@ def forward(
99102 self ,
100103 hidden_states : torch .Tensor ,
101104 attn_metadata : AttentionMetadata ,
105+ ** kwargs ,
102106 ) -> torch .Tensor :
103107 return super ().forward (position_ids = None ,
104108 hidden_states = hidden_states ,
@@ -153,12 +157,13 @@ def forward(
153157 position_ids : torch .IntTensor ,
154158 hidden_states : torch .Tensor ,
155159 attn_metadata : AttentionMetadata ,
160+ ** kwargs ,
156161 ) -> torch .Tensor :
157162
158163 residual = hidden_states
159164
160165 hidden_states = self .norm (hidden_states )
161- hidden_states = self .mixer (hidden_states , attn_metadata )
166+ hidden_states = self .mixer (hidden_states , attn_metadata , ** kwargs )
162167 hidden_states = torch .add (hidden_states , residual )
163168
164169 return hidden_states
@@ -190,6 +195,8 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
190195 dtype = config .torch_dtype ,
191196 )
192197
198+ self .mamba_metadata : Optional [Mamba2Metadata ] = None
199+
193200 def forward (
194201 self ,
195202 attn_metadata : AttentionMetadata ,
@@ -203,13 +210,20 @@ def forward(
203210 "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
204211 )
205212
213+ if self .mamba_metadata is None or self .mamba_metadata .max_batch_size != attn_metadata .max_num_requests :
214+ self .mamba_metadata = Mamba2Metadata (attn_metadata .max_num_requests )
215+ self .mamba_metadata .prepare (attn_metadata )
216+
206217 if inputs_embeds is None :
207218 inputs_embeds = self .embed_tokens (input_ids )
208219
209220 hidden_states = inputs_embeds
210221
211222 for layer in self .layers :
212- hidden_states = layer (position_ids , hidden_states , attn_metadata )
223+ hidden_states = layer (position_ids ,
224+ hidden_states ,
225+ attn_metadata ,
226+ mamba_metadata = self .mamba_metadata )
213227
214228 hidden_states = self .norm_f (hidden_states )
215229
0 commit comments