Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ...pyexecutor.config import PyTorchConfig
from ...pyexecutor.model_engine import ModelEngine
from ...pyexecutor.py_executor import PyExecutor
from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager
from ...pyexecutor.resource_manager import KVCacheManager, ResourceManager, ResourceManagerType
from ...pyexecutor.sampler import TorchSampler
from ...pyexecutor.scheduler import (
BindCapacityScheduler,
Expand Down Expand Up @@ -151,7 +151,9 @@ def _prepare_inputs(
) -> bool:
"""Prepare inputs for AD Model from scheduled requests."""
# cache manager
kv_cache_manager = resource_manager.get_resource_manager("kv_cache_manager")
kv_cache_manager = resource_manager.get_resource_manager(
ResourceManagerType.KV_CACHE_MANAGER
)

# requests in order of context, extend (generate with draft), generate
context_requests = scheduled_requests.context_requests
Expand Down Expand Up @@ -290,8 +292,8 @@ def create_autodeploy_executor(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
)
resource_manager = ResourceManager({"kv_cache_manager": kv_cache_manager})
resource_manager.resource_managers.move_to_end("kv_cache_manager", last=True)
resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})
resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True)

# scheduling
capacitor_scheduler = BindCapacityScheduler(max_batch_size, kv_cache_manager.impl)
Expand Down
35 changes: 19 additions & 16 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
from .config_utils import is_mla, is_nemotron_hybrid
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
from .llm_request import ExecutorResponse
from .model_engine import (DRAFT_KV_CACHE_MANAGER_KEY, KV_CACHE_MANAGER_KEY,
PyTorchModelEngine)
from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor
from .resource_manager import (KVCacheManager, MambaHybridCacheManager,
PeftCacheManager, ResourceManager)
PeftCacheManager, ResourceManager,
ResourceManagerType)
from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler,
TRTLLMSampler)
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
Expand Down Expand Up @@ -245,7 +245,7 @@ def estimate_max_tokens(self, py_executor: PyExecutor) -> None:
f"Memory used outside torch (e.g., NCCL and CUDA graphs) in memory usage profiling: {extra_cost / (GB):.2f} GiB"
)
kv_stats = py_executor.resource_manager.resource_managers.get(
"kv_cache_manager").get_kv_cache_stats()
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()

kv_cache_max_tokens = self._cal_max_tokens(
peak_memory, total_gpu_memory, fraction,
Expand Down Expand Up @@ -349,7 +349,7 @@ def _create_kv_cache_manager(
spec_config=spec_config,
)
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to executor_config
if model_engine.kv_cache_manager_key == KV_CACHE_MANAGER_KEY:
if model_engine.kv_cache_manager_key == ResourceManagerType.KV_CACHE_MANAGER:
executor_config.max_seq_len = kv_cache_manager.max_seq_len

return kv_cache_manager
Expand All @@ -360,17 +360,19 @@ def build_managers(self, resources: Dict) -> None:
draft_kv_cache_manager = self._create_kv_cache_manager(
self._draft_model_engine
) if self._draft_model_engine is not None else None
resources[KV_CACHE_MANAGER_KEY] = kv_cache_manager
resources[DRAFT_KV_CACHE_MANAGER_KEY] = draft_kv_cache_manager
resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
resources[
ResourceManagerType.DRAFT_KV_CACHE_MANAGER] = draft_kv_cache_manager

def teardown_managers(self, resources: Dict) -> None:
"""Clean up KV caches for model and draft model (if applicable)."""
resources[KV_CACHE_MANAGER_KEY].shutdown()
del resources[KV_CACHE_MANAGER_KEY]
draft_kv_cache_manager = resources[DRAFT_KV_CACHE_MANAGER_KEY]
resources[ResourceManagerType.KV_CACHE_MANAGER].shutdown()
del resources[ResourceManagerType.KV_CACHE_MANAGER]
draft_kv_cache_manager = resources[
ResourceManagerType.DRAFT_KV_CACHE_MANAGER]
if draft_kv_cache_manager:
draft_kv_cache_manager.shutdown()
del resources[DRAFT_KV_CACHE_MANAGER_KEY]
del resources[ResourceManagerType.DRAFT_KV_CACHE_MANAGER]


def create_py_executor_instance(
Expand All @@ -386,7 +388,7 @@ def create_py_executor_instance(
sampler,
lora_config: Optional[LoraConfig] = None,
garbage_collection_gen0_threshold: Optional[int] = None) -> PyExecutor:
kv_cache_manager = resources.get(KV_CACHE_MANAGER_KEY, None)
kv_cache_manager = resources.get(ResourceManagerType.KV_CACHE_MANAGER, None)

spec_config = model_engine.spec_config
if mapping.is_last_pp_rank(
Expand Down Expand Up @@ -463,22 +465,23 @@ def create_py_executor_instance(
model_config=model_binding_config,
world_config=world_config,
)
resources["peft_cache_manager"] = peft_cache_manager
resources[ResourceManagerType.PEFT_CACHE_MANAGER] = peft_cache_manager
model_engine.set_lora_model_config(
lora_config.lora_target_modules,
lora_config.trtllm_modules_to_hf_modules)

max_num_sequences = executor_config.max_batch_size * mapping.pp_size

resources["seq_slot_manager"] = SeqSlotManager(max_num_sequences)
resources[ResourceManagerType.SEQ_SLOT_MANAGER] = SeqSlotManager(
max_num_sequences)

resource_manager = ResourceManager(resources)

# Make sure the kv cache manager is always invoked last as it could
# depend on the results of other resource managers.
if kv_cache_manager is not None:
resource_manager.resource_managers.move_to_end("kv_cache_manager",
last=True)
resource_manager.resource_managers.move_to_end(
ResourceManagerType.KV_CACHE_MANAGER, last=True)

capacity_scheduler = BindCapacityScheduler(
max_num_sequences,
Expand Down
14 changes: 5 additions & 9 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
from .guided_decoder import GuidedDecoder
from .layerwise_nvtx_marker import LayerwiseNvtxMarker
from .resource_manager import (BaseResourceManager, KVCacheManager,
ResourceManager)
ResourceManager, ResourceManagerType)
from .scheduler import ScheduledRequests

MAX_UINT64 = (1 << 64) - 1
Expand Down Expand Up @@ -272,10 +272,6 @@ def _get_random_min_max(dtype: torch.dtype) -> Tuple[int, int]:
param.uniform_(low, high, generator=generator)


KV_CACHE_MANAGER_KEY = 'kv_cache_manager'
DRAFT_KV_CACHE_MANAGER_KEY = 'draft_kv_cache_manager'


def get_rank_model_storage(model):
total_bytes = 0
for _, param in model.named_parameters():
Expand Down Expand Up @@ -497,7 +493,7 @@ def __init__(
# We look up this key in resource_manager during forward to find the
# kv cache manager. Can be changed to support multiple model engines
# with different KV cache managers.
self.kv_cache_manager_key = KV_CACHE_MANAGER_KEY
self.kv_cache_manager_key = ResourceManagerType.KV_CACHE_MANAGER
self.lora_model_config: Optional[LoraModelConfig] = None
self.cuda_graph_dummy_request = None

Expand Down Expand Up @@ -541,7 +537,7 @@ def warmup(self, resource_manager: ResourceManager) -> None:
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
spec_resource_manager = resource_manager.get_resource_manager(
'spec_resource_manager')
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if kv_cache_manager is None:
logger.info("Skipping warm up as no KV Cache manager allocated.")
return
Expand Down Expand Up @@ -2010,7 +2006,7 @@ def forward(self,
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
if self.is_spec_decode:
spec_resource_manager = resource_manager.get_resource_manager(
'spec_resource_manager')
ResourceManagerType.SPEC_RESOURCE_MANAGER)
spec_metadata = self._set_up_spec_metadata(spec_resource_manager,
no_cache=kv_cache_manager
is None)
Expand Down Expand Up @@ -2089,7 +2085,7 @@ def capture_forward_fn(inputs: Dict[str, Any]):
if self.mapping.is_last_pp_rank(
) and self.guided_decoder is not None:
seq_slot_manager = resource_manager.get_resource_manager(
"seq_slot_manager")
ResourceManagerType.SEQ_SLOT_MANAGER)
self.guided_decoder.build(scheduled_requests, seq_slot_manager)
self.guided_decoder.execute(scheduled_requests,
outputs['logits'], seq_slot_manager)
Expand Down
17 changes: 9 additions & 8 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import torch

from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank,
is_trace_enabled, nvtx_range, trace_func)
from tensorrt_llm.bindings.executor import (DisServingRequestStats,
Expand Down Expand Up @@ -215,7 +216,7 @@ def __init__(self,

# kv cache events
self.kv_cache_manager = self.resource_manager.resource_managers.get(
"kv_cache_manager")
ResourceManagerType.KV_CACHE_MANAGER)
self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0

if self.draft_model_engine is not None and self.kv_cache_manager is not None:
Expand Down Expand Up @@ -405,7 +406,7 @@ def get_latest_iteration_stats(self):

def get_latest_kv_cache_events(self):
kv_cache_manager = self.resource_manager.resource_managers.get(
"kv_cache_manager")
ResourceManagerType.KV_CACHE_MANAGER)
if not kv_cache_manager or not self.enable_kv_cache_events:
return []

Expand Down Expand Up @@ -529,7 +530,7 @@ def _get_init_iter_stats(self, num_new_active_requests,
# staticBatchingStats is not used in pytorch path
stats.static_batching_stats = StaticBatchingStats()
spec_resource_manager = self.resource_manager.resource_managers.get(
"spec_resource_manager")
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if spec_resource_manager is not None:
stats.specdec_stats = SpecDecodingStats()
return stats
Expand Down Expand Up @@ -606,7 +607,7 @@ def _update_iter_stats(self, stats, iter_latency_ms, num_completed_requests,
stats.iter = self.model_engine.iter_counter

kv_cache_manager = self.resource_manager.resource_managers.get(
"kv_cache_manager")
ResourceManagerType.KV_CACHE_MANAGER)
if kv_cache_manager is not None:
kv_stats = kv_cache_manager.get_kv_cache_stats()
kv_stats_to_save = KvCacheStats()
Expand Down Expand Up @@ -1308,7 +1309,7 @@ def _fetch_new_requests(self) -> List[RequestQueueItem]:

def _add_kv_cache_events(self):
kv_cache_manager = self.resource_manager.resource_managers.get(
"kv_cache_manager")
ResourceManagerType.KV_CACHE_MANAGER)
if not kv_cache_manager:
return
# Flush iteration events at each iteration to ensure that events have enough time
Expand Down Expand Up @@ -1514,7 +1515,7 @@ def _pad_attention_dp_dummy_request(self):
)[0]
llm_request.is_attention_dp_dummy = True
spec_resource_manager = self.resource_manager.get_resource_manager(
'spec_resource_manager')
ResourceManagerType.SPEC_RESOURCE_MANAGER)
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests([0])
self.active_requests.append(llm_request)
Expand All @@ -1528,7 +1529,7 @@ def _prepare_disagg_gen_init(self, fitting_disagg_gen_init_requests):
disagg_gen_init_to_prepare.paused_requests = []

self.resource_manager.resource_managers[
'kv_cache_manager'].prepare_resources(
ResourceManagerType.KV_CACHE_MANAGER].prepare_resources(
disagg_gen_init_to_prepare)

# Trigger KV cache exchange for new disagg_gen_init_requests
Expand Down Expand Up @@ -1590,7 +1591,7 @@ def _send_disagg_ctx_cache(self, scheduled_ctx_requests):
req.is_finished_due_to_length):
self.kv_cache_transceiver.respond_and_send_async(req)
self.resource_manager.resource_managers[
"seq_slot_manager"].free_resources(req)
ResourceManagerType.SEQ_SLOT_MANAGER].free_resources(req)

self.kv_cache_transceiver.check_context_transfer_status(0)

Expand Down
8 changes: 5 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch

import tensorrt_llm
from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.bindings.executor import ContextChunkingPolicy, ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import ContextChunkingConfig
Expand All @@ -23,7 +24,7 @@
instantiate_sampler, is_mla)
from .config import PyTorchConfig
from .config_utils import is_mla
from .model_engine import DRAFT_KV_CACHE_MANAGER_KEY, PyTorchModelEngine
from .model_engine import PyTorchModelEngine
from .py_executor import PyExecutor


Expand Down Expand Up @@ -242,7 +243,7 @@ def create_py_executor(
spec_config=draft_spec_config,
is_draft_model=True,
)
draft_model_engine.kv_cache_manager_key = DRAFT_KV_CACHE_MANAGER_KEY
draft_model_engine.kv_cache_manager_key = ResourceManagerType.DRAFT_KV_CACHE_MANAGER
draft_model_engine.load_weights_from_target_model(
model_engine.model)
else:
Expand Down Expand Up @@ -328,7 +329,8 @@ def create_py_executor(
spec_resource_manager = get_spec_resource_manager(
spec_config, model_engine, draft_model_engine)
if spec_resource_manager is not None:
resources["spec_resource_manager"] = spec_resource_manager
resources[ResourceManagerType.
SPEC_RESOURCE_MANAGER] = spec_resource_manager

with mem_monitor.observe_creation_stage(
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
Expand Down
9 changes: 9 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import math
from abc import ABC, abstractmethod
from collections import OrderedDict
Expand Down Expand Up @@ -36,6 +37,14 @@
WorldConfig = tensorrt_llm.bindings.WorldConfig


class ResourceManagerType(enum.Enum):
KV_CACHE_MANAGER = "KV_CACHE_MANAGER"
DRAFT_KV_CACHE_MANAGER = "DRAFT_KV_CACHE_MANAGER"
PEFT_CACHE_MANAGER = "PEFT_CACHE_MANAGER"
SEQ_SLOT_MANAGER = "SEQ_SLOT_MANAGER"
SPEC_RESOURCE_MANAGER = "SPEC_RESOURCE_MANAGER"


def compute_page_count(token_count: int, tokens_per_page: int) -> int:
return (token_count + tokens_per_page) // tokens_per_page

Expand Down
13 changes: 9 additions & 4 deletions tests/unittest/_torch/test_pytorch_model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest
from tensorrt_llm._torch.pyexecutor.model_engine import PyTorchModelEngine

# isort: off
from tensorrt_llm._torch.pyexecutor.resource_manager import (KVCacheManager,
ResourceManager)
ResourceManager,
ResourceManagerType
)
# isort: on
from tensorrt_llm._torch.pyexecutor.scheduler import ScheduledRequests
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.llmapi import SamplingParams
Expand Down Expand Up @@ -181,7 +186,7 @@ def test_pad_generation_requests(self) -> None:
def test_position_id_preparation(self):
model_engine, kv_cache_manager = create_model_engine_and_kvcache()
resource_manager = ResourceManager(
{"kv_cache_manager": kv_cache_manager})
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})

prompt_len = 256
requests = [_create_request(prompt_len, 0)]
Expand Down Expand Up @@ -225,7 +230,7 @@ def test_position_id_preparation(self):
def test_warmup(self):
model_engine, kv_cache_manager = create_model_engine_and_kvcache()
resource_manager = ResourceManager(
{"kv_cache_manager": kv_cache_manager})
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})

# Test with a huge batch size. The warmup run should bail out of
# warmup instead of crashing (there's not enough KV cache space for this).
Expand All @@ -245,7 +250,7 @@ def test_layerwise_nvtx_marker(self):
enable_layerwise_nvtx_marker=True)
model_engine, kv_cache_manager = create_model_engine_and_kvcache(config)
resource_manager = ResourceManager(
{"kv_cache_manager": kv_cache_manager})
{ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager})

prompt_len = 32
requests = [_create_request(prompt_len, 0)]
Expand Down