Skip to content

Commit ec1ee4f

Browse files
committed
add llm args to the connector api
Signed-off-by: richardhuo-nv <[email protected]> fix Signed-off-by: richardhuo-nv <[email protected]>
1 parent 9a4f606 commit ec1ee4f

File tree

3 files changed

+15
-10
lines changed

3 files changed

+15
-10
lines changed

examples/llm-api/llm_kv_cache_connector.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import (
1616
KvCacheConnectorScheduler, KvCacheConnectorWorker, SchedulerOutput)
1717
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
18-
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig
18+
from tensorrt_llm.llmapi.llm_args import KvCacheConnectorConfig, TorchLlmArgs
1919

2020
# This is a simple example of the use of the KV cache connector.
2121
# It persists KV cache contents into a folder, and can load them back on subsequent runs.
@@ -33,8 +33,8 @@ class PersistentKvCacheConnectorMetadata:
3333

3434
class PersistentKvCacheConnectorWorker(KvCacheConnectorWorker):
3535

36-
def __init__(self):
37-
super().__init__()
36+
def __init__(self, llm_args: TorchLlmArgs):
37+
super().__init__(llm_args)
3838

3939
self.kv_cache_tensor = None
4040

@@ -80,10 +80,10 @@ def get_finished(
8080

8181
class PersistentKvCacheConnectorLeader(KvCacheConnectorScheduler):
8282

83-
def __init__(self, tokens_per_block):
84-
super().__init__()
83+
def __init__(self, llm_args: TorchLlmArgs, tokens_per_block: int):
84+
super().__init__(llm_args, tokens_per_block)
8585

86-
self.block_size = tokens_per_block
86+
self.block_size = self._tokens_per_block
8787
self.pending_loads = {}
8888

8989
self.cache_folder = os.environ.get(CONNECTOR_CACHE_FOLDER_KEY,

tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from tensorrt_llm.bindings.internal.batch_manager import \
4848
KvCacheConnectorManager as KvCacheConnectorManagerCpp
4949
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
50+
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs
5051

5152
from .scheduler import ScheduledRequests
5253

@@ -80,7 +81,8 @@ class SchedulerOutput:
8081

8182
class KvCacheConnectorWorker(ABC):
8283

83-
def __init__(self):
84+
def __init__(self, llm_args: TorchLlmArgs):
85+
self._llm_args = llm_args
8486
self._metadata = None
8587
super().__init__()
8688

@@ -160,7 +162,9 @@ def get_finished(
160162

161163
class KvCacheConnectorScheduler(ABC):
162164

163-
def __init__(self):
165+
def __init__(self, llm_args: TorchLlmArgs, tokens_per_block: int):
166+
self._llm_args = llm_args
167+
self._tokens_per_block = tokens_per_block
164168
super().__init__()
165169

166170
@abstractmethod

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,11 +438,12 @@ def drafting_loop_wrapper(model):
438438
# In this case, the worker may be dependent on the scheduler, or vice-versa.
439439
# To deal with cases like this, we instantiate them both concurrently.
440440
with ThreadPoolExecutor(max_workers=2) as executor:
441-
connector_worker_task = executor.submit(worker_cls)
441+
connector_worker_task = executor.submit(worker_cls, llm_args)
442442

443443
if scheduler_cls is not None and rank == 0:
444444
connector_scheduler_task = executor.submit(
445-
scheduler_cls, executor_config.tokens_per_block)
445+
scheduler_cls, llm_args,
446+
executor_config.tokens_per_block)
446447
connector_scheduler = connector_scheduler_task.result()
447448
else:
448449
connector_scheduler = None

0 commit comments

Comments
 (0)