Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions docs/source/scripts/disaggregated/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Disaggregated Inference Benchmark Scripts

This directory contains scripts to run disaggregated inference benchmarks using TensorRT-LLM and SLURM.

## Overview

The benchmarking process is orchestrated through a set of shell scripts and a Python script that work together:

1. `submit.sh`: The main entry point for submitting benchmark jobs to SLURM. It runs a parameter sweep by calling `sbatch` with different configurations.
2. `disaggr_torch.slurm`: The SLURM script that sets up and runs a single benchmark experiment. It launches a container, generates a configuration file, starts the server and workers, and runs the benchmark client.
3. `gen_yaml.py`: A Python script that generates the `config.yaml` file needed by `trtllm-serve`. It determines the server and worker configuration based on SLURM environment variables and script arguments.
4. `start_worker.sh`: A shell script responsible for starting a `trtllm-serve disaggregated_mpi_worker` on each allocated machine.
5. `run_benchmark.sh`: A shell script that waits for the server to be healthy and then runs the actual benchmark client (`run_benchmark.py`, not included in this directory).

## File Descriptions

### `submit.sh`

This script is used to submit multiple SLURM jobs for running benchmarks with different parameters. It iterates through various configurations and uses `sbatch` to submit `disaggr_torch.slurm` for each one.

**Usage:**

```bash
./submit.sh
```

You can modify the loops in this script to change the parameter space for the benchmark sweep.

### `disaggr_torch.slurm`

This is the core SLURM script for a single benchmark run. It is not meant to be run directly, but rather submitted via `sbatch` (e.g., by `submit.sh`).

It takes the following arguments in order:

1. `num_ctx_servers`: Number of context servers.
2. `ctx_tp_size`: Tensor parallel size for context servers.
3. `ctx_batch_size`: Max batch size for context servers.
4. `ctx_max_num_tokens`: Max number of tokens for context servers.
5. `ctx_enable_attention_dp`: `true` or `false` to enable attention DP for context servers.
6. `num_gen_servers`: Number of generation servers.
7. `gen_tp_size`: Tensor parallel size for generation servers.
8. `gen_batch_size`: Max batch size for generation servers.
9. `gen_max_num_tokens`: Max number of tokens for generation servers.
10. `gen_enable_attention_dp`: `true` or `false` to enable attention DP for generation servers.
11. `gen_gpu_memory_fraction`: GPU memory fraction for generation servers.
12. `concurrency_list`: A space-separated list of concurrencies to test (e.g., "1 2 4 8").
13. `sub_file`: A subdirectory name for logs.

### `gen_yaml.py`

This Python script generates the `config.yaml` file that configures the `trtllm-serve` application. It reads SLURM environment variables (`SLURM_JOB_NODELIST`, `SLURM_TASKS_PER_NODE`) to distribute workers across nodes.

**Usage:**

The script is called from within `disaggr_torch.slurm`. It takes numerous arguments to define the model, parallelism, and server configurations.

### `start_worker.sh`

This script starts a `trtllm-serve disaggregated_mpi_worker`. It is launched by `srun` from the `disaggr_torch.slurm` script on all allocated nodes.

**Arguments:**

1. `config_file`: Path to the `config.yaml` file.
2. `enable_pdl`: `true` or `false`.
3. `ctx_gpus`: Number of GPUs used for the context phase.
4. `work_dir`: (Optional) Directory to store nsys profiling output.

### `run_benchmark.sh`

This script orchestrates the execution of the benchmark client. It waits for the `config.yaml` to be created and for the server's `/health` endpoint to respond, then it runs the benchmark.

**Arguments:**

1. `isl`: Input sequence length.
2. `osl`: Output sequence length.
3. `multi_round`: Number of rounds for the benchmark.
4. `model_name`: Name of the model being benchmarked.
5. `concurrency_list`: Space-separated list of concurrencies.
6. `streaming`: `true` or `false`.
7. `log_path`: Path to the log directory.

## Workflow

1. The user runs `./submit.sh`.
2. `submit.sh` submits one or more jobs to SLURM by calling `sbatch disaggr_torch.slurm` with different parameters.
3. For each job, SLURM allocates resources and runs `disaggr_torch.slurm`.
4. `disaggr_torch.slurm` runs `gen_yaml.py` to create a `config.yaml`.
5. `disaggr_torch.slurm` uses `srun` to launch `start_worker.sh` on all nodes, starting the MPI workers.
6. `disaggr_torch.slurm` starts the main `trtllm-serve` process.
7. `disaggr_torch.slurm` runs `run_benchmark.sh` which waits for the server to be ready.
8. `run_benchmark.sh` executes the benchmark for each concurrency level specified.
9. After the benchmark, `run_benchmark.sh` and `disaggr_torch.slurm` attempt to kill the server and worker processes.
10. Logs for each run are stored in a subdirectory specified by the `sub_file` parameter.
112 changes: 112 additions & 0 deletions docs/source/scripts/disaggregated/disaggr_torch.slurm
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#!/bin/bash
#SBATCH --nodes=2
#SBATCH --ntasks=8
#SBATCH --ntasks-per-node=4
#SBATCH --partition=batch
#SBATCH --account=${account}
#SBATCH --time=02:00:00
#SBATCH --job-name="${account}:disaggr-test"

isl=8192
osl=256
multi_round=10
gen_yaml_file=gen_yaml.py
container_image=${docker_image}
mount_dir=/${account}/${user}/
workdir=/${account}/${user}/8k-${osl}/disaggr-e2e/
model_dir=/${account}/${user}/DeepSeek-R1-nvfp4_allmoe/
logdir=$workdir/bm_deepseek-r1-8k-${osl}-disaggr-e2e-nostream
streaming=false
mkdir -p ${logdir}

dep_dir=${workdir}
run_benchmark_cmd="bash ${dep_dir}/run_benchmark.sh"

container_name=disaggr-test

num_ctx_servers=$1
ctx_tp_size=$2
ctx_batch_size=$3
ctx_max_num_tokens=$4
ctx_enable_attention_dp=$5
num_gen_servers=$6
gen_tp_size=$7
gen_batch_size=$8
gen_max_num_tokens=$9
gen_enable_attention_dp=${10}
gen_gpu_memory_fraction=${11}
concurrency_list=${12}
sub_file=${13}

# concurrency=$((concurrency * gen_tp_size))
echo "concurrency_list: ${concurrency_list}"

ctx_gpus=$((num_ctx_servers * ctx_tp_size))
gen_gpus=$((num_gen_servers * gen_tp_size))

echo "enable_attention_dp: ${ctx_enable_attention_dp}, ${gen_enable_attention_dp}, gpu_memory_fraction: ${gen_gpu_memory_fraction}"

enable_pdl=false
if [ "${gen_enable_attention_dp}" = "false" ]; then
enable_pdl=true
fi

full_logdir=${logdir}/${sub_file}
mkdir -p ${full_logdir}

# start the container
srun -l --container-image=${container_image} \
--container-name=${container_name} \
--container-mounts=${mount_dir}:${mount_dir} \
--mpi=pmix \
echo "Container up."

# generate the yaml file
srun -l --container-name=${container_name} \
--container-mounts=${mount_dir}:${mount_dir} \
--mpi=pmix --overlap \
python3 ${dep_dir}/${gen_yaml_file} --config ${full_logdir}/config.yaml \
--model ${model_dir} \
--num_ctx_servers ${num_ctx_servers} \
--ctx_tp_size ${ctx_tp_size} \
--ctx_batch_size ${ctx_batch_size} \
--ctx_max_num_tokens ${ctx_max_num_tokens} \
--num_gen_servers ${num_gen_servers} \
--gen_tp_size ${gen_tp_size} \
--gen_batch_size ${gen_batch_size} \
--gen_max_num_tokens ${gen_max_num_tokens} \
--gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \
$(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \
$(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi)

echo "YAML file generated."

hostname_value=$(grep '^hostname:' ${full_logdir}/config.yaml | awk -F': ' '{print $2}')
echo "server host name: $hostname_value"

nsys_on=""
# nsys_on=${full_logdir}

# start the workers
srun -l --container-name=${container_name} \
--container-mounts=${mount_dir}:${mount_dir} \
--mpi=pmix --overlap \
bash ${dep_dir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${nsys_on} &> ${full_logdir}/output_workers.log &
# start the server
srun -l --container-name=${container_name} \
--container-mounts=${mount_dir}:${mount_dir} \
--mpi=pmix --overlap -N 1 -n 1 \
bash trtllm-serve disaggregated -c ${full_logdir}/config.yaml -t 1800 -r 1800 &> ${full_logdir}/output_server.log &
# start benchmark
srun -l --container-name=${container_name} \
--container-mounts=${mount_dir}:${mount_dir} \
--mpi=pmix --overlap -N 1 -n 1 \
--nodelist=${hostname_value} \
${run_benchmark_cmd} ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency_list}" ${streaming} ${full_logdir}/ > ${full_logdir}/benchmark.log 2>&1
wait

# try to kill the server and workers
srun -l --container-name=${container_name} \
--container-mounts=${mount_dir}:${mount_dir} \
--mpi=pmix --overlap \
pkill -f "trtllm-serve" || true
Loading