Skip to content

Commit c9dca69

Browse files
dominicshanshanNaveassafyechank-nvidiabrb-nvTabrizian
authored
[None][chore] Mass integration of release/1.0 - 3rd (#7519)
Signed-off-by: Nave Assaf <[email protected]> Signed-off-by: Wangshanshan <[email protected]> Signed-off-by: yechank <[email protected]> Signed-off-by: Balaram Buddharaju <[email protected]> Signed-off-by: Iman Tabrizian <[email protected]> Signed-off-by: qqiao <[email protected]> Signed-off-by: Superjomn <[email protected]> Signed-off-by: Bo Deng <[email protected]> Signed-off-by: Jin Li <[email protected]> Signed-off-by: Yifei Zhang <[email protected]> Signed-off-by: Amit Zuker <[email protected]> Signed-off-by: Erin Ho <[email protected]> Signed-off-by: Chenfei Zhang <[email protected]> Signed-off-by: Christina Zhang <[email protected]> Signed-off-by: Venky Ganesh <[email protected]> Signed-off-by: Pamela <[email protected]> Signed-off-by: Hui Gao <[email protected]> Signed-off-by: Alexandre Milesi <[email protected]> Signed-off-by: Shixiaowei02 <[email protected]> Signed-off-by: Michal Guzek <[email protected]> Signed-off-by: peaceh <[email protected]> Signed-off-by: nv-guomingz <[email protected]> Signed-off-by: Wanli Jiang <[email protected]> Signed-off-by: Patrice Castonguay <[email protected]> Signed-off-by: ruodil <[email protected]> Signed-off-by: Linda-Stadter <[email protected]> Signed-off-by: Yuxian Qiu <[email protected]> Signed-off-by: Jiagan Cheng <[email protected]> Signed-off-by: William Zhang <[email protected]> Signed-off-by: Dom Brown <[email protected]> Co-authored-by: Nave Assaf <[email protected]> Co-authored-by: Yechan Kim <[email protected]> Co-authored-by: brb-nv <[email protected]> Co-authored-by: Iman Tabrizian <[email protected]> Co-authored-by: Emma Qiao <[email protected]> Co-authored-by: Yan Chunwei <[email protected]> Co-authored-by: Bo Deng <[email protected]> Co-authored-by: Jin Li <[email protected]> Co-authored-by: yifeizhang-c <[email protected]> Co-authored-by: amitz-nv <[email protected]> Co-authored-by: Erin <[email protected]> Co-authored-by: chenfeiz0326 <[email protected]> Co-authored-by: ChristinaZ <[email protected]> Co-authored-by: Venky <[email protected]> Co-authored-by: Pamela Peng <[email protected]> Co-authored-by: HuiGao-NV <[email protected]> Co-authored-by: milesial <[email protected]> Co-authored-by: Shi Xiaowei <[email protected]> Co-authored-by: Michal Guzek <[email protected]> Co-authored-by: peaceh-nv <[email protected]> Co-authored-by: Guoming Zhang <[email protected]> Co-authored-by: Wanli Jiang <[email protected]> Co-authored-by: pcastonguay <[email protected]> Co-authored-by: ruodil <[email protected]> Co-authored-by: Linda <[email protected]> Co-authored-by: Zhanrui Sun <[email protected]> Co-authored-by: Yuxian Qiu <[email protected]> Co-authored-by: Jiagan Cheng <[email protected]> Co-authored-by: William Zhang <[email protected]> Co-authored-by: Larry <[email protected]> Co-authored-by: Sharan Chetlur <[email protected]> Co-authored-by: Dom Brown <[email protected]>
1 parent 504bb7f commit c9dca69

File tree

37 files changed

+372
-146
lines changed

37 files changed

+372
-146
lines changed

cpp/include/tensorrt_llm/batch_manager/logitsPostProcessor.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class LogitsPostProcessor : Algorithm
4747

4848
bool operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor,
4949
runtime::WorldConfig const& worldConfig, CudaStreamPtr const& stream,
50-
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched = std::nullopt) const;
50+
std::optional<LogitsPostProcessorBatched> const& logitsPostProcessorBatched = std::nullopt) const;
5151
};
5252

5353
} // namespace tensorrt_llm::batch_manager

cpp/tensorrt_llm/batch_manager/logitsPostProcessor.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ using SizeType32 = tensorrt_llm::runtime::SizeType32;
3434

3535
bool LogitsPostProcessor::operator()(DecoderInputBuffers& inputBuffers, bool replicateLogitsPostProcessor,
3636
tr::WorldConfig const& worldConfig, CudaStreamPtr const& stream,
37-
std::optional<LogitsPostProcessorBatched> logitsPostProcessorBatched) const
37+
std::optional<LogitsPostProcessorBatched> const& logitsPostProcessorBatched) const
3838
{
3939
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
4040
NVTX3_SCOPED_RANGE(LogitsPostProcessor);

docs/source/commands/trtllm-serve/trtllm-serve.rst

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -201,56 +201,60 @@ Metrics Endpoint
201201
202202
.. note::
203203
204-
This endpoint is beta maturity.
204+
The metrics endpoint for the default PyTorch backend are in beta and are not as comprehensive as those for the TensorRT backend.
205205
206-
The statistics for the PyTorch backend are beta and not as comprehensive as those for the TensorRT backend.
206+
Some fields, such as CPU memory usage, are not yet available for the PyTorch backend.
207207
208-
Some fields, such as CPU memory usage, are not available for the PyTorch backend.
208+
Enabling ``enable_iter_perf_stats`` in the PyTorch backend can slightly impact performance, depending on the serving configuration.
209209
210-
Enabling ``enable_iter_perf_stats`` in the PyTorch backend can impact performance slightly, depending on the serving configuration.
210+
The ``/metrics`` endpoint provides runtime iteration statistics such as GPU memory usage and KV cache details.
211211
212-
The ``/metrics`` endpoint provides runtime-iteration statistics such as GPU memory use and inflight-batching details.
213-
For the TensorRT backend, these statistics are enabled by default.
214-
However, for the PyTorch backend, you must explicitly enable iteration statistics logging by setting the `enable_iter_perf_stats` field in a YAML configuration file as shown in the following example:
212+
For the default PyTorch backend, iteration statistics logging is enabled by setting the ``enable_iter_perf_stats`` field in a YAML file:
215213
216214
.. code-block:: yaml
217215
218-
# extra-llm-api-config.yml
219-
pytorch_backend_config:
220-
enable_iter_perf_stats: true
216+
# extra_llm_config.yaml
217+
enable_iter_perf_stats: true
221218
222-
Then start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file as shown in the following example:
219+
Start the server and specify the ``--extra_llm_api_options`` argument with the path to the YAML file:
223220
224221
.. code-block:: bash
225222
226-
trtllm-serve <model> \
227-
--extra_llm_api_options <path-to-extra-llm-api-config.yml> \
228-
[--tp_size <tp> --pp_size <pp> --ep_size <ep> --host <host> --port <port>]
223+
trtllm-serve "TinyLlama/TinyLlama-1.1B-Chat-v1.0" --extra_llm_api_options extra_llm_config.yaml
229224
230-
After at least one inference request is sent to the server, you can fetch the runtime-iteration statistics by polling the `/metrics` endpoint:
225+
After sending at least one inference request to the server, you can fetch runtime iteration statistics by polling the ``/metrics`` endpoint.
226+
Since the statistics are stored in an internal queue and removed once retrieved, it's recommended to poll the endpoint shortly after each request and store the results if needed.
231227
232228
.. code-block:: bash
233229
234-
curl -X GET http://<host>:<port>/metrics
230+
curl -X GET http://localhost:8000/metrics
235231
236-
*Example Output*
232+
Example output:
237233
238234
.. code-block:: json
239235
240-
[
241-
{
242-
"gpuMemUsage": 56401920000,
243-
"inflightBatchingStats": {
236+
[
237+
{
238+
"gpuMemUsage": 76665782272,
239+
"iter": 154,
240+
"iterLatencyMS": 7.00688362121582,
241+
"kvCacheStats": {
242+
"allocNewBlocks": 3126,
243+
"allocTotalBlocks": 3126,
244+
"cacheHitRate": 0.00128,
245+
"freeNumBlocks": 101253,
246+
"maxNumBlocks": 101256,
247+
"missedBlocks": 3121,
248+
"reusedBlocks": 4,
249+
"tokensPerBlock": 32,
250+
"usedNumBlocks": 3
251+
},
252+
"numActiveRequests": 1
244253
...
245-
},
246-
"iter": 1,
247-
"iterLatencyMS": 16.505143404006958,
248-
"kvCacheStats": {
249-
...
250-
},
251-
"newActiveRequestsQueueLatencyMS": 0.0007503032684326172
252-
}
253-
]
254+
}
255+
]
256+
257+
254258
255259
Syntax
256260
------

docs/source/legacy/tensorrt_quickstart.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# LLM API with TensorRT Engine
22
A simple inference example with TinyLlama using the LLM API:
33

4-
```{literalinclude} ../../examples/llm-api/_tensorrt_engine/quickstart_example.py
4+
```{literalinclude} ../../../examples/llm-api/_tensorrt_engine/quickstart_example.py
55
:language: python
66
:linenos:
77
```

examples/llm-api/_tensorrt_engine/quickstart_example.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1-
from tensorrt_llm import LLM, SamplingParams
1+
from tensorrt_llm import BuildConfig, SamplingParams
2+
from tensorrt_llm._tensorrt_engine import LLM # NOTE the change
23

34

45
def main():
56

7+
build_config = BuildConfig()
8+
build_config.max_batch_size = 256
9+
build_config.max_num_tokens = 1024
10+
611
# Model could accept HF model name, a path to local HF model,
712
# or TensorRT Model Optimizer's quantized checkpoints like nvidia/Llama-3.1-8B-Instruct-FP8 on HF.
8-
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0")
13+
llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
14+
build_config=build_config)
915

1016
# Sample prompts.
1117
prompts = [

examples/llm-api/llm_mgmn_trtllm_bench.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ srun -l \
7676
7777
# This is optional
7878
cat > /tmp/pytorch_extra_args.txt << EOF
79+
cuda_graph_config: null
7980
print_iter_log: true
8081
enable_attention_dp: false
8182
EOF

tensorrt_llm/_torch/attention_backend/flashinfer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def __post_init__(self) -> None:
170170
def create_cuda_graph_metadata(self,
171171
max_batch_size: int,
172172
sub_cross_metadata: bool = False,
173-
max_draft_tokens: int = 0) -> Self:
173+
max_draft_tokens: int = 0,
174+
buffers=None) -> Self:
174175
metadata = super().create_cuda_graph_metadata(max_batch_size,
175176
sub_cross_metadata,
176177
max_draft_tokens)

tensorrt_llm/_torch/attention_backend/interface.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ class AttentionMetadata:
140140

141141
# This buffer is currently only used for TrtllmAttentionMetadata.
142142
cache_indirection: Optional[torch.Tensor] = None
143+
cuda_graph_buffers: dict[str, list[torch.Tensor]] = None
143144

144145
_saved_tensors: Dict[str, torch.Tensor] = field(init=False,
145146
default_factory=dict)
@@ -288,7 +289,8 @@ def prepare(self):
288289
def create_cuda_graph_metadata(self,
289290
max_batch_size: int,
290291
sub_cross_metadata: bool = False,
291-
max_draft_tokens: int = 0) -> Self:
292+
max_draft_tokens: int = 0,
293+
buffers=None) -> Self:
292294
"""
293295
Creates metadata for CUDA graph execution.
294296
CUDA graphs require to use pre-allocated buffers for all tensors in fields.
@@ -300,6 +302,7 @@ def create_cuda_graph_metadata(self,
300302

301303
cuda_graph_metadata = copy.copy(self)
302304
cuda_graph_metadata.is_cuda_graph = True
305+
cuda_graph_metadata.cuda_graph_buffers = buffers
303306
if self.has_cross_sub_metadata:
304307
cuda_graph_metadata.cross = cuda_graph_metadata.cross.create_cuda_graph_metadata(
305308
max_batch_size, True)

tensorrt_llm/_torch/attention_backend/trtllm.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -600,21 +600,76 @@ def host_kv_cache_pool_mapping(self) -> Optional[torch.Tensor]:
600600

601601
def __post_init__(self) -> None:
602602
super().__post_init__()
603+
self._post_init_with_buffers(self.cuda_graph_buffers)
604+
605+
def _post_init_with_buffers(self, buffers) -> None:
606+
603607
# Set a default value, as max_num_sequences is not always set.
604608
if self.max_num_sequences is None:
605609
self.max_num_sequences = self.max_num_requests
606610

607-
self.prompt_lens_cuda = torch.empty(
611+
def get_empty(tensor_shape: list[int], dtype: torch.dtype,
612+
cache_name: str) -> torch.Tensor:
613+
"""
614+
Finds a compatible, reusable buffer from a cache or creates a new one.
615+
616+
This function searches for a pre-allocated tensor (buffer) that can be
617+
reused for an operation involving a tensor with the shape of `tensor_shape`.
618+
619+
The compatibility rules are: The buffer's total elements must be >= tensor_shape's.
620+
621+
If a compatible buffer is found, it's returned immediately. Otherwise, a new
622+
buffer is allocated on the 'cuda' device with the give properties of 'tensor_shape' and 'dtype'.
623+
624+
Args:
625+
tensor_shape: The required shape.
626+
dtype: The required dtype.
627+
cache_name: The key for the specific list of buffers to search in.
628+
629+
Returns:
630+
An existing compatible buffer or a newly created one.
631+
"""
632+
if buffers is not None:
633+
# Safely get the list of candidates. Defaults to an empty list if key is missing.
634+
candidate_buffers = buffers.get(cache_name, [])
635+
numel_like = math.prod(tensor_shape)
636+
637+
for buffer in candidate_buffers:
638+
numel_buffer = buffer.numel()
639+
640+
# buffer just needs to be large enough.
641+
if numel_buffer >= numel_like:
642+
return buffer[0:numel_like].view(
643+
tensor_shape) # Found a fit, return immediately.
644+
645+
# If we get here, no suitable buffer was found in the cache. Create a new one.
646+
new_buffer = torch.zeros(tensor_shape, device='cuda', dtype=dtype)
647+
if buffers is not None:
648+
buffers.setdefault(cache_name, []).append(new_buffer)
649+
return new_buffer
650+
651+
def get_empty_like(like_tensor: torch.Tensor,
652+
cache_name: str) -> torch.Tensor:
653+
return get_empty(
654+
like_tensor.shape,
655+
cache_name=cache_name,
656+
dtype=like_tensor.dtype,
657+
)
658+
659+
self.prompt_lens_cuda = get_empty(
608660
(self.max_num_sequences, ),
609-
device='cuda',
661+
cache_name="prompt_lens_cuda",
610662
dtype=torch.int,
611663
)
612664
self.prompt_lens_cpu = torch.empty_like(
613665
self.prompt_lens_cuda,
614666
device='cpu',
615667
pin_memory=True,
616668
)
617-
self.kv_lens_cuda = torch.empty_like(self.prompt_lens_cuda)
669+
self.kv_lens_cuda = get_empty_like(
670+
self.prompt_lens_cuda,
671+
cache_name="kv_lens_cuda",
672+
)
618673
self.kv_lens = torch.empty_like(self.kv_lens_cuda,
619674
device='cpu',
620675
pin_memory=True)
@@ -629,13 +684,13 @@ def __post_init__(self) -> None:
629684
dtype=torch.int8,
630685
)
631686
if self.kv_cache_manager is not None:
632-
self.kv_cache_block_offsets = torch.empty(
687+
self.kv_cache_block_offsets = get_empty(
633688
[
634689
self.kv_cache_manager.num_pools, self.max_num_sequences, 2,
635690
self.kv_cache_manager.max_blocks_per_seq
636691
],
692+
cache_name="kv_cache_block_offsets",
637693
dtype=torch.int32,
638-
device='cuda',
639694
)
640695
self.host_kv_cache_block_offsets = torch.empty_like(
641696
self.kv_cache_block_offsets,
@@ -645,37 +700,37 @@ def __post_init__(self) -> None:
645700
self.block_ids_per_seq = None
646701
self.kv_block_ids_per_seq = None
647702
if self.enable_flash_mla:
648-
self.block_ids_per_seq = torch.zeros(
703+
self.block_ids_per_seq = get_empty(
649704
[
650705
self.kv_cache_manager.max_batch_size,
651706
self.kv_cache_manager.max_blocks_per_seq
652707
],
708+
cache_name="block_ids_per_seq",
653709
dtype=torch.int32,
654-
device='cuda',
655710
)
656-
self.kv_block_ids_per_seq = torch.zeros(
711+
self.kv_block_ids_per_seq = get_empty(
657712
[
658713
self.kv_cache_manager.max_batch_size,
659714
self.kv_cache_manager.max_blocks_per_seq
660715
],
716+
cache_name="kv_block_ids_per_seq",
661717
dtype=torch.int32,
662-
device='cuda',
663718
)
664719
if self.enable_context_mla_with_cached_kv:
665720
# for kv cache reuse/chunked context in MLA
666-
self.ctx_cached_token_indptr = torch.zeros(
721+
self.ctx_cached_token_indptr = get_empty(
667722
(self.max_num_requests + 1, ),
668-
device='cuda',
723+
cache_name="ctx_cached_token_indptr",
669724
dtype=torch.int64,
670725
)
671726
self.host_ctx_cached_token_indptr = torch.zeros_like(
672727
self.ctx_cached_token_indptr,
673728
device='cpu',
674729
pin_memory=True,
675730
)
676-
self.ctx_uncached_token_indptr = torch.zeros(
731+
self.ctx_uncached_token_indptr = get_empty(
677732
(self.max_num_requests + 1, ),
678-
device='cuda',
733+
cache_name="ctx_uncached_token_indptr",
679734
dtype=torch.int64,
680735
)
681736
self.host_ctx_uncached_token_indptr = torch.zeros_like(
@@ -684,9 +739,9 @@ def __post_init__(self) -> None:
684739
pin_memory=True,
685740
)
686741
# context full seqlens include cached tokens and uncached tokens
687-
self.ctx_kv_indptr = torch.zeros(
742+
self.ctx_kv_indptr = get_empty(
688743
(self.max_num_requests + 1, ),
689-
device='cuda',
744+
cache_name="ctx_kv_indptr",
690745
dtype=torch.int64,
691746
)
692747
self.host_ctx_kv_indptr = torch.zeros_like(
@@ -1165,7 +1220,7 @@ def forward(
11651220
host_kv_cache_pool_pointers=metadata.host_kv_cache_pool_pointers,
11661221
host_kv_cache_pool_mapping=metadata.host_kv_cache_pool_mapping,
11671222
block_ids_per_seq=metadata.block_ids_per_seq,
1168-
workspace=metadata.workspace,
1223+
workspace=None,
11691224
cache_indirection=metadata.cache_indirection,
11701225
kv_scale_orig_quant=self.kv_scale_orig_quant,
11711226
kv_scale_quant_orig=self.kv_scale_quant_orig,

tensorrt_llm/_torch/autotuner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def choose_one(
371371
if not is_cache_hit:
372372
logger.warning_once(
373373
f"[AutoTunner] Using the fallback tactic, due to cache miss on input shapes={input_shapes}",
374-
key=(custom_op))
374+
key=custom_op)
375375

376376
return (best_runner, best_tactic)
377377

0 commit comments

Comments
 (0)