Skip to content

Commit 14f761d

Browse files
committed
refactor: Introduce ResourceManagerType enum for resource management
- Added ResourceManagerType enum to standardize resource manager keys. - Updated references in ad_executor, model_engine, and py_executor to use ResourceManagerType. - Refactored resource manager handling in tests to align with the new enum structure. - Improved code readability and maintainability by replacing string literals with enum values. Signed-off-by: Robin Kobus <[email protected]>
1 parent 546274d commit 14f761d

File tree

7 files changed

+62
-44
lines changed

7 files changed

+62
-44
lines changed

tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from ...pyexecutor.config import PyTorchConfig
1616
from ...pyexecutor.model_engine import ModelEngine
1717
from ...pyexecutor.py_executor import PyExecutor
18-
from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager
18+
from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager, ResourceManagerType
1919
from ...pyexecutor.sampler import TorchSampler
2020
from ...pyexecutor.scheduler import (
2121
BindCapacityScheduler,
@@ -151,7 +151,9 @@ def _prepare_inputs(
151151
) -> bool:
152152
"""Prepare inputs for AD Model from scheduled requests."""
153153
# cache manager
154-
kv_cache_manager = resource_manager.get_resource_manager("kv_cache_manager")
154+
kv_cache_manager = resource_manager.get_resource_manager(
155+
ResourceManagerType.KV_CACHE_MANAGER
156+
)
155157

156158
# requests in order of context, extend (generate with draft), generate
157159
context_requests = scheduled_requests.context_requests
@@ -290,8 +292,8 @@ def create_autodeploy_executor(
290292
max_seq_len=max_seq_len,
291293
max_batch_size=max_batch_size,
292294
)
293-
resource_manager = ResourceManager({"kv_cache_manager": kv_cache_manager})
294-
resource_manager.resource_managers.move_to_end("kv_cache_manager", last=True)
295+
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
296+
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)
295297

296298
# scheduling
297299
capacitor_scheduler = BindCapacityScheduler(max_batch_size, kv_cache_manager.impl)

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
from .config_utils import is_mla, is_nemotron_hybrid
2222
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
2323
from .llm_request import ExecutorResponse
24-
from .model_engine import (DRAFT_KV_CACHE_MANAGER_KEY, KV_CACHE_MANAGER_KEY,
25-
PyTorchModelEngine)
24+
from .model_engine import PyTorchModelEngine
2625
from .py_executor import PyExecutor
2726
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
28-
PeftCacheManager, ResourceManager)
27+
PeftCacheManager, ResourceManager,
28+
ResourceManagerType)
2929
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
3030
TRTLLMSampler)
3131
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
@@ -245,7 +245,7 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
245245
f"Memory used outside torch (e.g., NCCL and CUDA graphs) in memory usage profiling: {extra_cost / (GB):.2f} GiB"
246246
)
247247
kv_stats = py_executor.resource_manager.resource_managers.get(
248-
"kv_cache_manager").get_kv_cache_stats()
248+
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()
249249

250250
kv_cache_max_tokens = self._cal_max_tokens(
251251
peak_memory, total_gpu_memory, fraction,
@@ -349,7 +349,7 @@ def _create_kv_cache_manager(
349349
spec_config=spec_config,
350350
)
351351
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
352-
if model_engine.kv_cache_manager_key == KV_CACHE_MANAGER_KEY:
352+
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
353353
executor_config.max_seq_len = kv_cache_manager.max_seq_len
354354

355355
return kv_cache_manager
@@ -360,17 +360,19 @@ def build_managers(self, resources: Dict) -> None:
360360
draft_kv_cache_manager = self._create_kv_cache_manager(
361361
self._draft_model_engine
362362
) if self._draft_model_engine is not None else None
363-
resources[KV_CACHE_MANAGER_KEY] = kv_cache_manager
364-
resources[DRAFT_KV_CACHE_MANAGER_KEY] = draft_kv_cache_manager
363+
resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
364+
resources[
365+
ResourceManagerType.DRAFT_KV_CACHE_MANAGER] = draft_kv_cache_manager
365366

366367
def teardown_managers(self, resources: Dict) -> None:
367368
"""Clean up KV caches for model and draft model (if applicable)."""
368-
resources[KV_CACHE_MANAGER_KEY].shutdown()
369-
del resources[KV_CACHE_MANAGER_KEY]
370-
draft_kv_cache_manager = resources[DRAFT_KV_CACHE_MANAGER_KEY]
369+
resources[ResourceManagerType.KV_CACHE_MANAGER].shutdown()
370+
del resources[ResourceManagerType.KV_CACHE_MANAGER]
371+
draft_kv_cache_manager = resources[
372+
ResourceManagerType.DRAFT_KV_CACHE_MANAGER]
371373
if draft_kv_cache_manager:
372374
draft_kv_cache_manager.shutdown()
373-
del resources[DRAFT_KV_CACHE_MANAGER_KEY]
375+
del resources[ResourceManagerType.DRAFT_KV_CACHE_MANAGER]
374376

375377

376378
def create_py_executor_instance(
@@ -386,7 +388,7 @@ def create_py_executor_instance(
386388
sampler,
387389
lora_config: Optional[LoraConfig] = None,
388390
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
389-
kv_cache_manager = resources.get(KV_CACHE_MANAGER_KEY, None)
391+
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)
390392

391393
spec_config = model_engine.spec_config
392394
if mapping.is_last_pp_rank(
@@ -463,22 +465,23 @@ def create_py_executor_instance(
463465
model_config=model_binding_config,
464466
world_config=world_config,
465467
)
466-
resources["peft_cache_manager"] = peft_cache_manager
468+
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
467469
model_engine.set_lora_model_config(
468470
lora_config.lora_target_modules,
469471
lora_config.trtllm_modules_to_hf_modules)
470472

471473
max_num_sequences = executor_config.max_batch_size * mapping.pp_size
472474

473-
resources["seq_slot_manager"] = SeqSlotManager(max_num_sequences)
475+
resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
476+
max_num_sequences)
474477

475478
resource_manager = ResourceManager(resources)
476479

477480
# Make sure the kv cache manager is always invoked last as it could
478481
# depend on the results of other resource managers.
479482
if kv_cache_manager is not None:
480-
resource_manager.resource_managers.move_to_end("kv_cache_manager",
481-
last=True)
483+
resource_manager.resource_managers.move_to_end(
484+
ResourceManagerType.KV_CACHE_MANAGER, last=True)
482485

483486
capacity_scheduler = BindCapacityScheduler(
484487
max_num_sequences,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from .guided_decoder import GuidedDecoder
6161
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
6262
from .resource_manager import (BaseResourceManager, KVCacheManager,
63-
ResourceManager)
63+
ResourceManager, ResourceManagerType)
6464
from .scheduler import ScheduledRequests
6565

6666
MAX_UINT64 = (1 << 64) - 1
@@ -272,10 +272,6 @@ def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]:
272272
param.uniform_(low, high, generator=generator)
273273

274274

275-
KV_CACHE_MANAGER_KEY = 'kv_cache_manager'
276-
DRAFT_KV_CACHE_MANAGER_KEY = 'draft_kv_cache_manager'
277-
278-
279275
def get_rank_model_storage(model):
280276
total_bytes = 0
281277
for _, param in model.named_parameters():
@@ -497,7 +493,7 @@ def __init__(
497493
# We look up this key in resource_manager during forward to find the
498494
# kv cache manager. Can be changed to support multiple model engines
499495
# with different KV cache managers.
500-
self.kv_cache_manager_key = KV_CACHE_MANAGER_KEY
496+
self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
501497
self.lora_model_config: Optional[LoraModelConfig] = None
502498
self.cuda_graph_dummy_request = None
503499

@@ -541,7 +537,7 @@ def warmup(self, resource_manager: ResourceManager) -> None:
541537
kv_cache_manager = resource_manager.get_resource_manager(
542538
self.kv_cache_manager_key)
543539
spec_resource_manager = resource_manager.get_resource_manager(
544-
'spec_resource_manager')
540+
ResourceManagerType.SPEC_RESOURCE_MANAGER)
545541
if kv_cache_manager is None:
546542
logger.info("Skipping warm up as no KV Cache manager allocated.")
547543
return
@@ -2010,7 +2006,7 @@ def forward(self,
20102006
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
20112007
if self.is_spec_decode:
20122008
spec_resource_manager = resource_manager.get_resource_manager(
2013-
'spec_resource_manager')
2009+
ResourceManagerType.SPEC_RESOURCE_MANAGER)
20142010
spec_metadata = self._set_up_spec_metadata(spec_resource_manager,
20152011
no_cache=kv_cache_manager
20162012
is None)
@@ -2089,7 +2085,7 @@ def capture_forward_fn(inputs: Dict[str, Any]):
20892085
if self.mapping.is_last_pp_rank(
20902086
) and self.guided_decoder is not None:
20912087
seq_slot_manager = resource_manager.get_resource_manager(
2092-
"seq_slot_manager")
2088+
ResourceManagerType.SEQ_SLOT_MANAGER)
20932089
self.guided_decoder.build(scheduled_requests, seq_slot_manager)
20942090
self.guided_decoder.execute(scheduled_requests,
20952091
outputs['logits'], seq_slot_manager)

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import torch
1818

19+
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1920
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
2021
is_trace_enabled, nvtx_range, trace_func)
2122
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
@@ -215,7 +216,7 @@ def __init__(self,
215216

216217
# kv cache events
217218
self.kv_cache_manager = self.resource_manager.resource_managers.get(
218-
"kv_cache_manager")
219+
ResourceManagerType.KV_CACHE_MANAGER)
219220
self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0
220221

221222
if self.draft_model_engine is not None and self.kv_cache_manager is not None:
@@ -405,7 +406,7 @@ def get_latest_iteration_stats(self):
405406

406407
def get_latest_kv_cache_events(self):
407408
kv_cache_manager = self.resource_manager.resource_managers.get(
408-
"kv_cache_manager")
409+
ResourceManagerType.KV_CACHE_MANAGER)
409410
if not kv_cache_manager or not self.enable_kv_cache_events:
410411
return []
411412

@@ -529,7 +530,7 @@ def _get_init_iter_stats(self, num_new_active_requests,
529530
# staticBatchingStats is not used in pytorch path
530531
stats.static_batching_stats = StaticBatchingStats()
531532
spec_resource_manager = self.resource_manager.resource_managers.get(
532-
"spec_resource_manager")
533+
ResourceManagerType.SPEC_RESOURCE_MANAGER)
533534
if spec_resource_manager is not None:
534535
stats.specdec_stats = SpecDecodingStats()
535536
return stats
@@ -606,7 +607,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
606607
stats.iter = self.model_engine.iter_counter
607608

608609
kv_cache_manager = self.resource_manager.resource_managers.get(
609-
"kv_cache_manager")
610+
ResourceManagerType.KV_CACHE_MANAGER)
610611
if kv_cache_manager is not None:
611612
kv_stats = kv_cache_manager.get_kv_cache_stats()
612613
kv_stats_to_save = KvCacheStats()
@@ -1308,7 +1309,7 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:
13081309

13091310
def _add_kv_cache_events(self):
13101311
kv_cache_manager = self.resource_manager.resource_managers.get(
1311-
"kv_cache_manager")
1312+
ResourceManagerType.KV_CACHE_MANAGER)
13121313
if not kv_cache_manager:
13131314
return
13141315
# Flush iteration events at each iteration to ensure that events have enough time
@@ -1514,7 +1515,7 @@ def _pad_attention_dp_dummy_request(self):
15141515
)[0]
15151516
llm_request.is_attention_dp_dummy = True
15161517
spec_resource_manager = self.resource_manager.get_resource_manager(
1517-
'spec_resource_manager')
1518+
ResourceManagerType.SPEC_RESOURCE_MANAGER)
15181519
if spec_resource_manager is not None:
15191520
spec_resource_manager.add_dummy_requests([0])
15201521
self.active_requests.append(llm_request)
@@ -1528,7 +1529,7 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
15281529
disagg_gen_init_to_prepare.paused_requests = []
15291530

15301531
self.resource_manager.resource_managers[
1531-
'kv_cache_manager'].prepare_resources(
1532+
ResourceManagerType.KV_CACHE_MANAGER].prepare_resources(
15321533
disagg_gen_init_to_prepare)
15331534

15341535
# Trigger KV cache exchange for new disagg_gen_init_requests
@@ -1590,7 +1591,7 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
15901591
req.is_finished_due_to_length):
15911592
self.kv_cache_transceiver.respond_and_send_async(req)
15921593
self.resource_manager.resource_managers[
1593-
"seq_slot_manager"].free_resources(req)
1594+
ResourceManagerType.SEQ_SLOT_MANAGER].free_resources(req)
15941595

15951596
self.kv_cache_transceiver.check_context_transfer_status(0)
15961597

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99

1010
import tensorrt_llm
11+
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
1112
from tensorrt_llm._utils import get_sm_version
1213
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
1314
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
@@ -23,7 +24,7 @@
2324
instantiate_sampler, is_mla)
2425
from .config import PyTorchConfig
2526
from .config_utils import is_mla
26-
from .model_engine import DRAFT_KV_CACHE_MANAGER_KEY, PyTorchModelEngine
27+
from .model_engine import PyTorchModelEngine
2728
from .py_executor import PyExecutor
2829

2930

@@ -242,7 +243,7 @@ def create_py_executor(
242243
spec_config=draft_spec_config,
243244
is_draft_model=True,
244245
)
245-
draft_model_engine.kv_cache_manager_key = DRAFT_KV_CACHE_MANAGER_KEY
246+
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
246247
draft_model_engine.load_weights_from_target_model(
247248
model_engine.model)
248249
else:
@@ -328,7 +329,8 @@ def create_py_executor(
328329
spec_resource_manager = get_spec_resource_manager(
329330
spec_config, model_engine, draft_model_engine)
330331
if spec_resource_manager is not None:
331-
resources["spec_resource_manager"] = spec_resource_manager
332+
resources[ResourceManagerType.
333+
SPEC_RESOURCE_MANAGER] = spec_resource_manager
332334

333335
with mem_monitor.observe_creation_stage(
334336
_ExecutorCreationStage.INIT_EXTRA_RESOURCES

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import math
23
from abc import ABC, abstractmethod
34
from collections import OrderedDict
@@ -36,6 +37,14 @@
3637
WorldConfig = tensorrt_llm.bindings.WorldConfig
3738

3839

40+
class ResourceManagerType(enum.Enum):
41+
KV_CACHE_MANAGER = "KV_CACHE_MANAGER"
42+
DRAFT_KV_CACHE_MANAGER = "DRAFT_KV_CACHE_MANAGER"
43+
PEFT_CACHE_MANAGER = "PEFT_CACHE_MANAGER"
44+
SEQ_SLOT_MANAGER = "SEQ_SLOT_MANAGER"
45+
SPEC_RESOURCE_MANAGER = "SPEC_RESOURCE_MANAGER"
46+
47+
3948
def compute_page_count(token_count: int, tokens_per_page: int) -> int:
4049
return (token_count + tokens_per_page) // tokens_per_page
4150

tests/unittest/_torch/test_pytorch_model_engine.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,13 @@
99
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
1010
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
1111
from tensorrt_llm._torch.pyexecutor.model_engine import PyTorchModelEngine
12+
13+
# isort: off
1214
from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager,
13-
ResourceManager)
15+
ResourceManager,
16+
ResourceManagerType
17+
)
18+
# isort: on
1419
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
1520
from tensorrt_llm.bindings.executor import KvCacheConfig
1621
from tensorrt_llm.llmapi import SamplingParams
@@ -181,7 +186,7 @@ def test_pad_generation_requests(self) -> None:
181186
def test_position_id_preparation(self):
182187
model_engine, kv_cache_manager = create_model_engine_and_kvcache()
183188
resource_manager = ResourceManager(
184-
{"kv_cache_manager": kv_cache_manager})
189+
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
185190

186191
prompt_len = 256
187192
requests = [_create_request(prompt_len, 0)]
@@ -225,7 +230,7 @@ def test_position_id_preparation(self):
225230
def test_warmup(self):
226231
model_engine, kv_cache_manager = create_model_engine_and_kvcache()
227232
resource_manager = ResourceManager(
228-
{"kv_cache_manager": kv_cache_manager})
233+
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
229234

230235
# Test with a huge batch size. The warmup run should bail out of
231236
# warmup instead of crashing (there's not enough KV cache space for this).
@@ -245,7 +250,7 @@ def test_layerwise_nvtx_marker(self):
245250
enable_layerwise_nvtx_marker=True)
246251
model_engine, kv_cache_manager = create_model_engine_and_kvcache(config)
247252
resource_manager = ResourceManager(
248-
{"kv_cache_manager": kv_cache_manager})
253+
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
249254

250255
prompt_len = 32
251256
requests = [_create_request(prompt_len, 0)]

0 commit comments

Comments
 (0)