diff --git a/docs/source/scripts/disaggregated/README.md b/docs/source/scripts/disaggregated/README.md new file mode 100644 index 00000000000..ed21b998ddd --- /dev/null +++ b/docs/source/scripts/disaggregated/README.md @@ -0,0 +1,93 @@ +# Disaggregated Inference Benchmark Scripts + +This directory contains scripts to run disaggregated inference benchmarks using TensorRT-LLM and SLURM. + +## Overview + +The benchmarking process is orchestrated through a set of shell scripts and a Python script that work together: + +1. `submit.sh`: The main entry point for submitting benchmark jobs to SLURM. It runs a parameter sweep by calling `sbatch` with different configurations. +2. `disaggr_torch.slurm`: The SLURM script that sets up and runs a single benchmark experiment. It launches a container, generates a configuration file, starts the server and workers, and runs the benchmark client. +3. `gen_yaml.py`: A Python script that generates the `config.yaml` file needed by `trtllm-serve`. It determines the server and worker configuration based on SLURM environment variables and script arguments. +4. `start_worker.sh`: A shell script responsible for starting a `trtllm-serve disaggregated_mpi_worker` on each allocated machine. +5. `run_benchmark.sh`: A shell script that waits for the server to be healthy and then runs the actual benchmark client (`run_benchmark.py`, not included in this directory). + +## File Descriptions + +### `submit.sh` + +This script is used to submit multiple SLURM jobs for running benchmarks with different parameters. It iterates through various configurations and uses `sbatch` to submit `disaggr_torch.slurm` for each one. + +**Usage:** + +```bash +./submit.sh +``` + +You can modify the loops in this script to change the parameter space for the benchmark sweep. + +### `disaggr_torch.slurm` + +This is the core SLURM script for a single benchmark run. It is not meant to be run directly, but rather submitted via `sbatch` (e.g., by `submit.sh`). + +It takes the following arguments in order: + +1. `num_ctx_servers`: Number of context servers. +2. `ctx_tp_size`: Tensor parallel size for context servers. +3. `ctx_batch_size`: Max batch size for context servers. +4. `ctx_max_num_tokens`: Max number of tokens for context servers. +5. `ctx_enable_attention_dp`: `true` or `false` to enable attention DP for context servers. +6. `num_gen_servers`: Number of generation servers. +7. `gen_tp_size`: Tensor parallel size for generation servers. +8. `gen_batch_size`: Max batch size for generation servers. +9. `gen_max_num_tokens`: Max number of tokens for generation servers. +10. `gen_enable_attention_dp`: `true` or `false` to enable attention DP for generation servers. +11. `gen_gpu_memory_fraction`: GPU memory fraction for generation servers. +12. `concurrency_list`: A space-separated list of concurrencies to test (e.g., "1 2 4 8"). +13. `sub_file`: A subdirectory name for logs. + +### `gen_yaml.py` + +This Python script generates the `config.yaml` file that configures the `trtllm-serve` application. It reads SLURM environment variables (`SLURM_JOB_NODELIST`, `SLURM_TASKS_PER_NODE`) to distribute workers across nodes. + +**Usage:** + +The script is called from within `disaggr_torch.slurm`. It takes numerous arguments to define the model, parallelism, and server configurations. + +### `start_worker.sh` + +This script starts a `trtllm-serve disaggregated_mpi_worker`. It is launched by `srun` from the `disaggr_torch.slurm` script on all allocated nodes. + +**Arguments:** + +1. `config_file`: Path to the `config.yaml` file. +2. `enable_pdl`: `true` or `false`. +3. `ctx_gpus`: Number of GPUs used for the context phase. +4. `work_dir`: (Optional) Directory to store nsys profiling output. + +### `run_benchmark.sh` + +This script orchestrates the execution of the benchmark client. It waits for the `config.yaml` to be created and for the server's `/health` endpoint to respond, then it runs the benchmark. + +**Arguments:** + +1. `isl`: Input sequence length. +2. `osl`: Output sequence length. +3. `multi_round`: Number of rounds for the benchmark. +4. `model_name`: Name of the model being benchmarked. +5. `concurrency_list`: Space-separated list of concurrencies. +6. `streaming`: `true` or `false`. +7. `log_path`: Path to the log directory. + +## Workflow + +1. The user runs `./submit.sh`. +2. `submit.sh` submits one or more jobs to SLURM by calling `sbatch disaggr_torch.slurm` with different parameters. +3. For each job, SLURM allocates resources and runs `disaggr_torch.slurm`. +4. `disaggr_torch.slurm` runs `gen_yaml.py` to create a `config.yaml`. +5. `disaggr_torch.slurm` uses `srun` to launch `start_worker.sh` on all nodes, starting the MPI workers. +6. `disaggr_torch.slurm` starts the main `trtllm-serve` process. +7. `disaggr_torch.slurm` runs `run_benchmark.sh` which waits for the server to be ready. +8. `run_benchmark.sh` executes the benchmark for each concurrency level specified. +9. After the benchmark, `run_benchmark.sh` and `disaggr_torch.slurm` attempt to kill the server and worker processes. +10. Logs for each run are stored in a subdirectory specified by the `sub_file` parameter. diff --git a/docs/source/scripts/disaggregated/disaggr_torch.slurm b/docs/source/scripts/disaggregated/disaggr_torch.slurm new file mode 100644 index 00000000000..ae047c23552 --- /dev/null +++ b/docs/source/scripts/disaggregated/disaggr_torch.slurm @@ -0,0 +1,112 @@ +#!/bin/bash +#SBATCH --nodes=2 +#SBATCH --ntasks=8 +#SBATCH --ntasks-per-node=4 +#SBATCH --partition=batch +#SBATCH --account=${account} +#SBATCH --time=02:00:00 +#SBATCH --job-name="${account}:disaggr-test" + +isl=8192 +osl=256 +multi_round=10 +gen_yaml_file=gen_yaml.py +container_image=${docker_image} +mount_dir=/${account}/${user}/ +workdir=/${account}/${user}/8k-${osl}/disaggr-e2e/ +model_dir=/${account}/${user}/DeepSeek-R1-nvfp4_allmoe/ +logdir=$workdir/bm_deepseek-r1-8k-${osl}-disaggr-e2e-nostream +streaming=false +mkdir -p ${logdir} + +dep_dir=${workdir} +run_benchmark_cmd="bash ${dep_dir}/run_benchmark.sh" + +container_name=disaggr-test + +num_ctx_servers=$1 +ctx_tp_size=$2 +ctx_batch_size=$3 +ctx_max_num_tokens=$4 +ctx_enable_attention_dp=$5 +num_gen_servers=$6 +gen_tp_size=$7 +gen_batch_size=$8 +gen_max_num_tokens=$9 +gen_enable_attention_dp=${10} +gen_gpu_memory_fraction=${11} +concurrency_list=${12} +sub_file=${13} + +# concurrency=$((concurrency * gen_tp_size)) +echo "concurrency_list: ${concurrency_list}" + +ctx_gpus=$((num_ctx_servers * ctx_tp_size)) +gen_gpus=$((num_gen_servers * gen_tp_size)) + +echo "enable_attention_dp: ${ctx_enable_attention_dp}, ${gen_enable_attention_dp}, gpu_memory_fraction: ${gen_gpu_memory_fraction}" + +enable_pdl=false +if [ "${gen_enable_attention_dp}" = "false" ]; then + enable_pdl=true +fi + +full_logdir=${logdir}/${sub_file} +mkdir -p ${full_logdir} + +# start the container +srun -l --container-image=${container_image} \ + --container-name=${container_name} \ + --container-mounts=${mount_dir}:${mount_dir} \ + --mpi=pmix \ + echo "Container up." + +# generate the yaml file +srun -l --container-name=${container_name} \ + --container-mounts=${mount_dir}:${mount_dir} \ + --mpi=pmix --overlap \ + python3 ${dep_dir}/${gen_yaml_file} --config ${full_logdir}/config.yaml \ + --model ${model_dir} \ + --num_ctx_servers ${num_ctx_servers} \ + --ctx_tp_size ${ctx_tp_size} \ + --ctx_batch_size ${ctx_batch_size} \ + --ctx_max_num_tokens ${ctx_max_num_tokens} \ + --num_gen_servers ${num_gen_servers} \ + --gen_tp_size ${gen_tp_size} \ + --gen_batch_size ${gen_batch_size} \ + --gen_max_num_tokens ${gen_max_num_tokens} \ + --gen_gpu_memory_fraction ${gen_gpu_memory_fraction} \ + $(if [ "${gen_enable_attention_dp}" = "true" ]; then echo "--gen_enable_attention_dp"; fi) \ + $(if [ "${ctx_enable_attention_dp}" = "true" ]; then echo "--ctx_enable_attention_dp"; fi) + +echo "YAML file generated." + +hostname_value=$(grep '^hostname:' ${full_logdir}/config.yaml | awk -F': ' '{print $2}') +echo "server host name: $hostname_value" + +nsys_on="" +# nsys_on=${full_logdir} + +# start the workers +srun -l --container-name=${container_name} \ + --container-mounts=${mount_dir}:${mount_dir} \ + --mpi=pmix --overlap \ + bash ${dep_dir}/start_worker.sh ${full_logdir}/config.yaml "${enable_pdl}" ${ctx_gpus} ${nsys_on} &> ${full_logdir}/output_workers.log & +# start the server +srun -l --container-name=${container_name} \ + --container-mounts=${mount_dir}:${mount_dir} \ + --mpi=pmix --overlap -N 1 -n 1 \ + bash trtllm-serve disaggregated -c ${full_logdir}/config.yaml -t 1800 -r 1800 &> ${full_logdir}/output_server.log & +# start benchmark +srun -l --container-name=${container_name} \ + --container-mounts=${mount_dir}:${mount_dir} \ + --mpi=pmix --overlap -N 1 -n 1 \ + --nodelist=${hostname_value} \ + ${run_benchmark_cmd} ${isl} ${osl} ${multi_round} ${model_dir} "${concurrency_list}" ${streaming} ${full_logdir}/ > ${full_logdir}/benchmark.log 2>&1 +wait + +# try to kill the server and workers +srun -l --container-name=${container_name} \ + --container-mounts=${mount_dir}:${mount_dir} \ + --mpi=pmix --overlap \ + pkill -f "trtllm-serve" || true diff --git a/docs/source/scripts/disaggregated/gen_yaml.py b/docs/source/scripts/disaggregated/gen_yaml.py new file mode 100644 index 00000000000..d9924ebaa5f --- /dev/null +++ b/docs/source/scripts/disaggregated/gen_yaml.py @@ -0,0 +1,298 @@ +import argparse +import os +import re +from typing import Dict, List + +import yaml + + +def process_node_and_task() -> tuple[int, List[str], List[str]]: + """ + Process SLURM node and task environment variables. + + Returns: + tuple: (max_tasks_per_node, nodes, task_nodes) + """ + slurm_job_nodelist = os.getenv('SLURM_JOB_NODELIST', '') + print(f"SLURM_JOB_NODELIST: {slurm_job_nodelist}") + if not slurm_job_nodelist: + raise ValueError(f"Environment variable SLURM_JOB_NODELIST not found.") + + slurm_tasks_per_node = os.getenv('SLURM_TASKS_PER_NODE', '') + print(f"SLURM_TASKS_PER_NODE: {slurm_tasks_per_node}") + if not slurm_tasks_per_node: + raise ValueError( + f"Environment variable SLURM_TASKS_PER_NODE not found.") + + # Generate list of nodes + if '[' in slurm_job_nodelist: + # Handle nodelist with range format (e.g., "ptyche[0065-0066]") + node_prefix = re.match(r'^[a-zA-Z]+', slurm_job_nodelist).group(0) + node_range = re.search(r'\[(.*?)\]', slurm_job_nodelist).group(1) + nodes = [] + for part in node_range.split(','): + if '-' in part: + start, end = part.split('-') + # Get the width of the number format from the first number + width = len(start) + # Convert to integers after getting the width + start, end = int(start), int(end) + # Format numbers with leading zeros + nodes.extend([ + f"{node_prefix}{str(i).zfill(width)}" + for i in range(start, end + 1) + ]) + else: + # Preserve the original format for single numbers + nodes.append(f"{node_prefix}{part}") + else: + # Handle single node format (e.g., "ptyche0065") + nodes = [slurm_job_nodelist] + print(f"Nodes: {nodes}") + + # Generate tasks per node + tasks_per_node = [] + for part in slurm_tasks_per_node.split(','): + if '(x' in part: + count, repeat = map(int, re.findall(r'\d+', part)) + tasks_per_node.extend([count] * repeat) + else: + tasks_per_node.append(int(part)) + print(f"Tasks per node: {tasks_per_node}") + + if (len(tasks_per_node) != len(nodes)): + raise ValueError( + f"Number of nodes and tasks per node do not match. Number of nodes: {len(nodes)}, Number of tasks per node: {len(tasks_per_node)}" + ) + + max_tasks_per_node = max(tasks_per_node) + task_nodes = [] + for node, tasks in zip(nodes, tasks_per_node): + task_nodes.extend([node] * tasks) + + return max_tasks_per_node, nodes, task_nodes + + +def generate_urls(ctx_or_gen: str, + num_instances: int, + tensor_parallel_size: int, + pipeline_parallel_size: int, + max_tasks_per_node: int, + nodes: List[str], + task_nodes: List[str], + node_to_port: Dict[str, int], + task_nodes_offset: int = 0) -> tuple[List[str], int]: + """ + Generate URLs for context or generation servers. + + Returns: + tuple: (urls, updated_task_nodes_offset) + """ + urls = [] + + for instance in range(num_instances): + tasks_needed = tensor_parallel_size * pipeline_parallel_size + + if (task_nodes_offset + tasks_needed) > len(task_nodes): + print(f"{ctx_or_gen} urls so far: {urls}") + raise ValueError( + f"For {ctx_or_gen} instance {instance}, there are not enough tasks available. task_nodes_offset: {task_nodes_offset}, tasks_needed: {tasks_needed}, len(task_nodes): {len(task_nodes)}" + ) + + min_node = (tasks_needed + max_tasks_per_node - 1) / max_tasks_per_node + instance_nodes = set(task_nodes[task_nodes_offset:task_nodes_offset + + tasks_needed]) + if len(instance_nodes) > min_node: + raise ValueError( + f"Tasks for a instance {instance} of {ctx_or_gen} instances use more node than expected. Nodes used: {instance_nodes}, number of nodes expected: {min_node}, max_tasks_per_node: {max_tasks_per_node}" + ) + + node = task_nodes[task_nodes_offset] + port = node_to_port[node] + node_to_port[node] += 1 + task_nodes_offset += tasks_needed + + urls.append(f"{node}:{port}") + + print(f"{ctx_or_gen} urls: {urls}") + return urls, task_nodes_offset + + +def gen_config_file(config_path: str, + model_path: str, + num_ctx_servers: int, + ctx_tp_size: int, + ctx_batch_size: int, + ctx_max_num_tokens: int, + ctx_enable_attention_dp: bool, + num_gen_servers: int, + gen_tp_size: int, + gen_batch_size: int, + gen_max_num_tokens: int, + gen_enable_attention_dp: bool, + gen_gpu_memory_fraction: float, + worker_start_port: int = 8001, + server_port: int = 8000) -> None: + """ + Generate configuration YAML file for disaggregated inference. + + Args: + config_path: Path to save the config file + model_path: Path to the model + num_ctx_servers: Number of context servers + ctx_tp_size: Tensor parallel size for context servers + ctx_batch_size: Batch size for context servers + ctx_max_num_tokens: Max number of tokens for context servers + ctx_enable_attention_dp: Enable attention DP for context servers + num_gen_servers: Number of generation servers + gen_tp_size: Tensor parallel size for generation servers + gen_batch_size: Batch size for generation servers + gen_max_num_tokens: Max number of tokens for generation servers + gen_enable_attention_dp: Enable attention DP for generation servers + gen_gpu_memory_fraction: GPU memory fraction for generation servers + worker_start_port: Start port for workers + server_port: Server port + """ + gen_cuda_graph_batch_sizes = [ + 1, 2, 4, 8, 16, 32, 64, 128, 256, gen_batch_size + ] + + config = { + 'model': model_path, + 'hostname': 'localhost', + 'port': server_port, + 'backend': 'pytorch', + 'context_servers': { + 'num_instances': num_ctx_servers, + 'max_batch_size': ctx_batch_size, + 'max_num_tokens': ctx_max_num_tokens, + 'max_seq_len': 8300, + 'free_gpu_memory_fraction': 0.7, + 'tensor_parallel_size': ctx_tp_size, + 'moe_expert_parallel_size': ctx_tp_size, + 'enable_attention_dp': ctx_enable_attention_dp, + 'pipeline_parallel_size': 1, + 'print_iter_log': True, + 'disable_overlap_scheduler': True, + 'kv_cache_dtype': 'fp8', + 'cache_transceiver_config': { + 'max_num_tokens': 8320, + }, + }, + 'generation_servers': { + 'num_instances': num_gen_servers, + 'tensor_parallel_size': gen_tp_size, + 'moe_expert_parallel_size': gen_tp_size, + 'enable_attention_dp': gen_enable_attention_dp, + 'pipeline_parallel_size': 1, + 'max_batch_size': gen_batch_size, + 'max_num_tokens': gen_max_num_tokens, + 'max_seq_len': 8576, + 'free_gpu_memory_fraction': gen_gpu_memory_fraction, + 'use_cuda_graph': True, + 'cuda_graph_padding_enabled': True, + 'cuda_graph_batch_sizes': gen_cuda_graph_batch_sizes, + 'print_iter_log': True, + 'kv_cache_dtype': 'fp8', + 'moe_backend': 'TRTLLM', + 'cache_transceiver_config': { + 'max_num_tokens': 8320, + }, + } + } + + # Process nodes and generate URLs + max_tasks_per_node, nodes, task_nodes = process_node_and_task() + node_ports = {node: worker_start_port for node in nodes} + + # Generate URLs for context and generation servers + ctx_urls, task_nodes_offset = generate_urls("ctx", num_ctx_servers, + ctx_tp_size, 1, + max_tasks_per_node, nodes, + task_nodes, node_ports) + if num_ctx_servers > 0: + config['context_servers']['urls'] = ctx_urls + + gen_urls, _ = generate_urls("gen", num_gen_servers, gen_tp_size, 1, + max_tasks_per_node, nodes, task_nodes, + node_ports, task_nodes_offset) + config['generation_servers']['urls'] = gen_urls + + # set the hostname to the first node + config['hostname'] = nodes[0] + + # Write config to file + with open(config_path, 'w') as f: + yaml.dump(config, f, default_flow_style=False, sort_keys=False) + + +# gen main and args +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="/tmp/config.yaml") + parser.add_argument("--model", + type=str, + required=True, + help="Path to the model") + parser.add_argument("--num_ctx_servers", + type=int, + required=True, + help="Number of context servers") + parser.add_argument("--ctx_tp_size", + type=int, + required=True, + help="Tensor parallel size for context servers") + parser.add_argument("--ctx_batch_size", + type=int, + required=True, + help="Batch size for context servers") + parser.add_argument("--ctx_max_num_tokens", + type=int, + required=True, + help="Max number of tokens for context servers") + parser.add_argument("--ctx_enable_attention_dp", + dest='ctx_enable_attention_dp', + action='store_true', + help="Enable attention DP for context servers") + parser.add_argument("--num_gen_servers", + type=int, + required=True, + help="Number of generation servers") + parser.add_argument("--gen_tp_size", + type=int, + required=True, + help="Tensor parallel size for generation servers") + parser.add_argument("--gen_batch_size", + type=int, + required=True, + help="Batch size for generation servers") + parser.add_argument("--gen_max_num_tokens", + type=int, + required=True, + help="Max number of tokens for generation servers") + parser.add_argument("--gen_enable_attention_dp", + dest='gen_enable_attention_dp', + action='store_true', + help="Enable attention DP for generation servers") + parser.add_argument("--gen_gpu_memory_fraction", + type=float, + required=True, + help="GPU memory fraction for generation servers") + parser.add_argument("--worker_start_port", + type=int, + default=8336, + help="Start port for workers") + parser.add_argument("--server_port", + type=int, + default=8333, + help="Server port") + + args = parser.parse_args() + + gen_config_file(args.config, args.model, args.num_ctx_servers, + args.ctx_tp_size, args.ctx_batch_size, + args.ctx_max_num_tokens, args.ctx_enable_attention_dp, + args.num_gen_servers, args.gen_tp_size, args.gen_batch_size, + args.gen_max_num_tokens, args.gen_enable_attention_dp, + args.gen_gpu_memory_fraction, args.worker_start_port, + args.server_port) diff --git a/docs/source/scripts/disaggregated/run_benchmark.sh b/docs/source/scripts/disaggregated/run_benchmark.sh new file mode 100644 index 00000000000..1232cf9c415 --- /dev/null +++ b/docs/source/scripts/disaggregated/run_benchmark.sh @@ -0,0 +1,92 @@ +#!/bin/bash + +set -e +set -u +trap 'echo "Error occurred at line $LINENO"; exit 1' ERR + +if [ "$#" -lt 7 ]; then + echo "Error: Missing required arguments" + echo "Usage: $0 isl osl multi_round model_name concurrency_list streaming log_path" + exit 1 +fi + +isl=$1 +osl=$2 +multi_round=$3 +model_name=$4 +concurrency_list=$5 +streaming=$6 +log_path=$7 + +set -x +config_file=${log_path}/config.yaml + +# check if the config file exists every 10 seconds timeout 1800 seconds +timeout=1800 +start_time=$(date +%s) +while [ ! -f ${config_file} ]; do + current_time=$(date +%s) + elapsed=$((current_time - start_time)) + if [ $elapsed -ge $timeout ]; then + echo "Error: Config file ${config_file} not found within ${timeout} seconds" + exit 1 + fi + if [ $((elapsed % 30)) -eq 0 ]; then + echo "Waiting for config file... (${elapsed}s elapsed)" + fi + sleep 10 +done + +# grep the host and port from the config file +hostname=$(grep -i "hostname:" ${config_file} | awk '{print $2}') +port=$(grep -i "port:" ${config_file} | awk '{print $2}') +if [ -z "$hostname" ] || [ -z "$port" ]; then + echo "Error: Failed to extract hostname or port from config file" + exit 1 +fi +echo "Hostname: ${hostname}, Port: ${port}" + +# check server is health by curl every 10 seconds timeout 1800 seconds +timeout=1800 +start_time=$(date +%s) +while ! curl -s -o /dev/null -w "%{http_code}" http://${hostname}:${port}/health; do + hostname=$(grep -i "hostname:" ${config_file} | awk '{print $2}') + port=$(grep -i "port:" ${config_file} | awk '{print $2}') + echo "Hostname: ${hostname}, Port: ${port}" + current_time=$(date +%s) + elapsed=$((current_time - start_time)) + if [ $elapsed -ge $timeout ]; then + echo "Error: Server is not healthy after ${timeout} seconds" + exit 1 + fi + if [ $((elapsed % 30)) -eq 0 ]; then + echo "Waiting for server to be healthy... (${elapsed}s elapsed)" + fi + sleep 10 +done + +# run the benchmark +for concurrency in ${concurrency_list}; do + mkdir -p ${log_path}/concurrency_${concurrency} + max_count=$((${concurrency} * ${multi_round})) + echo "Running benchmark with concurrency: ${concurrency}, max_count: ${max_count}" + # run your benchmark here + python run_benchmark.py --model_name ${model_name} \ + --isl ${isl} \ + --osl ${osl} \ + --concurrency ${concurrency} \ + --max_count ${max_count} \ + --log_path ${log_path}/concurrency_${concurrency} + echo "done for ${concurrency} in folder ${log_path}/concurrency_${concurrency}" +done + +echo "Benchmark done, gracefully shutting down server and workers..." +pkill -f "start_worker.sh" || true +pkill -f "trtllm-serve" || true +sleep 20 # + +if pgrep -f "trtllm-serve"; then + echo "Warning: Some processes may still be running" +else + echo "All processes successfully terminated" +fi diff --git a/docs/source/scripts/disaggregated/start_worker.sh b/docs/source/scripts/disaggregated/start_worker.sh new file mode 100644 index 00000000000..6ba61d4906e --- /dev/null +++ b/docs/source/scripts/disaggregated/start_worker.sh @@ -0,0 +1,32 @@ +#! /bin/bash + +config_file=$1 +enable_pdl=$2 +ctx_gpus=$3 +work_dir=$4 + +export TLLM_LOG_LEVEL=INFO +export TRTLLM_USE_MPI_KVCACHE=1 +export TRTLLM_MNNVL_AR_ENABLED=1 + +if [ "${enable_pdl}" = "true" ]; then + export TRTLLM_ENABLE_PDL=1 +fi + +#check if work_dir is provided +if [ -z "${work_dir}" ]; then + trtllm-serve disaggregated_mpi_worker -c ${config_file} +else + nsys_prefix="" + nsys_file=${work_dir}/nsys_worker_proc_${SLURM_PROCID} + export TLLM_PROFILE_RECORD_GC=1 + export TLLM_NVTX_DEBUG=1 + if [ ${SLURM_PROCID} -ge ${ctx_gpus} ]; then + export TLLM_PROFILE_START_STOP=300-400 + else + export TLLM_PROFILE_START_STOP=25-100 + fi + nsys_prefix="nsys profile -e \"NSYS_MPI_STORE_TEAMS_PER_RANK=1\" -o ${nsys_file} -f true -t cuda,nvtx,python-gil -c cudaProfilerApi --cuda-graph-trace node --capture-range-end=stop --gpu-metrics-devices=all" + + ${nsys_prefix} trtllm-serve disaggregated_mpi_worker -c ${config_file} +fi diff --git a/docs/source/scripts/disaggregated/submit.sh b/docs/source/scripts/disaggregated/submit.sh new file mode 100644 index 00000000000..9757dc7d32f --- /dev/null +++ b/docs/source/scripts/disaggregated/submit.sh @@ -0,0 +1,36 @@ +#! /bin/bash + +slurm_file=disaggr_torch.slurm + +# ctx1dep4_gen1tep4, max_batch16 +for c in 1 2 4 8 16 32 48 64; do + sbatch --nodes=2 --ntasks=8 --ntasks-per-node=4 ${slurm_file} 1 4 1 8300 true 1 4 32 32 false "0.95" "$c" ctx1dep4_gen1tep4_${c} +done + +# ctx2dep4_gen1tep4, max_batch 64 +for c in 64 96 128; do + sbatch --nodes=3 --ntasks=12 --ntasks-per-node=4 ${slurm_file} 2 4 1 8300 true 1 4 64 64 false "0.9" "$c" ctx2dep4_gen1tep4_${c} +done + +for c in 128 192 256; do + sbatch --nodes=4 --ntasks=16 --ntasks-per-node=4 ${slurm_file} 3 4 1 8300 true 1 4 32 32 true "0.9" "$c" ctx3dep4_gen1dep4_${c} +done + +for c in 256 384 512; do + sbatch --nodes=5 --ntasks=20 --ntasks-per-node=4 ${slurm_file} 4 4 1 8300 true 1 4 64 64 true "0.9" "$c" ctx4dep4_gen1dep4_${c} +done + +# ctx5dep4_gen1dep4, max_batch +for c in 256 384 512; do + sbatch --nodes=6 --ntasks=24 --ntasks-per-node=4 ${slurm_file} 5 4 1 8300 true 1 4 64 64 true "0.9" "$c" ctx5dep4_gen1dep4_${c} +done + +# ctx7dep4_gen1dep4 +for c in 512 768 1024; do + sbatch --nodes=8 --ntasks=32 --ntasks-per-node=4 ${slurm_file} 7 4 1 8300 true 1 4 128 128 true "0.9" "$c" ctx7dep4_gen1dep4_${c} +done + +# ctx8dep4_gen1dep4 +for c in 512 768 1024; do + sbatch --nodes=9 --ntasks=36 --ntasks-per-node=4 ${slurm_file} 8 4 1 8300 true 1 4 128 128 true "0.9" "$c" ctx8dep4_gen1dep4_${c} +done