Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docker/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ else
endif
SOURCE_DIR ?= $(shell readlink -f ..)
CODE_DIR ?= /code/tensorrt_llm
EXTRA_VOLUMES ?=
CCACHE_DIR ?= ${CODE_DIR}/cpp/.ccache
CONAN_DIR ?= ${CODE_DIR}/cpp/.conan
RUN_CMD ?=
Expand All @@ -138,6 +139,7 @@ endif
docker run $(DOCKER_RUN_OPTS) $(DOCKER_RUN_ARGS) \
$(GPU_OPTS) \
--volume $(SOURCE_DIR):$(CODE_DIR) \
$(EXTRA_VOLUMES) \
--env "CCACHE_DIR=${CCACHE_DIR}" \
--env "CCACHE_BASEDIR=${CODE_DIR}" \
--env "CONAN_HOME=${CONAN_DIR}" \
Expand Down
5 changes: 5 additions & 0 deletions docker/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ Containers can be started with the local user instead of `root` by appending `LO
make -C docker devel_run LOCAL_USER=1
```

Extra docker volumes can be mounted in addition to the source code by appending `EXTRA_VOLUMES=` to the run target:
```bash
make -C docker devel_run LOCAL_USER=1 EXTRA_VOLUMES="--volume /pathA:/pathA --volume /pathB:/pathB"
```

Specific CUDA architectures supported by the `wheel` can be specified WITH `CUDA_ARCHS`:

```bash
Expand Down
5 changes: 3 additions & 2 deletions examples/pytorch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,15 @@ python3 quickstart_advanced.py --model_dir nvidia/Nemotron-H-8B-Base-8K --disabl

```bash
# default inputs
python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --modality image [--use_cuda_graph]
TLLM_MULTIMODAL_DISAGGREGATED=1 python3 quickstart_multimodal.py --model_dir llava-hf/llava-v1.6-mistral-7b-hf --modality image [--use_cuda_graph]

# user inputs
# supported modes:
# (1) N prompt, N media (N requests are in-flight batched)
# (2) 1 prompt, N media
# Note: media should be either image or video. Mixing image and video is not supported.
python3 quickstart_multimodal.py --model_dir Efficient-Large-Model/NVILA-8B --modality video --prompt "Tell me what you see in the video briefly." "Describe the scene in the video briefly." --media "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4" "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4" --max_tokens 128 [--use_cuda_graph]
TLLM_MULTIMODAL_DISAGGREGATED=0 python3 quickstart_multimodal.py --model_dir llava-hf/llava-v1.6-mistral-7b-hf --modality video --prompt "Tell me what you see in the video briefly." "Describe the scene in the video briefly." --media "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/OAI-sora-tokyo-walk.mp4" "https://huggingface.co/datasets/Efficient-Large-Model/VILA-inference-demos/resolve/main/world.mp4" --max_tokens 64 [--use_cuda_graph] [--enable_overlap_scheduler]
# use TLLM_MULTIMODAL_DISAGGREGATED to control vision+LLM in a single forward pass (0) or in separate forward pass (1)
```

### Supported Models
Expand Down
55 changes: 31 additions & 24 deletions examples/pytorch/quickstart_multimodal.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import argparse
import json
import os
from typing import Any, Dict, List
from functools import partial
from typing import Any, Callable, Dict, List

from quickstart_advanced import add_llm_args, setup_llm

Expand All @@ -28,24 +29,24 @@
]


def prepare_multimodal_inputs(model_dir: str,
model_type: str,
modality: str,
prompts: List[str],
media: List[str],
image_data_format: str = "pt",
num_frames: int = 8) -> List[Dict[str, Any]]:
def prepare_multimodal_inputs(
model_dir: str,
modality: str,
prompts: List[str],
media: List[str],
input_formatter: Callable,
mm_loader: Callable,
data_format: str = "pt", # Options: "pt" or "pil"
device: str = "cuda") -> List[Dict[str, Any]]:

inputs = []
if modality == "image":
inputs = default_image_loader(prompts, media, image_data_format)
elif modality == "video":
inputs = default_video_loader(prompts, media, image_data_format,
num_frames)
if modality in ["image", "video"]:
assert mm_loader, "multimodal data loader is required for image/video modality"
inputs = mm_loader(prompts, media, data_format, device)
else:
raise ValueError(f"Unsupported modality: {modality}")

inputs = INPUT_FORMATTER_MAP[model_type](model_dir, inputs)
inputs = input_formatter(model_dir, inputs)

return inputs

Expand Down Expand Up @@ -95,17 +96,23 @@ def main():

llm, sampling_params = setup_llm(args)

image_format = "pt" # ["pt", "pil"]
if args.model_type is not None:
model_type = args.model_type
else:
model_type = json.load(
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
# feel free to override the default formatter and loaders based on your applications
model_type = args.model_type if args.model_type else json.load(
open(os.path.join(llm._hf_model_dir, 'config.json')))['model_type']
assert model_type in INPUT_FORMATTER_MAP, f"Unsupported model_type: {model_type}"

inputs = prepare_multimodal_inputs(args.model_dir, model_type,
args.modality, args.prompt, args.media,
image_format, args.num_frames)
input_formatter = INPUT_FORMATTER_MAP[model_type]
mm_loader = None
if args.modality == "image":
mm_loader = default_image_loader
elif args.modality == "video":
mm_loader = partial(default_video_loader, num_frames=args.num_frames)
data_format = "pt" # ["pt", "pil"]
media_input_device = "cpu"

inputs = prepare_multimodal_inputs(args.model_dir, args.modality,
args.prompt, args.media, input_formatter,
mm_loader, data_format,
media_input_device)

outputs = llm.generate(inputs, sampling_params)

Expand Down
96 changes: 54 additions & 42 deletions tensorrt_llm/_torch/attention_backend/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,6 @@ class AttentionMetadata:
# Whether CUDA graph is enabled.
is_cuda_graph: bool = field(default=False, repr=False)

# The length of each sequence in the batch for query.
# The shape is (batch_size), and located on CPU memory.
# For sub metadata of cross attention, it's automatically
# initialized to seq_lens of parent metadata.
seq_lens: Optional[torch.Tensor] # Implemented using property

# The number of context-phase sequences in the batch.
num_contexts: int # Implemented using property

Expand All @@ -71,6 +65,12 @@ class AttentionMetadata:
# The parameters for the KV cache.
kv_cache_params: Optional[KVCacheParams] = None

# The length of each sequence in the batch for query.
# The shape is (batch_size), and located on CPU memory.
# For sub metadata of cross attention, it's automatically
# initialized to seq_lens of parent metadata.
seq_lens: Optional[torch.Tensor] # Implemented using property

# The length of each sequence in the batch for key and value.
# The shape is (batch_size), and located on CPU memory.
# It defaults to seq_lens if not set.
Expand Down Expand Up @@ -110,6 +110,8 @@ class AttentionMetadata:
# For generation-phase sequence, the value is the token number of its context phase.
# The shape is (batch_size) if provided.
prompt_lens: Optional[List[int]] = None
# The "original" prompt length of each sequence in the batch. Useful in cases where the prompt tokens can be updated thus become different from the original user input prompt, e.g. multimodal.
orig_prompt_lens: Optional[List[int]] = None

# These fields indicate whether the runtime can use various features.
# The kernels may or may not have different behaviors when these
Expand Down Expand Up @@ -148,6 +150,35 @@ def on_update(self):
elif self._seq_lens is not None:
self._num_tokens = self._seq_lens.sum().item()

def on_update_gpu(self, key="all"):
'''
Update underlying GPU buffers when seq_lens or seq_lens_kv is updated.
'''
if key in ["seq_lens", "all"]:
# The model executor sets seq_lens to None initially.
if self._seq_lens is not None:
self._seq_lens = self._seq_lens.pin_memory()

if self.is_cuda_graph and self._seq_lens_cuda is not None:
# Very important: do not reallocate if we are using CUDA graphs.
# This copy is safe because the batch size is guaranteed to not
# change in the CUDA graph case. The seqlens can change if we
# are doing spec decode.
self._seq_lens_cuda.copy_(self._seq_lens, non_blocking=True)
else:
self._seq_lens_cuda = self._seq_lens.cuda(non_blocking=True)

if self.has_cross_sub_metadata:
self.cross._seq_lens = self._seq_lens
self.cross._seq_lens_cuda = self._seq_lens_cuda

if key in ["seq_lens_kv", "all"]:
# The model executor sets seqlens to None initially.
if self._seq_lens_kv is not None:
self._seq_lens_kv = self._seq_lens_kv.pin_memory()
self._seq_lens_kv_cuda = self._seq_lens_kv.cuda(
non_blocking=True)

@property
def seq_lens(self) -> Optional[torch.Tensor]:
return self._seq_lens
Expand All @@ -158,23 +189,26 @@ def seq_lens(self, value: Optional[torch.Tensor]):
value = value if value is not AttentionMetadata.seq_lens else None
self._seq_lens = value
self.on_update()
self.on_update_gpu("seq_lens")

# The model executor sets seq_lens to None initially.
if self._seq_lens is not None:
self._seq_lens = self._seq_lens.pin_memory()
@property
def seq_lens_cuda(self):
return self._seq_lens_cuda

if self.is_cuda_graph and self._seq_lens_cuda is not None:
# Very important: do not reallocate if we are using CUDA graphs.
# This copy is safe because the batch size is guaranteed to not
# change in the CUDA graph case. The seqlens can change if we
# are doing spec decode.
self._seq_lens_cuda.copy_(self._seq_lens, non_blocking=True)
else:
self._seq_lens_cuda = self._seq_lens.cuda(non_blocking=True)
@property
def seq_lens_kv(self) -> Optional[torch.Tensor]:
return self._seq_lens_kv if self._seq_lens_kv is not None else self._seq_lens

if self.has_cross_sub_metadata:
self.cross._seq_lens = self._seq_lens
self.cross._seq_lens_cuda = self._seq_lens_cuda
@seq_lens_kv.setter
def seq_lens_kv(self, value: Optional[torch.Tensor]):
value = value if value is not AttentionMetadata.seq_lens_kv else None
self._seq_lens_kv = value
self.on_update()
self.on_update_gpu("seq_lens_kv")

@property
def seq_lens_kv_cuda(self):
return self._seq_lens_kv_cuda if self._seq_lens_kv_cuda is not None else self._seq_lens_cuda

@property
def num_contexts(self) -> int:
Expand All @@ -196,28 +230,6 @@ def num_generations(self, value: int):
self._num_generations = value
self.on_update()

@property
def seq_lens_cuda(self):
return self._seq_lens_cuda

@property
def seq_lens_kv(self) -> Optional[torch.Tensor]:
return self._seq_lens_kv if self._seq_lens_kv is not None else self._seq_lens

@seq_lens_kv.setter
def seq_lens_kv(self, value: Optional[torch.Tensor]):
value = value if value is not AttentionMetadata.seq_lens_kv else None
self._seq_lens_kv = value
self.on_update()
# The model executor sets seqlens to None initially.
if self._seq_lens_kv is not None:
self._seq_lens_kv = self._seq_lens_kv.pin_memory()
self._seq_lens_kv_cuda = self._seq_lens_kv.cuda(non_blocking=True)

@property
def seq_lens_kv_cuda(self):
return self._seq_lens_kv_cuda if self._seq_lens_kv_cuda is not None else self._seq_lens_cuda

@property
def context_lens(self) -> torch.Tensor:
"""
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def from_pretrained(cls,
model_dir = Path(
transformers.utils.hub.cached_file(checkpoint_dir,
'config.json')).parent
pretrained_config.checkpoint_dir = model_dir

quant_config = QuantConfig()
layer_quant_config = None
# quantized ckpt in modelopt format
Expand Down
Loading