Skip to content

Commit f8730d5

Browse files
authored
Merge branch 'main' into deep_ep/fp4_combine
2 parents 118f9e7 + 50e5e72 commit f8730d5

File tree

24 files changed

+997
-209
lines changed

24 files changed

+997
-209
lines changed

cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,11 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
297297
= mFixedParams.isSPadded ? runnerParams.b * runnerParams.qSeqLen : runnerParams.totalQSeqLen;
298298
mLaunchParams.total_kv_seqlen
299299
= mFixedParams.isSPadded ? runnerParams.b * runnerParams.kvSeqLen : runnerParams.totalKvSeqLen;
300+
// Workaround for nvbug 5412456: total_kv_seqlen fallbacks to total_q_seqlen if it's zero.
301+
if (mLaunchParams.total_kv_seqlen == 0)
302+
{
303+
mLaunchParams.total_kv_seqlen = mLaunchParams.total_q_seqlen;
304+
}
300305

301306
TLLM_CHECK_WITH_INFO(mFixedParams.headSize > 0, "Head size should be greater than 0.");
302307
// Pad head size to next power of 2.

docs/source/commands/trtllm-serve/run-benchmark-with-trtllm-serve.md

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
TensorRT-LLM provides the OpenAI-compatiable API via `trtllm-serve` command.
44
A complete reference for the API is available in the [OpenAI API Reference](https://platform.openai.com/docs/api-reference).
55

6-
This step-by-step tutorial covers the following topics for running online serving benchmarking with Llama 3.1 70B:
6+
This step-by-step tutorial covers the following topics for running online serving benchmarking with Llama 3.1 70B and Qwen2.5-VL-7B for multimodal models:
77
* Methodology Introduction
88
* Launch the OpenAI-Compatibale Server with NGC container
99
* Run the performance benchmark
1010
* Using `extra_llm_api_options`
11+
* Multimodal Serving and Benchmarking
1112

1213

1314
## Methodology Introduction
@@ -220,3 +221,78 @@ The following is a list of common performance switches.
220221
 **Default**: TRTLLM
221222

222223
See the [TorchLlmArgs class](https://nvidia.github.io/TensorRT-LLM/llm-api/reference.html#tensorrt_llm.llmapi.TorchLlmArgs) for the full list of options which can be used in the extra\_llm\_api\_options`.`
224+
225+
## Multimodal Serving and Benchmarking
226+
227+
TensorRT-LLM supports multimodal models for both serving and benchmarking. This section covers how to set up multimodal serving and run benchmarks for multimodal models.
228+
229+
### Setting up Multimodal Serving
230+
231+
Here's an example of setting up multimodal serving with Qwen2.5-VL:
232+
233+
```bash
234+
#!/bin/bash
235+
model_path=/path/to/qwen2.5vl-7B_model
236+
237+
trtllm-serve ${model_path} \
238+
--max_batch_size 64 \
239+
--max_num_tokens 8192 \
240+
--max_seq_len 4096 \
241+
--kv_cache_free_gpu_memory_fraction 0.9 \
242+
--tp_size 1 \
243+
--ep_size 1 \
244+
--trust_remote_code
245+
```
246+
247+
### Multimodal Benchmarking
248+
249+
For multimodal serving benchmarks, you can use the `benchmark_serving.py` script with multimodal datasets:
250+
251+
```bash
252+
python -m tensorrt_llm.serve.scripts.benchmark_serving \
253+
--model ${model_path} \
254+
--backend openai-chat \
255+
--dataset-name "random_image" \
256+
--random-input-len 128 \
257+
--random-output-len 128 \
258+
--random-image-width 512 \
259+
--random-image-height 512 \
260+
--random-num-images 1 \
261+
--num-prompts 100 \
262+
--max-concurrency 8 \
263+
--ignore-eos
264+
```
265+
266+
Below is some example TensorRT-LLM serving benchmark output. Your actual results may vary.
267+
```
268+
============ Serving Benchmark Result ============
269+
Successful requests: 1
270+
Benchmark duration (s): 0.83
271+
Total input tokens: 128
272+
Total generated tokens: 128
273+
Request throughput (req/s): 1.20
274+
Output token throughput (tok/s): 153.92
275+
Total Token throughput (tok/s): 307.85
276+
User throughput (tok/s): 154.15
277+
Mean Request AR: 0.9845
278+
Median Request AR: 0.9845
279+
---------------Time to First Token----------------
280+
Mean TTFT (ms): 84.03
281+
Median TTFT (ms): 84.03
282+
P99 TTFT (ms): 84.03
283+
-----Time per Output Token (excl. 1st token)------
284+
Mean TPOT (ms): 5.88
285+
Median TPOT (ms): 5.88
286+
P99 TPOT (ms): 5.88
287+
---------------Inter-token Latency----------------
288+
Mean ITL (ms): 5.83
289+
Median ITL (ms): 5.88
290+
P99 ITL (ms): 6.14
291+
==================================================
292+
```
293+
294+
**Notes for Multimodal Benchmarking:**
295+
- Set `--backend` as `openai-chat` since multimodal models are only supported on the chat API and require a chat template
296+
- Control the number of images per request with `--random-num-images`
297+
- Use `--random-image-width` and `--random-image-height` to specify image dimensions or `--random-image-size` for squared image dimensions.
298+
- The `random_image` dataset generates synthetic images for benchmarking

examples/disaggregated/slurm/disaggr_torch.slurm

Lines changed: 45 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,42 @@ container_image=${19}
3838
mounts=${20}
3939
workdir=${21}
4040
model_dir=${22}
41+
trtllm_repo=${23}
42+
43+
echo "================= parameters ================="
44+
echo "num_ctx_servers: ${num_ctx_servers}"
45+
echo "ctx_tp_size: ${ctx_tp_size}"
46+
echo "ctx_batch_size: ${ctx_batch_size}"
47+
echo "ctx_max_num_tokens: ${ctx_max_num_tokens}"
48+
echo "ctx_enable_attention_dp: ${ctx_enable_attention_dp}"
49+
echo "num_gen_servers: ${num_gen_servers}"
50+
echo "gen_tp_size: ${gen_tp_size}"
51+
echo "gen_batch_size: ${gen_batch_size}"
52+
echo "gen_max_num_tokens: ${gen_max_num_tokens}"
53+
echo "gen_enable_attention_dp: ${gen_enable_attention_dp}"
54+
echo "gen_gpu_memory_fraction: ${gen_gpu_memory_fraction}"
55+
echo "eplb_num_slots: ${eplb_num_slots}"
56+
echo "mtp_size: ${mtp_size}"
57+
echo "concurrency: ${concurrency}"
58+
echo "isl: ${isl}"
59+
echo "osl: ${osl}"
60+
echo "multi_round: ${multi_round}"
61+
echo "streaming: ${streaming}"
62+
echo "container_image: ${container_image}"
63+
echo "mounts: ${mounts}"
64+
echo "workdir: ${workdir}"
65+
echo "model_dir: ${model_dir}"
66+
echo "trtllm_repo: ${trtllm_repo}"
67+
echo "==========================================="
68+
4169

4270
ctx_max_seq_len=$((isl + 1))
4371
gen_max_seq_len=$((isl + osl))
4472
ctx_gpu_frac=0.75
4573
cache_transceiver_max_num_tokens=8448
4674

4775
container_name=disaggr
48-
logdir=${workdir}/benchmark-${isl}-${osl}/
76+
logdir=${workdir}/benchmark-${isl}-${osl}
4977
mkdir -p ${logdir}
5078
full_logdir=${logdir}/ctx${num_ctx_servers}_gen${num_gen_servers}_dep${gen_tp_size}_batch${gen_batch_size}_eplb${eplb_num_slots}_mtp${mtp_size}
5179

@@ -65,16 +93,27 @@ fi
6593
mkdir -p ${full_logdir}
6694
echo "Log will be saved to: ${full_logdir}"
6795

96+
if [ -z "${TRT_LLM_GIT_COMMIT}" ]; then
97+
export TRT_LLM_GIT_COMMIT=$(git -C ${trtllm_repo} rev-parse --short HEAD 2>/dev/null || echo "unknown")
98+
echo "TRT_LLM_GIT_COMMIT: ${TRT_LLM_GIT_COMMIT}"
99+
fi
100+
68101
nsys_on=""
69102
# nsys_on=${full_logdir} # Uncomment this line to enable Nsys profiling
70-
71103
# start the container
72104
srun -l --container-image=${container_image} \
73105
--container-name=${container_name} \
74106
--container-mounts=${mounts} \
75107
--mpi=pmix \
76108
echo "Container up."
77109

110+
if [ -n "${trtllm_repo}" ]; then
111+
srun --container-name=${container_name} \
112+
--container-mounts=${mounts} \
113+
--mpi=pmix --overlap -N $SLURM_NNODES --ntasks-per-node=1 \
114+
bash -c "cd ${trtllm_repo} && echo 'Running install operation...' && pip install -e . " 2>&1 | tee ${full_logdir}/install.log
115+
fi
116+
78117
# generate the yaml file
79118
srun -l --container-name=${container_name} \
80119
--container-mounts=${mounts} \
@@ -104,11 +143,12 @@ echo "YAML file generated."
104143
hostname_value=$(grep '^hostname:' ${full_logdir}/config.yaml | awk -F': ' '{print $2}')
105144
echo "server host name: $hostname_value"
106145

146+
107147
# start the workers
108148
srun -l --container-name=${container_name} \
109149
--container-mounts=${mounts} \
110-
--mpi=pmix --overlap \
111-
bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${nsys_on} &> ${full_logdir}/output_workers.log &
150+
--mpi=pmix --overlap \
151+
bash ${workdir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${nsys_on} &> ${full_logdir}/output_workers.log &
112152

113153
# start the server
114154
srun -l --container-name=${container_name} \
@@ -121,7 +161,7 @@ srun -l --container-name=${container_name} \
121161
srun -l --container-name=${container_name} \
122162
--container-mounts=${mounts} \
123163
--mpi=pmix --overlap -N 1 -n 1 \
124-
bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir}/ > ${full_logdir}/benchmark.log 2>&1
164+
bash ${workdir}/run_benchmark.sh ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency}" ${streaming} ${full_logdir} > ${full_logdir}/benchmark.log 2>&1
125165

126166
# try to kill the server and workers
127167
srun -l --container-name=${container_name} \

examples/disaggregated/slurm/run_benchmark.sh

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ isl=$1
1616
osl=$2
1717
multi_round=$3
1818
model_name=$4
19-
concurrency=$5
19+
concurrency_list=$5
2020
streaming=$6
2121
log_path=$7
2222

@@ -89,31 +89,31 @@ do_get_logs(){
8989
}
9090

9191
# run the loadgen
92-
93-
mkdir -p ${log_path}/concurrency_${concurrency}
94-
cp ${log_path}/output_workers.log ${log_path}/concurrency_${concurrency}/workers_start.log
95-
max_count=$((${concurrency} * ${multi_round}))
96-
echo "Running loadgen with concurrency: ${concurrency}, max_count: ${max_count}"
97-
98-
python -m tensorrt_llm.serve.scripts.benchmark_serving \
99-
--model ${model_name} \
100-
--tokenizer ${model_name} \
101-
--dataset-name random \
102-
--dataset-path ${shared_gpt_path} \
103-
--random-input-len ${isl} \
104-
--random-output-len ${osl} \
105-
--random-prefix-len 0 \
106-
--num-prompts ${max_count} \
107-
--max-concurrency ${concurrency} \
108-
--host ${hostname} \
109-
--port ${port} \
110-
--ignore-eos \
111-
--no-test-input \
112-
$(if [ "${streaming}" = "false" ]; then echo "--non-streaming"; fi)
113-
114-
do_get_logs ${log_path}/output_workers.log ${log_path}/concurrency_${concurrency}
115-
# echo "" > ${log_path}/output_workers.log
116-
echo "done for ${concurrency} in folder ${log_path}/concurrency_${concurrency}"
92+
cp ${log_path}/output_workers.log ${log_path}/workers_start.log
93+
for concurrency in ${concurrency_list}; do
94+
mkdir -p ${log_path}/concurrency_${concurrency}
95+
max_count=$((${concurrency} * ${multi_round}))
96+
echo "Running loadgen with concurrency: ${concurrency}, max_count: ${max_count}"
97+
python -m tensorrt_llm.serve.scripts.benchmark_serving \
98+
--model ${model_name} \
99+
--tokenizer ${model_name} \
100+
--dataset-name random \
101+
--dataset-path ${shared_gpt_path} \
102+
--random-input-len ${isl} \
103+
--random-output-len ${osl} \
104+
--random-prefix-len 0 \
105+
--num-prompts ${max_count} \
106+
--max-concurrency ${concurrency} \
107+
--host ${hostname} \
108+
--port ${port} \
109+
--ignore-eos \
110+
--no-test-input \
111+
$(if [ "${streaming}" = "false" ]; then echo "--non-streaming"; fi)
112+
113+
do_get_logs ${log_path}/output_workers.log ${log_path}/concurrency_${concurrency}
114+
echo "" > ${log_path}/output_workers.log
115+
echo "done for ${concurrency} in folder ${log_path}/concurrency_${concurrency}"
116+
done
117117

118118
echo "Benchmark done, gracefully shutting down server and workers..."
119119
kill -9 $(ps aux | grep '[s]tart_server.sh' | awk '{print $2}') >/dev/null 2>&1 || true

examples/disaggregated/slurm/submit.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ container_image=<container_image>
77
mounts=<mounts> # e.g. /mnt/data:/mnt/data
88
workdir=<workdir> # Path to disaggr_torch.slurm
99
model_dir=<model_dir> # Path to the model checkpoint
10+
repo_dir=<repo_dir> # Path to the repo to install TensorRT-LLM, if this is empty, the pre-installed version will be used
1011

1112
ntasks_per_node=4 # 4 GPUs per GB200 node
1213
total_node_num=8
@@ -31,6 +32,7 @@ args=(
3132
$mounts
3233
$workdir
3334
$model_dir
35+
$repo_dir
3436
)
3537

3638
# This command starts a job with 8 nodes, 32 GPUs in total.

examples/quantization/quantize_mixed_precision_moe.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,16 @@ def load_and_preprocess_state_dict(modelopt_state_root, world_size=8):
4545
state_dict_list = []
4646
# load amax from state dict
4747
for rank in range(world_size):
48-
state_dict_list.append(
49-
torch.load(
50-
f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt",
51-
map_location="cuda:0"))
48+
amax_file = f"{modelopt_state_root}/amax_dict_rank{rank}-mp{world_size}.pt"
49+
if os.path.exists(amax_file):
50+
state_dict_list.append(torch.load(amax_file, map_location="cuda:0"))
51+
else:
52+
print(f"WARNING: amax file not found: {amax_file}")
53+
54+
if not state_dict_list:
55+
print("ERROR: No amax files loaded!")
56+
return {}
57+
5258
# calculate the max across all TP ranks
5359
merged_state_dict = state_dict_list[0]
5460
for rank in range(world_size):
@@ -232,15 +238,18 @@ def get_file_name(layer):
232238
continue
233239
new_safetensors.update({key: get_tensor(key)})
234240

241+
# Process activation scales for all ranks
242+
if os.path.isdir(args.act_scales):
243+
# Extract activation scales
244+
renamed_state_dict = load_and_preprocess_state_dict(
245+
modelopt_state_root=args.act_scales, world_size=8)
246+
scales = get_scales_from_amax(start_layer=start_layer,
247+
end_layer=end_layer,
248+
renamed_state_dict=renamed_state_dict)
249+
new_safetensors.update(scales)
250+
235251
if args.rank == 0:
236-
if os.path.isdir(args.act_scales):
237-
# Extract activation scales
238-
renamed_state_dict = load_and_preprocess_state_dict(
239-
modelopt_state_root=args.act_scales, world_size=8)
240-
get_scales_from_amax(start_layer=start_layer,
241-
end_layer=end_layer,
242-
renamed_state_dict=renamed_state_dict)
243-
else:
252+
if not os.path.isdir(args.act_scales):
244253
input_scales = safe_open(args.act_scales, "pt")
245254
for k in input_scales.keys():
246255
new_safetensors.update({k: input_scales.get_tensor(k)})
@@ -259,7 +268,10 @@ def get_file_name(layer):
259268
]
260269
for name in names:
261270
shutil.copy(os.path.join(model_dir, name), output_dir)
262-
shutil.copy(args.act_scales, output_dir)
271+
if os.path.isdir(args.act_scales):
272+
shutil.copytree(args.act_scales, output_dir, dirs_exist_ok=True)
273+
else:
274+
shutil.copy(args.act_scales, output_dir)
263275

264276
# config.json
265277
del config['quantization_config']

examples/wide_ep/slurm_scripts/submit.sh

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ container_image=<container_image>
99
mounts=<mounts> # e.g. /mnt/data:/mnt/data
1010
workdir=<workdir> # Path to disaggr_torch.slurm
1111
model_dir=<model_dir> # Path to the model checkpoint
12+
repo_dir=<repo_dir> # Path to the repo to install TensorRT-LLM, if this is empty, the pre-installed version will be used
1213

1314
mtp_size=0
1415
ntasks_per_node=4 # 4 GPUs per GB200 node
@@ -28,7 +29,7 @@ for b in 1 64 1024; do
2829

2930
args=(
3031
${ctx_num} 4 4 4480 true # Context servers arguments
31-
1 16 1024 1024 "0.7" # Generation servers arguments
32+
1 16 1024 1024 true "0.7" # Generation servers arguments
3233
$eplb_num_slots $mtp_size # Other arguments
3334
$concurrency # Benchmarking arguments
3435
$isl
@@ -39,6 +40,7 @@ for b in 1 64 1024; do
3940
$mounts
4041
$workdir
4142
$model_dir
43+
$repo_dir
4244
)
4345

4446
sbatch --nodes=${total_node_num} \
@@ -74,6 +76,7 @@ for b in 512; do
7476
$mounts
7577
$workdir
7678
$model_dir
79+
$repo_dir
7780
)
7881

7982
sbatch --nodes=${total_node_num} \

tensorrt_llm/_torch/modules/fused_moe/fused_moe_cutlass.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -110,13 +110,11 @@ def __init__(
110110
assert len(
111111
self.initial_local_expert_ids) == self.expert_size_per_partition
112112

113-
max_num_tokens = model_config.max_num_tokens
114113
# The maximum number of tokens in MoE are multiplied by DP size when attention DP is enabled
115-
if self.use_dp:
116-
max_num_tokens *= model_config.mapping.world_size
117-
self.moe_max_num_tokens = model_config.moe_max_num_tokens or max_num_tokens
114+
moe_max_num_tokens = model_config.max_num_tokens * model_config.mapping.dp_size
115+
self.moe_max_num_tokens = model_config.moe_max_num_tokens or moe_max_num_tokens
118116
# The auxiliary CUDA stream and CUDA events are only used when MoE chunking is applied
119-
if self.moe_max_num_tokens < max_num_tokens:
117+
if self.moe_max_num_tokens < moe_max_num_tokens:
120118
self.aux_stream = aux_stream_dict[
121119
AuxStreamType.
122120
MoeChunkingOverlap] if aux_stream_dict is not None else torch.cuda.Stream(

0 commit comments

Comments
 (0)