|
| 1 | +#!/bin/bash |
| 2 | +#SBATCH --job-name={{ job_name }} |
| 3 | +#SBATCH --nodes={{ total_nodes }} |
| 4 | +#SBATCH --ntasks={{ total_nodes }} |
| 5 | +#SBATCH --ntasks-per-node=1 |
| 6 | +#SBATCH --account={{ account }} |
| 7 | +#SBATCH --time={{ time_limit }} |
| 8 | +#SBATCH --output=logs/%j/log.out |
| 9 | +#SBATCH --error=logs/%j/log.err |
| 10 | + |
| 11 | +# Constants |
| 12 | +PREFILL_NODES={{ prefill_nodes }} |
| 13 | +DECODE_NODES={{ decode_nodes }} |
| 14 | +TOTAL_NODES=$((PREFILL_NODES + DECODE_NODES)) |
| 15 | +GPUS_PER_NODE={{ gpus_per_node }} |
| 16 | +LOG_DIR="${SLURM_SUBMIT_DIR}/logs/${SLURM_JOB_ID}/" |
| 17 | +SCRIPT_DIR="${SLURM_SUBMIT_DIR}/scripts" |
| 18 | +OUTPUT_DIR="${SLURM_SUBMIT_DIR}/outputs" |
| 19 | +MODEL_DIR="{{ model_dir }}" |
| 20 | +CONFIG_DIR="{{ config_dir }}" |
| 21 | +CONTAINER_IMAGE="{{ container_image }}" |
| 22 | +NETWORK_INTERFACE="{{ network_interface }}" |
| 23 | + |
| 24 | +{% raw %} |
| 25 | + |
| 26 | +mkdir -p "${OUTPUT_DIR}" "${LOG_DIR}" |
| 27 | + |
| 28 | +nodes=($(scontrol show hostnames $SLURM_NODELIST)) |
| 29 | +if [ ${#nodes[@]} -ne $TOTAL_NODES ]; then |
| 30 | + echo "Error: Expected $TOTAL_NODES nodes but got ${#nodes[@]} nodes" |
| 31 | + exit 1 |
| 32 | +fi |
| 33 | + |
| 34 | +# Print node information |
| 35 | +for i in "${!nodes[@]}"; do |
| 36 | + echo "Node $i: ${nodes[$i]}" |
| 37 | +done |
| 38 | + |
| 39 | +PREFILL_HOST_IP=$(srun --nodes=1 --ntasks=1 --nodelist=${nodes[0]} ifconfig $NETWORK_INTERFACE | grep -oP 'inet \K[0-9.]+') |
| 40 | +if [ -z "$PREFILL_HOST_IP" ]; then |
| 41 | + echo "Error: Could not retrieve IP address for prefill host ${nodes[0]} on interface $NETWORK_INTERFACE" |
| 42 | + exit 1 |
| 43 | +fi |
| 44 | +echo "Prefill host IP address: $PREFILL_HOST_IP" |
| 45 | + |
| 46 | +DECODE_HOST_IP=$(srun --nodes=1 --ntasks=1 --nodelist=${nodes[$PREFILL_NODES]} ifconfig $NETWORK_INTERFACE | grep -oP 'inet \K[0-9.]+') |
| 47 | +if [ -z "$DECODE_HOST_IP" ]; then |
| 48 | + echo "Error: Could not retrieve IP address for decode host ${nodes[$PREFILL_NODES]} on interface $NETWORK_INTERFACE" |
| 49 | + exit 1 |
| 50 | +fi |
| 51 | +echo "Decode host IP address: $DECODE_HOST_IP" |
| 52 | + |
| 53 | +# Prepare enroot arguments to pass to srun commands |
| 54 | +ENROOT_ARGS="\ |
| 55 | + --container-image=${CONTAINER_IMAGE} \ |
| 56 | + --no-container-entrypoint \ |
| 57 | + --container-mount-home \ |
| 58 | + --no-container-remap-root \ |
| 59 | + --container-mounts=${MODEL_DIR}:/model/,${CONFIG_DIR}:/configs/,${SCRIPT_DIR}:/scripts/,${OUTPUT_DIR}:/outputs/,${LOG_DIR}:/logs/ \ |
| 60 | +" |
| 61 | + |
| 62 | +# Launch prefill tasks on the first PREFILL_NODES nodes |
| 63 | +for i in $(seq 0 $((PREFILL_NODES - 1))); do |
| 64 | + node=${nodes[$i]} |
| 65 | + rank=$i |
| 66 | + echo "Launching prefill task on node ${i} (rank ${rank}): $node" |
| 67 | + echo "Srun args: $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_prefill.out --error=${LOG_DIR}/${node}_prefill.err" |
| 68 | + echo "Command: python /scripts/worker_setup.py --prefill_host_ip ${PREFILL_HOST_IP} --decode_host_ip ${DECODE_HOST_IP} --rank ${rank} --total_nodes ${PREFILL_NODES} --worker_type prefill --gpus_per_node ${GPUS_PER_NODE} --gpu_utilization_log /logs/${node}_prefill_gpu_utilization.log &" |
| 69 | + srun $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node \ |
| 70 | + --output=${LOG_DIR}/${node}_prefill.out --error=${LOG_DIR}/${node}_prefill.err \ |
| 71 | + python /scripts/worker_setup.py --prefill_host_ip ${PREFILL_HOST_IP} --decode_host_ip ${DECODE_HOST_IP} --rank ${rank} --total_nodes ${PREFILL_NODES} --worker_type prefill --gpus_per_node ${GPUS_PER_NODE} --gpu_utilization_log /logs/${node}_prefill_gpu_utilization.log & |
| 72 | +done |
| 73 | + |
| 74 | +# Launch decode tasks on the next DECODE_NODES nodes |
| 75 | +for i in $(seq $PREFILL_NODES $((PREFILL_NODES + DECODE_NODES - 1))); do |
| 76 | + node=${nodes[$i]} |
| 77 | + rank=$((i - PREFILL_NODES)) |
| 78 | + echo "Launching decode task on node ${i} (rank ${rank}): $node" |
| 79 | + echo "Srun args: $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node --output=${LOG_DIR}/${node}_decode.out --error=${LOG_DIR}/${node}_decode.err" |
| 80 | + echo "Command: python /scripts/worker_setup.py --decode_host_ip ${DECODE_HOST_IP} --prefill_host_ip ${PREFILL_HOST_IP} --rank ${rank} --total_nodes ${DECODE_NODES} --worker_type decode --gpus_per_node ${GPUS_PER_NODE} --gpu_utilization_log /logs/${node}_decode_gpu_utilization.log &" |
| 81 | + srun $ENROOT_ARGS --nodes=1 --ntasks=1 --nodelist=$node \ |
| 82 | + --output=${LOG_DIR}/${node}_decode.out --error=${LOG_DIR}/${node}_decode.err \ |
| 83 | + python /scripts/worker_setup.py --decode_host_ip ${DECODE_HOST_IP} --prefill_host_ip ${PREFILL_HOST_IP} --rank ${rank} --total_nodes ${DECODE_NODES} --worker_type decode --gpus_per_node ${GPUS_PER_NODE} --gpu_utilization_log /logs/${node}_decode_gpu_utilization.log & |
| 84 | +done |
| 85 | + |
| 86 | +echo "" |
| 87 | +echo "To connect to the host prefill node:" |
| 88 | +echo "srun $ENROOT_ARGS --jobid $SLURM_JOB_ID -w ${nodes[0]} --overlap --pty bash" |
| 89 | + |
| 90 | +echo "" |
| 91 | +echo "Make sure to cancel the job at the end:" |
| 92 | +echo "scancel $SLURM_JOB_ID" |
| 93 | + |
| 94 | +# Wait for all tasks to complete |
| 95 | +wait |
| 96 | +echo "Script finished at $(date)" |
| 97 | + |
| 98 | +{% endraw %} |
0 commit comments