diff --git a/examples/llm-api/llm_kv_cache_connector.py b/examples/llm-api/llm_kv_cache_connector.py index 599fab6f9ac..1eac9a9cd98 100644 --- a/examples/llm-api/llm_kv_cache_connector.py +++ b/examples/llm-api/llm_kv_cache_connector.py @@ -15,7 +15,7 @@ from tensorrt_llm._torch.pyexecutor.kv_cache_connector import ( KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput) from tensorrt_llm.bindings.internal.batch_manager import LlmRequest -from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig +from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs # This is a simple example of the use of the KV cache connector. # It persists KV cache contents into a folder, and can load them back on subsequent runs. @@ -33,8 +33,8 @@ class PersistentKvCacheConnectorMetadata: class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker): - def __init__(self): - super().__init__() + def __init__(self, llm_args: TorchLlmArgs): + super().__init__(llm_args) self.kv_cache_tensor = None @@ -80,10 +80,10 @@ def get_finished( class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler): - def __init__(self, tokens_per_block): - super().__init__() + def __init__(self, llm_args: TorchLlmArgs): + super().__init__(llm_args) - self.block_size = tokens_per_block + self.block_size = self._llm_args.kv_cache_config.tokens_per_block self.pending_loads = {} self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY, diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py index 9bec793a8c4..813b36112fa 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py @@ -47,6 +47,7 @@ from tensorrt_llm.bindings.internal.batch_manager import \ KvCacheConnectorManager as KvCacheConnectorManagerCpp from tensorrt_llm.bindings.internal.batch_manager import LlmRequest +from tensorrt_llm.llmapi.llm_args import TorchLlmArgs from .scheduler import ScheduledRequests @@ -80,7 +81,8 @@ class SchedulerOutput: class KvCacheConnectorWorker(ABC): - def __init__(self): + def __init__(self, llm_args: TorchLlmArgs): + self._llm_args = llm_args self._metadata = None super().__init__() @@ -160,7 +162,8 @@ def get_finished( class KvCacheConnectorScheduler(ABC): - def __init__(self): + def __init__(self, llm_args: TorchLlmArgs): + self._llm_args = llm_args super().__init__() @abstractmethod diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 5a7502f844f..2ec87a1c1c5 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -445,11 +445,11 @@ def drafting_loop_wrapper(model): # In this case, the worker may be dependent on the scheduler, or vice-versa. # To deal with cases like this, we instantiate them both concurrently. with ThreadPoolExecutor(max_workers=2) as executor: - connector_worker_task = executor.submit(worker_cls) + connector_worker_task = executor.submit(worker_cls, llm_args) if scheduler_cls is not None and rank == 0: connector_scheduler_task = executor.submit( - scheduler_cls, executor_config.tokens_per_block) + scheduler_cls, llm_args) connector_scheduler = connector_scheduler_task.result() else: connector_scheduler = None