Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 28 additions & 23 deletions cpp/tensorrt_llm/executor/executorImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,30 +947,32 @@ std::vector<Response> Executor::Impl::awaitResponses(std::optional<std::chrono::
{
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
checkParallelApiUsage(__func__);
std::vector<Response> responses;
std::unique_lock<std::mutex> lck(mResponsesMtx);
auto pred = [&mShutdown = mShutdown, &resp = this->mResponses]() -> bool { return !resp.empty() || mShutdown; };
auto storeResponses = [this, &resp = this->mResponses, &responses]()
auto pred = [this]() -> bool { return !mResponses.empty() || mShutdown; };
auto storeResponses = [this]()
{
for (auto it = resp.cbegin(); it != resp.cend();)
std::vector<Response> responses;
for (auto it = mResponses.begin(); it != mResponses.end();)
{
responses.insert(responses.end(), it->second.begin(), it->second.end());
addTerminatedReqId(it->second, it->first);
resp.erase(it++);
it = mResponses.erase(it);
}
return responses;
};

std::vector<Response> responses;
if (timeout)
{
if (mResponsesCv.wait_for(lck, timeout.value(), pred))
{
storeResponses();
responses = storeResponses();
}
}
else
{
mResponsesCv.wait(lck, pred);
storeResponses();
responses = storeResponses();
}
return responses;
}
Expand All @@ -980,15 +982,16 @@ std::vector<Response> Executor::Impl::awaitResponses(
{
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
checkParallelApiUsage(__func__);
std::vector<Response> responses;
std::unique_lock<std::mutex> lck(mResponsesMtx);
auto pred = [&mShutdown = mShutdown, &resp = this->mResponses, reqId]() -> bool
{ return (resp.find(reqId) != resp.end() && !resp.at(reqId).empty()) || mShutdown; };
auto storeIdResponse = [this, &resp = this->mResponses, &responses, reqId]()
auto pred = [this, reqId]() -> bool
{ return (mResponses.find(reqId) != mResponses.end() && !mResponses.at(reqId).empty()) || mShutdown; };
auto storeIdResponse = [this, reqId]()
{
responses.swap(resp.at(reqId));
resp.erase(reqId);
std::vector<Response> responses;
responses.swap(mResponses.at(reqId));
mResponses.erase(reqId);
addTerminatedReqId(responses, reqId);
return responses;
};

// We don't process a terminated request again. Terminated request is defined as a response
Expand All @@ -1005,17 +1008,18 @@ std::vector<Response> Executor::Impl::awaitResponses(
return {Response(reqId, err)};
}

std::vector<Response> responses;
if (timeout)
{
if (mResponsesCv.wait_for(lck, timeout.value(), pred))
{
storeIdResponse();
responses = storeIdResponse();
}
}
else
{
mResponsesCv.wait(lck, pred);
storeIdResponse();
responses = storeIdResponse();
}
return responses;
}
Expand All @@ -1025,26 +1029,27 @@ std::vector<std::vector<Response>> Executor::Impl::awaitResponses(
{
TLLM_CHECK_WITH_INFO(!mShutdownCalled, "Shutdown called");
checkParallelApiUsage(__func__);
std::vector<std::vector<Response>> v(requestIds.size());
std::vector<std::vector<Response>> responses;
responses.reserve(requestIds.size());
if (timeout)
{
auto const start_time = std::chrono::high_resolution_clock::now();
for (unsigned i = 0; i < v.size(); ++i)
for (auto const requestId : requestIds)
{
auto const elapsed_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::high_resolution_clock::now() - start_time);
v[i] = awaitResponses(requestIds[i],
timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0});
responses.emplace_back(awaitResponses(
requestId, timeout.value() > elapsed_ms ? timeout.value() - elapsed_ms : std::chrono::milliseconds{0}));
}
}
else
{
for (unsigned i = 0; i < v.size(); ++i)
for (auto const requestId : requestIds)
{
v[i] = awaitResponses(requestIds[i]);
responses.emplace_back(awaitResponses(requestId));
}
}
return v;
return responses;
}

SizeType32 Executor::Impl::getNumResponsesReady(std::optional<IdType> const& optId) const
Expand Down Expand Up @@ -1663,7 +1668,7 @@ void Executor::Impl::terminateActiveRequests(RequestList& activeRequests, std::s
}

// Remove from the requestList
activeRequests.erase(it++);
it = activeRequests.erase(it);
}
}

Expand Down
2 changes: 1 addition & 1 deletion cpp/tensorrt_llm/executor/executorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class Executor::Impl
std::vector<Response> awaitResponses(std::optional<std::chrono::milliseconds> const& timeout = std::nullopt);

std::vector<Response> awaitResponses(
IdType const& optId, std::optional<std::chrono::milliseconds> const& optTimeout = std::nullopt);
IdType const& reqId, std::optional<std::chrono::milliseconds> const& optTimeout = std::nullopt);

std::vector<std::vector<Response>> awaitResponses(
std::vector<IdType> const& requestIds, std::optional<std::chrono::milliseconds> const& timeout);
Expand Down
12 changes: 8 additions & 4 deletions examples/models/core/deepseek_v3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/
- [trtllm-serve](#trtllm-serve)
- [Disaggregated Serving](#disaggregated-serving)
- [Dynamo](#dynamo)
- [tensorrtllm_backend for triton inference server (Experimental)](#tensorrtllm_backend-for-triton-inference-server-experimental)
- [tensorrtllm\_backend for triton inference server (Experimental)](#tensorrtllm_backend-for-triton-inference-server-experimental)
- [Advanced Usages](#advanced-usages)
- [Multi-node](#multi-node)
- [mpirun](#mpirun)
Expand All @@ -40,6 +40,8 @@ Please refer to [this guide](https://nvidia.github.io/TensorRT-LLM/installation/
- [FlashMLA](#flashmla)
- [FP8 KV Cache and MLA](#fp8-kv-cache-and-mla)
- [W4AFP8](#w4afp8)
- [Activation calibration](#activation-calibration)
- [Weight quantization and assembling](#weight-quantization-and-assembling)
- [KV Cache Reuse](#kv-cache-reuse)
- [Notes and Troubleshooting](#notes-and-troubleshooting)
- [Known Issues](#known-issues)
Expand Down Expand Up @@ -227,6 +229,8 @@ trtllm-eval --model <YOUR_MODEL_DIR> \
## Serving
### trtllm-serve

Take max-throughput scenario on B200 as an example, the settings are extracted from the [blog](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/Best_perf_practice_on_DeepSeek-R1_in_TensorRT-LLM.md#b200-max-throughput). **For users' own models and cases, the specific settings could be different to get best performance.**

To serve the model using `trtllm-serve`:

```bash
Expand All @@ -253,12 +257,12 @@ trtllm-serve \
--host localhost \
--port 8000 \
--backend pytorch \
--max_batch_size 161 \
--max_num_tokens 1160 \
--max_batch_size 384 \
--max_num_tokens 1536 \
--tp_size 8 \
--ep_size 8 \
--pp_size 1 \
--kv_cache_free_gpu_memory_fraction 0.95 \
--kv_cache_free_gpu_memory_fraction 0.85 \
--extra_llm_api_options ./extra-llm-api-config.yml
```

Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ class PPInitCaller(type):

def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__pp_init__()
return obj


Expand All @@ -235,6 +234,7 @@ def __init__(self, model_config: ModelConfig):
self.model_config = model_config
self.prologue = []
self.epilogue = []
self.keep_embed_tokens = False

def forward(
self,
Expand Down Expand Up @@ -278,7 +278,7 @@ def __pp_init__(self):
)
return

if hasattr(self, "embed_tokens"):
if hasattr(self, "embed_tokens") and not self.keep_embed_tokens:
self.prologue.append(self.embed_tokens)
if hasattr(self, "norm"):
self.epilogue.append(self.norm)
Expand Down Expand Up @@ -394,6 +394,8 @@ def __init__(self, model: TModel, *, config: ModelConfig[TConfig],
assert self.lm_head.tp_mode == self.model.embed_tokens.tp_mode, (
"lm_head and vocab embedding should use the same TP mode")
self.lm_head.weight = self.model.embed_tokens.weight
if config.mapping.is_last_pp_rank():
self.model.keep_embed_tokens = True

self.logits_processor = LogitsProcessor()

Expand Down
5 changes: 3 additions & 2 deletions tensorrt_llm/bench/benchmark/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
if not enable_chunked_prefill and max_num_tokens < dataset_metadata.max_isl:
logger.warning(
f"Chunked prefill is disabled, but max_num_tokens ({max_num_tokens}) is less than the max ISL ({dataset_metadata.max_isl}). "
f"Forcing max_num_tokens to {dataset_metadata.max_isl}.")
max_num_tokens = dataset_metadata.max_isl
f"Forcing max_num_tokens to {dataset_metadata.max_isl + max_batch_size}."
)
max_num_tokens = dataset_metadata.max_isl + max_batch_size

pyt_options = {
"use_cuda_graph":
Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,11 @@ def _topk_logprobs(logits: torch.Tensor, top_k: int,
# reshape from [1, T, V] to [T, V]
logits = logits.squeeze(0)

if tokens is not None and logits.size(0) > len(tokens):
# WAR for nvbug 5324291 where TRT backend might return more logits
# than output tokens.
logits = logits[:len(tokens)]

logprobs = F.log_softmax(logits.to("cuda", dtype=torch.float32), dim=-1)
topk_vals, topk_indices = torch.topk(logprobs, k=top_k, dim=-1)

Expand Down
14 changes: 9 additions & 5 deletions tensorrt_llm/runtime/model_runner_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,13 +935,17 @@ def _initialize_and_fill_output(
output_ids = [[[] for _ in range(num_sequences)]
for _ in range(len(request_ids))]

multi_responses = self.session.await_responses(request_ids)
responses = [
response for responses in multi_responses for response in responses
]
all_responses = []
finished_request_ids = set()
while finished_request_ids != set(request_ids):
responses = self.session.await_responses()
for response in responses:
if response.result.is_final:
finished_request_ids.add(response.request_id)
all_responses.extend(responses)

return self._fill_output(
responses=responses,
responses=all_responses,
output_ids=output_ids,
end_id=end_id,
return_dict=return_dict,
Expand Down
1 change: 1 addition & 0 deletions tests/integration/defs/examples/test_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def test_llm_hf_gemma_quantization_1gpu_vswa(batch_size, data_type,
gemma_example_root,
llm_datasets_root, llm_rouge_root,
qformat):
skip_fp8_pre_ada(use_fp8=qformat == "fp8")
max_attention_window = VSWA_ATTENTION[Path(gemma_model_root).stem]
hf_gemma_quantization_1gpu(batch_size, data_type, gemma_model_root,
llm_venv, cmodel_dir, engine_dir,
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/defs/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -1977,15 +1977,15 @@ def test_ptp_quickstart_multimodal(llm_root, llm_venv, model_name, model_path,
],
"video": [
["city", "night", "lights", "jacket", "wet"],
["earth", "spinning", "black", "illuminated", "lights"],
["earth", "spinning", "black"],
],
},
"qwen2.5-vl-7b-instruct": {
"image": [
["dramatic", "moody", "stormy", "turbulent", "wave"],
[
"dome", "yosemite", "landmark", "sunny", "rock", "clouds",
"pleasant"
"large", "dome", "yosemite", "landmark", "rock", "road",
"formation"
],
["highway", "traffic", "vehicles", "bus", "police"],
],
Expand Down
19 changes: 10 additions & 9 deletions tests/integration/defs/triton_server/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,15 +574,16 @@ def setup_cache_data(request, tensorrt_llm_example_root):


def cleanup_engine_outputs(output_dir_root):
for dirpath, dirnames, _ in os.walk(output_dir_root, topdown=False):
for dirname in dirnames:
if "engine_dir" in dirname or "model_dir" in dirname or "ckpt_dir" in dirname:
folder_path = os.path.join(dirpath, dirname)
try:
shutil.rmtree(folder_path)
print_info(f"Deleted folder: {folder_path}")
except Exception as e:
print_info(f"Error deleting {folder_path}: {e}")
if output_dir_root is not None:
for dirpath, dirnames, _ in os.walk(output_dir_root, topdown=False):
for dirname in dirnames:
if "engine_dir" in dirname or "model_dir" in dirname or "ckpt_dir" in dirname:
folder_path = os.path.join(dirpath, dirname)
try:
shutil.rmtree(folder_path)
print_info(f"Deleted folder: {folder_path}")
except Exception as e:
print_info(f"Error deleting {folder_path}: {e}")


# Teardown hook to clean up engine outputs after each group of test cases are finished
Expand Down
4 changes: 3 additions & 1 deletion tests/integration/defs/triton_server/test_triton_memleak.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_llama_v3_8b_rss_increasement(
inflight_batcher_llm_client_root,
tensorrt_llm_llama_example_root,
llama_v3_8b_model_root,
tensorrt_llm_example_root,
llm_backend_venv,
):
if BATCHING_STRATEGY == "V1" and BATCH_SCHEDULER_POLICY == "max_utilization":
Expand All @@ -83,7 +84,8 @@ def test_llama_v3_8b_rss_increasement(

llm_backend_repo_root = os.environ["LLM_BACKEND_ROOT"]
# Build engine
ENGINE_PATH = prepare_llama_v3_8b_engine(tensorrt_llm_llama_example_root,
ENGINE_PATH = prepare_llama_v3_8b_engine(tensorrt_llm_example_root,
tensorrt_llm_llama_example_root,
llama_v3_8b_model_root,
workers=1)

Expand Down
1 change: 0 additions & 1 deletion tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,6 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=False] SKIP (https://nvbugs/5322354)
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True] SKIP (https://nvbugs/5322354)
test_e2e.py::test_ptp_quickstart_advanced[Nemotron-H-8B-Nemotron-H-8B-Base-8K] SKIP (https://nvbugs/5325284)
accuracy/test_llm_api_pytorch.py::TestLlama3_2_1B::test_fp8_pp2 SKIP (https://nvbugspro.nvidia.com/bug/5312750)
test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-70B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-70B] SKIP (https://nvbugs/5323316)
disaggregated/test_disaggregated.py::test_disaggregated_single_gpu_with_mpirun[TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5328160)
test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b] SKIP (https://nvbugs/5320234)
Expand Down