Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions examples/llm-api/llm_kv_cache_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs

# This is a simple example of the use of the KV cache connector.
# It persists KV cache contents into a folder, and can load them back on subsequent runs.
Expand All @@ -33,8 +33,8 @@ class PersistentKvCacheConnectorMetadata:

class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):

def __init__(self):
super().__init__()
def __init__(self, llm_args: TorchLlmArgs):
super().__init__(llm_args)

self.kv_cache_tensor = None

Expand Down Expand Up @@ -80,10 +80,10 @@ def get_finished(

class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):

def __init__(self, tokens_per_block):
super().__init__()
def __init__(self, llm_args: TorchLlmArgs):
super().__init__(llm_args)

self.block_size = tokens_per_block
self.block_size = self._llm_args.kv_cache_config.tokens_per_block
self.pending_loads = {}

self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,
Expand Down
7 changes: 5 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from tensorrt_llm.bindings.internal.batch_manager import \
KvCacheConnectorManager as KvCacheConnectorManagerCpp
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs

from .scheduler import ScheduledRequests

Expand Down Expand Up @@ -80,7 +81,8 @@ class SchedulerOutput:

class KvCacheConnectorWorker(ABC):

def __init__(self):
def __init__(self, llm_args: TorchLlmArgs):
self._llm_args = llm_args
self._metadata = None
super().__init__()

Expand Down Expand Up @@ -160,7 +162,8 @@ def get_finished(

class KvCacheConnectorScheduler(ABC):

def __init__(self):
def __init__(self, llm_args: TorchLlmArgs):
self._llm_args = llm_args
super().__init__()

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,11 @@ def drafting_loop_wrapper(model):
# In this case, the worker may be dependent on the scheduler, or vice-versa.
# To deal with cases like this, we instantiate them both concurrently.
with ThreadPoolExecutor(max_workers=2) as executor:
connector_worker_task = executor.submit(worker_cls)
connector_worker_task = executor.submit(worker_cls, llm_args)

if scheduler_cls is not None and rank == 0:
connector_scheduler_task = executor.submit(
scheduler_cls, executor_config.tokens_per_block)
scheduler_cls, llm_args)
connector_scheduler = connector_scheduler_task.result()
else:
connector_scheduler = None
Expand Down