Skip to content

Commit cea5dd1

Browse files
authored
[TRTLLM-5835][feat] Optimized Mamba2Mixer prefill (#5128)
Signed-off-by: Tomer Asida <[email protected]>
1 parent dd29063 commit cea5dd1

File tree

4 files changed

+183
-156
lines changed

4 files changed

+183
-156
lines changed

tensorrt_llm/_torch/models/modeling_nemotron_h.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
from torch.nn import functional as F
2121
from transformers import AutoConfig, PretrainedConfig
2222

23+
from tensorrt_llm._torch.modules.mamba.mamba2_metadata import Mamba2Metadata
24+
2325
from ..attention_backend import AttentionMetadata
2426
from ..model_config import ModelConfig
2527
from ..modules.attention import Attention
@@ -71,6 +73,7 @@ def forward(
7173
self,
7274
hidden_states: torch.Tensor,
7375
attn_metadata: AttentionMetadata,
76+
**kwargs,
7477
) -> torch.Tensor:
7578
return super().forward(hidden_states)
7679

@@ -99,6 +102,7 @@ def forward(
99102
self,
100103
hidden_states: torch.Tensor,
101104
attn_metadata: AttentionMetadata,
105+
**kwargs,
102106
) -> torch.Tensor:
103107
return super().forward(position_ids=None,
104108
hidden_states=hidden_states,
@@ -153,12 +157,13 @@ def forward(
153157
position_ids: torch.IntTensor,
154158
hidden_states: torch.Tensor,
155159
attn_metadata: AttentionMetadata,
160+
**kwargs,
156161
) -> torch.Tensor:
157162

158163
residual = hidden_states
159164

160165
hidden_states = self.norm(hidden_states)
161-
hidden_states = self.mixer(hidden_states, attn_metadata)
166+
hidden_states = self.mixer(hidden_states, attn_metadata, **kwargs)
162167
hidden_states = torch.add(hidden_states, residual)
163168

164169
return hidden_states
@@ -190,6 +195,8 @@ def __init__(self, model_config: ModelConfig[NemotronHConfig]):
190195
dtype=config.torch_dtype,
191196
)
192197

198+
self.mamba_metadata: Optional[Mamba2Metadata] = None
199+
193200
def forward(
194201
self,
195202
attn_metadata: AttentionMetadata,
@@ -203,13 +210,20 @@ def forward(
203210
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
204211
)
205212

213+
if self.mamba_metadata is None or self.mamba_metadata.max_batch_size != attn_metadata.max_num_requests:
214+
self.mamba_metadata = Mamba2Metadata(attn_metadata.max_num_requests)
215+
self.mamba_metadata.prepare(attn_metadata)
216+
206217
if inputs_embeds is None:
207218
inputs_embeds = self.embed_tokens(input_ids)
208219

209220
hidden_states = inputs_embeds
210221

211222
for layer in self.layers:
212-
hidden_states = layer(position_ids, hidden_states, attn_metadata)
223+
hidden_states = layer(position_ids,
224+
hidden_states,
225+
attn_metadata,
226+
mamba_metadata=self.mamba_metadata)
213227

214228
hidden_states = self.norm_f(hidden_states)
215229

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
18+
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
19+
20+
21+
class Mamba2Metadata:
22+
23+
def __init__(self, max_batch_size: int):
24+
self.max_batch_size = max_batch_size
25+
26+
# cumulative sequence lengths for prefill requests [batch_size+1]
27+
self.cu_seqlens = torch.zeros(max_batch_size + 1,
28+
dtype=torch.int,
29+
device="cuda")
30+
31+
# sequence index for prefill requests [num_prefill_tokens] - specifies which request each token belongs to
32+
self.seq_idx: torch.Tensor = None
33+
34+
def prepare(self, attn_metadata: AttentionMetadata):
35+
num_contexts = attn_metadata.num_contexts
36+
context_lens = attn_metadata.seq_lens_cuda[:num_contexts]
37+
if num_contexts > 0:
38+
torch.cumsum(context_lens,
39+
dim=0,
40+
dtype=torch.int,
41+
out=self.cu_seqlens[1:num_contexts + 1])
42+
self.seq_idx = torch.repeat_interleave(
43+
torch.arange(num_contexts,
44+
dtype=torch.int,
45+
device=self.cu_seqlens.device),
46+
repeats=context_lens,
47+
output_size=self.cu_seqlens[num_contexts]).unsqueeze(0)

0 commit comments

Comments
 (0)