Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions components/backends/trtllm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -303,3 +303,9 @@ sampling_params.logits_processor = create_trtllm_adapters(processors)
## Performance Sweep

For detailed instructions on running comprehensive performance sweeps across both aggregated and disaggregated serving configurations, see the [TensorRT-LLM Benchmark Scripts for DeepSeek R1 model](./performance_sweeps/README.md). This guide covers recommended benchmarking setups, usage of provided scripts, and best practices for evaluating system performance.

## Dynamo KV Block Manager Integration

Dynamo with TensorRT-LLM currently supports integration with the Dynamo KV Block Manager. This integration can significantly reduce time-to-first-token (TTFT) latency, particularly in usage patterns such as multi-turn conversations and repeated long-context requests.

Here is the instruction: [Running KVBM in TensorRT-LLM](./../../../docs/guides/run_kvbm_in_trtllm.md) .
15 changes: 11 additions & 4 deletions container/Dockerfile.trtllm
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
ARG BASE_IMAGE="nvcr.io/nvidia/pytorch"
ARG BASE_IMAGE_TAG="25.06-py3"
ARG RELEASE_BUILD
ARG ENABLE_KVBM=false
ARG RUNTIME_IMAGE="nvcr.io/nvidia/cuda"
ARG RUNTIME_IMAGE_TAG="12.9.1-runtime-ubuntu24.04"

Expand Down Expand Up @@ -234,6 +235,7 @@ ARG ARCH_ALT
FROM quay.io/pypa/manylinux_2_28_${ARCH_ALT} AS wheel_builder
ARG RELEASE_BUILD
ARG CARGO_BUILD_JOBS
ARG ENABLE_KVBM
# Set CARGO_BUILD_JOBS to 16 if not provided
# This is to prevent cargo from building $(nproc) jobs in parallel,
# which might exceed the number of opened files limit.
Expand Down Expand Up @@ -279,16 +281,21 @@ COPY launch /workspace/launch
RUN cargo build \
--release \
--locked \
--features dynamo-llm/block-manager \
--features block-manager \
--workspace

# Build dynamo wheels
RUN uv build --wheel --out-dir /workspace/dist && \
cd /workspace/lib/bindings/python && \
uv build --wheel --out-dir /workspace/dist --python 3.12 && \
uv pip install maturin[patchelf] && \
if [ "$ENABLE_KVBM" = "true" ]; then \
maturin build --release --features block-manager --out /workspace/dist; \
else \
maturin build --release --out /workspace/dist; \
fi && \
if [ "$RELEASE_BUILD" = "true" ]; then \
uv build --wheel --out-dir /workspace/dist --python 3.11 && \
uv build --wheel --out-dir /workspace/dist --python 3.10; \
uv run --python 3.11 maturin build --release --out /workspace/dist && \
uv run --python 3.10 maturin build --release --out /workspace/dist; \
fi

########################################
Expand Down
8 changes: 6 additions & 2 deletions docs/guides/run_kvbm_in_trtllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,12 @@ To use KVBM in TensorRT-LLM, you can follow the steps below:
# start up etcd for KVBM leader/worker registration and discovery
docker compose -f deploy/docker-compose.yml up -d

# Build a container that includes TensorRT-LLM and KVBM. Note: KVBM integration is only available in TensorRT-LLM commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 or newer.
./container/build.sh --framework trtllm --tensorrtllm-commit ce580ce4f52af3ad0043a800b3f9469e1f1109f6 --enable-kvbm
# Build a container that includes TensorRT-LLM and KVBM. Note: KVBM integration is only available in TensorRT-LLM commit dcd110cfac07e577ce01343c455917832b0f3d5e or newer.
# When building with the --tensorrtllm-commit option, you may notice that https://github.com keeps prompting for a username and password.
# This happens because cloning TensorRT-LLM can hit GitHub’s rate limit.
# To work around this, you can keep pressing "Enter" or "Return.".
# Setting "export GIT_LFS_SKIP_SMUDGE=1" may also reduce the number of prompts.
./container/build.sh --framework trtllm --tensorrtllm-commit dcd110cfac07e577ce01343c455917832b0f3d5e --enable-kvbm

# launch the container
./container/run.sh --framework trtllm -it --mount-workspace --use-nixl-gds
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
KvCacheConnectorScheduler,
SchedulerOutput,
)
from tensorrt_llm.bindings.executor import ExecutorConfig
from tensorrt_llm.bindings.internal.batch_manager import LlmRequest
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs

from dynamo.llm import KvbmLeader
from dynamo.llm.trtllm_integration.rust import KvbmRequest
Expand All @@ -21,21 +21,21 @@


class DynamoKVBMConnectorLeader(KvCacheConnectorScheduler):
def __init__(self, executor_config: ExecutorConfig):
super().__init__(executor_config)
def __init__(self, llm_args: TorchLlmArgs):
super().__init__(llm_args)
self.drt = DistributedRuntime.detached()

world_size = self._config.mapping.world_size
self.block_size = self._config.tokens_per_block
mappings = self._llm_args.parallel_config.to_mapping()

world_size = mappings.world_size
self.block_size = self._llm_args.kv_cache_config.tokens_per_block

# Set bytes_per_block to 0, because we will retrieve the actual value from the worker side.
leader = KvbmLeader(world_size, drt=self.drt)

print(
f"KvConnectorLeader initialized with rank: {executor_config.mapping.rank}"
)
print(f"KvConnectorLeader initialized with rank: {mappings.rank}")
self._connector = RustKvConnectorLeader(
executor_config.mapping.rank, self.drt, self.block_size, leader
mappings.rank, self.drt, self.block_size, leader
)

def build_connector_meta(self, scheduler_output: SchedulerOutput) -> bytes:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from tensorrt_llm import logger
from tensorrt_llm._torch.pyexecutor.kv_cache_connector import KvCacheConnectorWorker
from tensorrt_llm.bindings.executor import ExecutorConfig
from tensorrt_llm.llmapi.llm_args import TorchLlmArgs

from dynamo.llm.trtllm_integration.rust import (
KvConnectorWorker as RustKvConnectorWorker,
Expand All @@ -13,16 +13,15 @@


class DynamoKVBMConnectorWorker(KvCacheConnectorWorker):
def __init__(self, executor_config: ExecutorConfig):
super().__init__(executor_config)
def __init__(self, llm_args: TorchLlmArgs):
super().__init__(llm_args)

self.drt = DistributedRuntime.detached()

self.rank = executor_config.mapping.rank
mappings = self._llm_args.parallel_config.to_mapping()
self.rank = mappings.rank

self._connector = RustKvConnectorWorker(
self.drt, str(executor_config.mapping.rank)
)
self._connector = RustKvConnectorWorker(self.drt, str(self.rank))

def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
"""
Expand All @@ -33,11 +32,11 @@ def register_kv_caches(self, kv_cache_tensor: torch.Tensor):
"""
print(f"Register KV Caches on rank {self.rank}")
logger.info(
f"KvConnectorWorker started registering the kv caches on rank {self._config.mapping.rank}"
f"KvConnectorWorker started registering the kv caches on rank {self.rank}"
)

num_device_blocks = kv_cache_tensor.shape[0]
page_size = self._config.tokens_per_block
page_size = self._llm_args.kv_cache_config.tokens_per_block
device_id = kv_cache_tensor.device.index
kv_cache_dtype = kv_cache_tensor.dtype

Expand Down
Loading