Skip to content

Commit e7ae5e2

Browse files
pcastonguayraayandharreasonsolocoderabbitai[bot]
authored
feat: Add support for disaggregation with pp with pytorch backend (#6369)
Signed-off-by: Patrice Castonguay <[email protected]> Signed-off-by: raayandhar <[email protected]> Signed-off-by: Lizhi Zhou <[email protected]> Signed-off-by: pcastonguay <[email protected]> Co-authored-by: raayandhar <[email protected]> Co-authored-by: Lizhi Zhou <[email protected]> Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
1 parent a2514d9 commit e7ae5e2

File tree

15 files changed

+497
-22
lines changed

15 files changed

+497
-22
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -840,6 +840,8 @@ void CacheFormatter::unformat(TransferSession& session)
840840
if (selfConfig.getModelConfig().mNbKvHeadsPerLayer.size() != destConfig.getModelConfig().mNbKvHeadsPerLayer.size())
841841
{
842842
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support same number of layers");
843+
TLLM_LOG_WARNING("self: %zu dest %zu", selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(),
844+
destConfig.getModelConfig().mNbKvHeadsPerLayer.size());
843845
return false;
844846
}
845847
int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size();

scripts/build_wheel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,10 @@ def clear_folder(folder_path):
7171
if os.path.isdir(item_path) and not os.path.islink(item_path):
7272
rmtree(item_path)
7373
else:
74-
os.remove(item_path)
74+
try:
75+
os.remove(item_path)
76+
except (OSError, IOError) as e:
77+
print(f"Failed to remove {item_path}: {e}", file=sys.stderr)
7578

7679

7780
def sysconfig_scheme(override_vars=None):

tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,13 +96,13 @@ def __init__(self, mapping: Mapping, kv_cache_manager: KVCacheManager,
9696
attention_type: AttentionTypeCpp,
9797
cache_transceiver_config: CacheTransceiverConfig):
9898
world_config = mapping_to_world_config(mapping)
99-
num_kv_heads_per_layer = kv_cache_manager.num_kv_heads_per_layer
99+
total_num_kv_heads_per_layer = kv_cache_manager.total_num_kv_heads_per_layer
100100
head_dim = kv_cache_manager.head_dim
101101
tokens_per_block = kv_cache_manager.tokens_per_block
102102
dtype = kv_cache_manager.dtype
103103

104104
self.impl = CacheTransceiverCpp(kv_cache_manager.impl,
105-
num_kv_heads_per_layer, head_dim,
105+
total_num_kv_heads_per_layer, head_dim,
106106
tokens_per_block, world_config, dtype,
107107
attention_type,
108108
cache_transceiver_config)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class BatchState:
122122
@dataclasses.dataclass
123123
class BatchStatePP(BatchState):
124124
microbatch_id: int = -1
125+
scheduled_ctx_reqs: list[LlmRequest] = None
125126

126127

127128
class PyExecutor:
@@ -641,6 +642,7 @@ def _need_return_log_probs(self, scheduled_requests: ScheduledRequests):
641642
return False
642643

643644
def _executor_loop_pp(self):
645+
logger.debug(f"Starting executor loop for pp_rank {self.dist.pp_rank}")
644646
torch.cuda.set_device(self.device_id)
645647
microbatch_id = 0
646648
with self._profiler() as profile_step:
@@ -654,6 +656,9 @@ def _executor_loop_pp(self):
654656
if self.should_stop_processing:
655657
break
656658

659+
if self.kv_cache_transceiver:
660+
self._check_disagg_gen_transfer_status()
661+
657662
if self.enable_iter_perf_stats:
658663
iter_stats = self._get_init_iter_stats(
659664
len(new_requests),
@@ -662,9 +667,23 @@ def _executor_loop_pp(self):
662667

663668
self._pad_attention_dp_dummy_request()
664669

665-
scheduled_batch, _, _ = self._schedule()
670+
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
671+
)
672+
673+
if self.kv_cache_transceiver:
674+
# For requests that are fitting disagg gen init, also prepare resources for KV cache manager
675+
self._prepare_disagg_gen_init(
676+
fitting_disagg_gen_init_requests)
677+
678+
if num_fitting_reqs == 0 and not fitting_disagg_gen_init_requests:
679+
logger.warning(
680+
"num_fitting_reqs=0 and fitting_disagg_gen_init_requests is empty, may not have enough kvCache"
681+
)
682+
self.kv_cache_transceiver.check_context_transfer_status(
683+
1)
666684

667685
self.num_scheduled_requests = scheduled_batch.batch_size
686+
668687
logger.debug(
669688
f'has {len(self.active_requests)} active_request, '
670689
f'scheduled {len(scheduled_batch.context_requests)} context requests and '
@@ -677,7 +696,7 @@ def _executor_loop_pp(self):
677696
can_queue = 0 not in tp_batch_sizes
678697
else:
679698
can_queue = scheduled_batch.batch_size > 0
680-
if not can_queue:
699+
if not can_queue and not self.kv_cache_transceiver:
681700
assert len(self.inflight_req_ids) > 0, (
682701
"fail to schedule any pending request, probably run out of resource"
683702
)
@@ -686,8 +705,28 @@ def _executor_loop_pp(self):
686705
self.micro_batches[microbatch_id] = None
687706
else:
688707
self._add_inflight_ids(scheduled_batch)
708+
709+
if self.kv_cache_transceiver:
710+
# For generation requests which have completed KV cache transfer
711+
self._prepare_disagg_gen_transmission_complete(
712+
scheduled_batch)
713+
689714
self.resource_manager.prepare_resources(scheduled_batch)
690715

716+
# The generation requests that are do not have batch_idx,
717+
# needs to be in front of the batch due to the assumptions
718+
# made in model_engine.py::_forward_step. This is only important
719+
# for disaggregated serving. For non-disaggregated serving,
720+
# the generation requests always have batch_idx.
721+
scheduled_batch.generation_requests = sorted( # stable sort
722+
scheduled_batch.generation_requests,
723+
key=lambda req: int(req.py_batch_idx is not None),
724+
)
725+
726+
if self.kv_cache_transceiver:
727+
# Return the first token to the client
728+
self._handle_first_token_response(scheduled_batch)
729+
691730
# Stage 1: Async forward (all ranks) and decoding pass (last rank only)
692731
if not self.dist.is_last_pp_rank:
693732
sample_state = self._forward_step_inter_pp(
@@ -715,6 +754,7 @@ def _executor_loop_pp(self):
715754
iter_start_time=iter_start_time,
716755
iter_stats=iter_stats,
717756
microbatch_id=microbatch_id,
757+
scheduled_ctx_reqs=scheduled_batch.context_requests,
718758
)
719759

720760
self.micro_batches[microbatch_id] = batch_state
@@ -779,6 +819,11 @@ def _executor_loop_pp(self):
779819
if previous_batch is not None:
780820
with torch.cuda.nvtx.range("_handle_previous_batch_pp"):
781821
self._update_requests(previous_batch.sample_state)
822+
823+
if self.kv_cache_transceiver and previous_batch.scheduled_ctx_reqs:
824+
self._send_disagg_ctx_cache(
825+
previous_batch.scheduled_ctx_reqs)
826+
782827
self._handle_canceled_requests()
783828
finished_requests = self._handle_responses()
784829
previous_scheduled_batch = previous_batch.sample_state.scheduled_requests
@@ -787,6 +832,9 @@ def _executor_loop_pp(self):
787832
self._remove_inflight_ids(previous_scheduled_batch)
788833
self.micro_batches[prev_microbatch_id] = None
789834

835+
if self.kv_cache_transceiver and self.ctx_in_transmission_requests:
836+
self._terminate_ctx_finished_requests()
837+
790838
# march forward in microbatch slots
791839
microbatch_id = (microbatch_id + 1) % self.num_micro_batches
792840

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,18 +155,33 @@ def __init__(
155155
(num_kv_heads + tp_size - 1) // tp_size
156156
for _ in range(self.num_local_layers)
157157
]
158+
self.total_num_kv_heads_per_layer = [
159+
(num_kv_heads + tp_size - 1) // tp_size
160+
for _ in range(self.num_layers)
161+
]
158162
else:
159163
assert len(num_kv_heads) == self.num_layers
160164

165+
def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int],
166+
kv_head: Optional[int]):
167+
if kv_head is not None:
168+
num_kv_heads_per_layer.append(
169+
(kv_head + tp_size - 1) // tp_size)
170+
else:
171+
num_kv_heads_per_layer.append(0)
172+
161173
self.num_kv_heads_per_layer = []
162174
if self.num_local_layers > 0:
163175
for i in self.pp_layers:
164176
kv_head = num_kv_heads[i]
165-
if kv_head is not None:
166-
self.num_kv_heads_per_layer.append(
167-
(kv_head + tp_size - 1) // tp_size)
168-
else:
169-
self.num_kv_heads_per_layer.append(0)
177+
append_to_kv_heads_per_layer(self.num_kv_heads_per_layer,
178+
kv_head)
179+
180+
self.total_num_kv_heads_per_layer = []
181+
for i in range(self.num_layers):
182+
kv_head = num_kv_heads[i]
183+
append_to_kv_heads_per_layer(self.total_num_kv_heads_per_layer,
184+
kv_head)
170185

171186
self.num_kv_heads = num_kv_heads
172187
self.head_dim = head_dim

tests/integration/defs/accuracy/accuracy_core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -735,3 +735,14 @@ def setup_class(cls):
735735
logger.set_level("info")
736736
yield
737737
logger.set_level(original_level)
738+
739+
740+
def get_accuracy_task(dataset_name: str):
741+
try:
742+
task_class = globals()[dataset_name]
743+
if issubclass(task_class, AccuracyTask):
744+
return task_class
745+
else:
746+
raise ValueError(f"Unknown dataset: {dataset_name}.")
747+
except KeyError:
748+
raise ValueError(f"Not registered dataset: {dataset_name}.")

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@
2020
from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams
2121
from tensorrt_llm.llmapi.llm_args import LlmArgs
2222

23-
from ..conftest import llm_models_root, parametrize_with_ids, skip_pre_hopper
23+
from ..conftest import (get_device_count, llm_models_root, parametrize_with_ids,
24+
skip_pre_hopper)
2425
from ..trt_test_alternative import popen
25-
from .accuracy_core import GSM8K, MMLU, LlmapiAccuracyTestHarness
26+
from .accuracy_core import (GSM8K, MMLU, LlmapiAccuracyTestHarness,
27+
get_accuracy_task)
2628

2729

2830
class Result(GenerationResultBase):
@@ -71,6 +73,12 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
7173
temp_dir = tempfile.TemporaryDirectory()
7274
disaggregated_serving_config_path = os.path.join(
7375
temp_dir.name, "disaggregated_serving_config.yaml")
76+
77+
if tensor_parallel_size > 1:
78+
print(
79+
f"Using unified tp parameter for testing is not recommended. Please use server configs instead."
80+
)
81+
7482
with open(disaggregated_serving_config_path, "w") as f:
7583
yaml.dump(disaggregated_server_config, f)
7684
ctx_server_config_path = os.path.join(temp_dir.name,
@@ -88,27 +96,40 @@ def launch_disaggregated_llm(disaggregated_server_config: Dict[str, Any],
8896
trtllm_serve_path = "trtllm-serve"
8997
# Common arguments for both servers
9098
common_args = [
91-
trtllm_serve_path, model_name, "--host", "localhost", "--backend",
92-
"pytorch"
99+
trtllm_serve_path,
100+
model_name,
101+
"--host",
102+
"localhost",
103+
"--backend",
104+
"pytorch",
93105
]
94-
95-
if tensor_parallel_size > 1:
96-
common_args.append(f"--tp_size={tensor_parallel_size}")
106+
gen_tp, gen_pp = gen_server_config.get(
107+
"tensor_parallel_size",
108+
tensor_parallel_size), gen_server_config.get("pipeline_parallel_size",
109+
1)
110+
ctx_tp, ctx_pp = ctx_server_config.get(
111+
"tensor_parallel_size",
112+
tensor_parallel_size), ctx_server_config.get("pipeline_parallel_size",
113+
1)
114+
115+
ctx_total_gpus = ctx_tp * ctx_pp
116+
gen_total_gpus = gen_tp * gen_pp
97117

98118
env_ctx = os.environ.copy()
99119
env_ctx["TRTLLM_USE_UCX_KVCACHE"] = "1"
100-
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(
101-
map(str, range(tensor_parallel_size)))
120+
env_ctx["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, range(ctx_total_gpus)))
102121

103122
env_gen = os.environ.copy()
104123
env_gen["TRTLLM_USE_UCX_KVCACHE"] = "1"
105124
env_gen["CUDA_VISIBLE_DEVICES"] = ",".join(
106-
map(str, range(tensor_parallel_size, 2 * tensor_parallel_size)))
125+
map(str, range(ctx_total_gpus, ctx_total_gpus + gen_total_gpus)))
107126
ctx_server_args = common_args + [
108-
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path
127+
"--port", "8001", "--extra_llm_api_options", ctx_server_config_path,
128+
f"--tp_size={ctx_tp}", f"--pp_size={ctx_pp}"
109129
]
110130
gen_server_args = common_args + [
111-
"--port", "8002", "--extra_llm_api_options", gen_server_config_path
131+
"--port", "8002", "--extra_llm_api_options", gen_server_config_path,
132+
f"--tp_size={gen_tp}", f"--pp_size={gen_pp}"
112133
]
113134
if "max_num_tokens" in ctx_server_config:
114135
ctx_server_args.append(
@@ -182,6 +203,56 @@ def generate_async(prompt: str,
182203
disaggregated_server.wait()
183204

184205

206+
def run_parallel_test(model_name: str, model_path: str, ctx_pp: int,
207+
ctx_tp: int, gen_pp: int, gen_tp: int,
208+
test_set: LlmapiAccuracyTestHarness):
209+
if ctx_tp * ctx_pp + gen_tp * gen_pp > get_device_count():
210+
pytest.fail(
211+
f"Not enough devices for ctx_pp={ctx_pp}+ctx_tp={ctx_tp} and gen_pp={gen_pp}+gen_tp={gen_tp} test"
212+
)
213+
214+
kv_cache_config = {
215+
"free_gpu_memory_fraction": 0.5,
216+
"enable_block_reuse": False
217+
}
218+
ctx_server_config = {
219+
"pipeline_parallel_size": ctx_pp,
220+
"tensor_parallel_size": ctx_tp,
221+
"disable_overlap_scheduler": True,
222+
"kv_cache_config": kv_cache_config,
223+
"cache_transceiver_config": {
224+
"backend": "default"
225+
}
226+
}
227+
gen_server_config = {
228+
"tensor_parallel_size": gen_tp,
229+
"pipeline_parallel_size": gen_pp,
230+
"disable_overlap_scheduler": True,
231+
"kv_cache_config": kv_cache_config,
232+
"cache_transceiver_config": {
233+
"backend": "default"
234+
}
235+
}
236+
disaggregated_server_config = {
237+
"hostname": "localhost",
238+
"port": 8000,
239+
"backend": "pytorch",
240+
"context_servers": {
241+
"num_instances": 1,
242+
"urls": ["localhost:8001"]
243+
},
244+
"generation_servers": {
245+
"num_instances": 1,
246+
"urls": ["localhost:8002"]
247+
}
248+
}
249+
with launch_disaggregated_llm(disaggregated_server_config,
250+
ctx_server_config, gen_server_config,
251+
model_path) as llm:
252+
task = test_set(model_name)
253+
task.evaluate(llm)
254+
255+
185256
@pytest.mark.timeout(3600)
186257
class TestLlama3_1_8BInstruct(LlmapiAccuracyTestHarness):
187258
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
@@ -315,6 +386,20 @@ def test_eagle3(self, overlap_scheduler):
315386
task = GSM8K(self.MODEL_NAME)
316387
task.evaluate(llm)
317388

389+
@pytest.mark.parametrize("tp,pp", [(1, 2), (2, 1), (2, 2)],
390+
ids=["tp1pp2", "tp2pp1", "tp2pp2"])
391+
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
392+
def test_tp_pp_symmetric(self, tp, pp, testset):
393+
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, pp, tp, pp,
394+
tp, get_accuracy_task(testset))
395+
396+
@parametrize_with_ids("ctx_pp", [2, 4])
397+
@parametrize_with_ids("gen_tp", [1, 2])
398+
@pytest.mark.parametrize("testset", ["GSM8K", "MMLU"])
399+
def test_ctx_pp_gen_tp_asymmetric(self, ctx_pp, gen_tp, testset):
400+
return run_parallel_test(self.MODEL_NAME, self.MODEL_PATH, ctx_pp, 1, 1,
401+
gen_tp, get_accuracy_task(testset))
402+
318403

319404
@pytest.mark.skip_less_device_memory(140000)
320405
@pytest.mark.timeout(3600)

0 commit comments

Comments
 (0)