Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/auto_deploy/.vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"program": "build_and_run_ad.py",
"args": [
"--config",
"{\"batch_size\": 2, \"page_size\": 16, \"world_size\": 2, \"compile_backend\": \"torch-simple\", \"attn_backend\": \"FlashInfer\", \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", \"benchmark\": false}",
"{\"batch_size\": 2, \"page_size\": 16, \"world_size\": 2, \"compile_backend\": \"torch-simple\", \"attn_backend\": \"FlashInfer\",\"model_factory\": \"AutoModelForCausalLM\", \"model\": \"meta-llama/Meta-Llama-3.1-8B-Instruct\", \"benchmark\": false}",
"--model-kwargs",
"{}",
// "{\"num_hidden_layers\": 3}",
Expand Down
4 changes: 2 additions & 2 deletions examples/auto_deploy/.vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
// Please also modify the .env file accordingly
"terminal.integrated.env.linux": {
"OMPI_MCA_opal_cuda_support": "true",
"LD_LIBRARY_PATH": "/lib/x86_64-linux-gnu:${env:HOME}/miniconda3/envs/tekit/lib:${env:LD_LIBRARY_PATH}",
"LD_LIBRARY_PATH": "/lib/x86_64-linux-gnu:<PATH-TO-CONDA-ENV>/lib:${env:LD_LIBRARY_PATH}",
},
// STANDARD SETTINGS, DO NOT MODIFY ////////////////////////////////////////////////////////////
"python.envFile": "${workspaceFolder}/.vscode/.env",
Expand All @@ -23,7 +23,7 @@
// https://code.visualstudio.com/updates/v1_81#_python
// When that happens, we can remove the --ignore-glob flag and the gpu tests will be skipped
"python.testing.pytestArgs": [
"./tests",
"./tests/unittest/_torch/auto_deploy",
"--no-cov",
],
"files.exclude": {
Expand Down
22 changes: 11 additions & 11 deletions examples/auto_deploy/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,21 +67,21 @@ The exported graph then undergoes a series of automated transformations, includi

**Bring Your Own Model**: AutoDeploy leverages `torch.export` and dynamic graph pattern matching, enabling seamless integration for a wide variety of models without relying on hard-coded architectures.

Additionally, we have officially verified and fully optimized support for the following models:
Additionally, we have officially verified support for the following models:

<details>
<summary>Click to expand supported models list</summary>

| Model Series | HF Model Card | Precision | World Size | Runtime | Compile Backend ||| Attention Backend |||
|--------------|----------------------|-----------|------------|---------|-----------------|--------------------|--------------------|--------------------|----------|----------|
| | | | | | torch-simple | torch-compile | torch-opt | TritonWithFlattenedInputs | FlashInfer | MultiHeadLatentAttention |
| LLaMA | meta-llama/Llama-2-7b-chat-hf<br>meta-llama/Meta-Llama-3.1-8B-Instruct<br>meta-llama/Llama-3.1-70B-Instruct<br>codellama/CodeLlama-13b-Instruct-hf | BF16 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| Nvidia Minitron | nvidia/Llama-3_1-Nemotron-51B-Instruct<br>nvidia/Llama-3.1-Minitron-4B-Width-Base<br>nvidia/Llama-3.1-Minitron-4B-Depth-Base | BF16 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| Nvidia Model Optimizer | nvidia/Llama-3.1-8B-Instruct-FP8<br>nvidia/Llama-3.1-405B-Instruct-FP8 | FP8 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| DeepSeek | deepseek-ai/DeepSeek-R1-Distill-Llama-70B | BF16 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| Mistral | mistralai/Mixtral-8x7B-Instruct-v0.1<br>mistralai/Mistral-7B-Instruct-v0.3 | BF16 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| BigCode | bigcode/starcoder2-15b | FP32 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| Deepseek-V3 | deepseek-ai/DeepSeek-V3 | BF16 | 1,2,4 | demollm | ✅ | ❌ | ❌ | n/a | n/a | ✅ |
| Model Series | HF Model Card | Model Factory | Precision | World Size | Runtime | Compile Backend ||| Attention Backend |||
|--------------|----------------------|----------------|-----------|------------|---------|-----------------|--------------------|--------------------|--------------------|----------|----------|
| | | | | | | torch-simple | torch-compile | torch-opt | TritonWithFlattenedInputs | FlashInfer | MultiHeadLatentAttention |
| LLaMA | meta-llama/Llama-2-7b-chat-hf<br>meta-llama/Meta-Llama-3.1-8B-Instruct<br>meta-llama/Llama-3.1-70B-Instruct<br>codellama/CodeLlama-13b-Instruct-hf | AutoModelForCausalLM | BF16 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| Nvidia Minitron | nvidia/Llama-3_1-Nemotron-51B-Instruct<br>nvidia/Llama-3.1-Minitron-4B-Width-Base<br>nvidia/Llama-3.1-Minitron-4B-Depth-Base | AutoModelForCausalLM | BF16 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| Nvidia Model Optimizer | nvidia/Llama-3.1-8B-Instruct-FP8<br>nvidia/Llama-3.1-405B-Instruct-FP8 | AutoModelForCausalLM | FP8 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| DeepSeek | deepseek-ai/DeepSeek-R1-Distill-Llama-70B | AutoModelForCausalLM | BF16 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| Mistral | mistralai/Mixtral-8x7B-Instruct-v0.1<br>mistralai/Mistral-7B-Instruct-v0.3 | AutoModelForCausalLM | BF16 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| BigCode | bigcode/starcoder2-15b | AutoModelForCausalLM | FP32 | 1,2,4 | demollm, trtllm | ✅ | ✅ | ✅ | ✅ | ✅ | n/a |
| Deepseek-V3 | deepseek-ai/DeepSeek-V3 | AutoModelForCausalLM | BF16 | 1,2,4 | demollm | ✅ | ❌ | ❌ | n/a | n/a | ✅ |

</details>

Expand Down
2 changes: 1 addition & 1 deletion examples/auto_deploy/simple_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class SimpleConfig:
# If no `model` argument is provided, the checkpoint directory is used to infer the model
# architecture.
model: Optional[str] = None
model_factory: Literal["hf"] = "hf"
model_factory: Literal["AutoModelForCausalLM"] = "AutoModelForCausalLM"
skip_loading_weights: bool = False # only load the architecture, not the weights
customize_tokenizer: bool = False # True: tokenizer from the model factory, False: from LLM api

Expand Down
37 changes: 19 additions & 18 deletions tensorrt_llm/_torch/auto_deploy/models/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@
class ModelFactory(ABC):
"""An interface to return and correctly initialize a model from a desired source.

NOTE: we make the assumption that the model can be prompted with a set of input_ids of shape
(batch_size, seq_len) and will return a tensor of shape (batch_size, seq_len, embedding_size).
NOTE: we make the assumption that the model can be prompted with a set of input_ids and
position_ids of shape (batch_size, seq_len) and will return a tensor of shape
(batch_size, seq_len, embedding_size).
"""

def __init__(self, model: Optional[str] = None, skip_loading_weights: bool = False, **kwargs):
Expand Down Expand Up @@ -47,16 +48,16 @@ def forward(
self, input_ids: torch.Tensor, position_ids: torch.Tensor
) -> Sequence[torch.Tensor]: ...

``logits`` are assumeg to be the first output of the model, i.e.,
``logits`` are assumed to be the first output of the model, i.e.,
``model(input_ids, position_ids)[0]`` should return a ``logits`` tensor.

Moreover, we assume the following tensor shapes:

.. code-block:: python

input_ids.shape = (batch_size, seq_len)
position_ids.shape = (batch_size, seq_len)
logits.shape = (batch_size, seq_len, vocab_size)
input_ids.shape == (batch_size, seq_len)
position_ids.shape == (batch_size, seq_len)
logits.shape == (batch_size, seq_len, vocab_size)
"""

def get_quant_config(self) -> Dict:
Expand Down Expand Up @@ -84,20 +85,20 @@ def prefetch_checkpoint(self):
"""Try prefetching checkpoint."""
pass

def load_or_random_init(self, model: nn.Module, **kwargs):
def load_or_random_init(self, model: nn.Module, device: DeviceLikeType):
"""Load the checkpoint into the model or randomly initialize the model.

Args:
model: The model to load the checkpoint into. Note that the model does not need to be
the same model that is built above but it needs to have a state dict compatible with
the model built above.
**kwargs: Keyword arguments that will be passed to torch.load.
device: The device to load the model on.
"""
ad_logger.info("Loading and initializing weights.")
if self.skip_loading_weights:
self._load_random_init(model, **kwargs)
self._load_random_init(model, device)
else:
self._load_checkpoint(model, **kwargs)
self._load_checkpoint(model, device)

@staticmethod
def _to_maybe_empty(model: nn.Module, device: DeviceLikeType):
Expand All @@ -114,25 +115,25 @@ def _to_maybe_empty(model: nn.Module, device: DeviceLikeType):
)

@classmethod
def _load_random_init(cls, model: nn.Module, **kwargs):
def _load_random_init(cls, model: nn.Module, device: DeviceLikeType):
"""Randomly initialize model."""
cls._to_maybe_empty(model, kwargs.get("map_location"))
cls._to_maybe_empty(model, device)
state_dict = model.state_dict()
for k in state_dict:
state_dict[k] = torch.normal(
0.0, 1.0, size=state_dict[k].shape, device=kwargs.get("map_location")
).to(state_dict[k].dtype)
model.load_state_dict(state_dict, strict=True)
state_dict[k] = torch.normal(0.0, 1.0, size=state_dict[k].shape, device=device).to(
state_dict[k].dtype
)
model.load_state_dict(state_dict, strict=True, assign=True)

@abstractmethod
def _load_checkpoint(self, model: nn.Module, **kwargs):
def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
"""Load the checkpoint into the model.

Args:
model: The model to load the checkpoint into. Note that the model does not need to be
the same model that is built above but it needs to have a state dict compatible with
the model built above.
**kwargs: Keyword arguments that will be passed to torch.load.
device: The device to load the model on.
"""


Expand Down
122 changes: 69 additions & 53 deletions tensorrt_llm/_torch/auto_deploy/models/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,16 @@
import os
import types
from contextlib import contextmanager, nullcontext
from typing import Any, Dict, Optional
from typing import Any, Dict, Mapping, Optional

import torch
import torch.nn as nn
from accelerate import init_empty_weights
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from accelerate.utils import modeling
from huggingface_hub import snapshot_download
from huggingface_hub.utils import HFValidationError, validate_repo_id
from torch._prims_common import DeviceLikeType
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.modeling_utils import load_sharded_checkpoint, load_state_dict
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, PretrainedConfig

from ..custom_ops.attention_interface import CacheConfig
from ..utils.logger import ad_logger
Expand All @@ -31,8 +31,10 @@ def load_state_dict_with_assign():
original_load_state_dict = torch.nn.Module.load_state_dict

# Define and apply the patched version
def load_state_dict_with_assign(*args, **kwargs):
return original_load_state_dict(*args, **kwargs, assign=True)
def load_state_dict_with_assign(
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
):
return original_load_state_dict(self, state_dict, strict=strict, assign=True)

# Apply the patch
torch.nn.Module.load_state_dict = load_state_dict_with_assign
Expand All @@ -45,22 +47,28 @@ def load_state_dict_with_assign(*args, **kwargs):
torch.nn.Module.load_state_dict = original_load_state_dict


def _to_maybe_empty(model: nn.Module, device: DeviceLikeType):
"""A mix of ``model.to(device)`` and ``model.to_empty(device)``.
@contextmanager
def hf_load_state_dict_with_device(device: DeviceLikeType):
"""Patch HF load_state_dict to use provided device."""
# save the original load_state_dict method
original_load_state_dict = modeling.load_state_dict

If a parameter is already initialized, then we will call `to()` on it. Otherwise, we will
initialize it with an empty tensor on the given device.
# Define and apply the patched version
def load_state_dict_with_device(checkpoint_file, device_map=None):
return original_load_state_dict(checkpoint_file, device_map={"": device})

"""
model._apply(
lambda t: torch.empty_like(t, device=device)
if t.device == torch.device("meta")
else t.to(device)
)
# Apply the patch
modeling.load_state_dict = load_state_dict_with_device

try:
yield
finally:
# Restore the original method, even if an exception occurred
modeling.load_state_dict = original_load_state_dict


@ModelFactoryRegistry.register("hf")
class HFFactory(ModelFactory):
@ModelFactoryRegistry.register("AutoModelForCausalLM")
class AutoModelForCausalLMFactory(ModelFactory):
def __init__(
self,
model_kwargs: Optional[Dict[str, Any]] = None,
Expand All @@ -70,10 +78,12 @@ def __init__(
super().__init__(**kwargs)

self.model_kwargs = model_kwargs or {}
self.model_kwargs["use_cache"] = False
self.tokenizer_kwargs = tokenizer_kwargs or {}
self._quant_config = None

# heuristic to disable use_cache
self.model_kwargs["use_cache"] = False

# prefetch the model+checkpoint
self.prefetch_checkpoint()
# load the quantization config
Expand All @@ -100,21 +110,49 @@ def _simple_forward(model: nn.Module, input_ids: torch.Tensor, position_ids: tor
"""
return type(model).forward(model, input_ids=input_ids, position_ids=position_ids)

def _recursive_update_config(self, config: PretrainedConfig, update_dict: Dict[str, Any]):
"""
Recursively update a PretrainedConfig object with values from update_dict.

Args:
config: PretrainedConfig object to update
update_dict: Dictionary with values to update in the config

Returns:
The updated PretrainedConfig object
"""
for key, value_new in update_dict.items():
# Check if the key exists in config
if not hasattr(config, key):
continue

target_value = getattr(config, key)

# Handle nested PretrainedConfig objects...
if isinstance(value_new, dict) and isinstance(target_value, PretrainedConfig):
# Recursively update nested configs
updated_value = self._recursive_update_config(target_value, value_new)
setattr(config, key, updated_value)
else:
# Direct update for simple values
setattr(config, key, value_new)

return config

def build_model(self, device: DeviceLikeType) -> nn.Module:
"""Build the model on the desired device."""
# We only support fp16 to fp4 conversion.
if self._quant_config and self._quant_config.get("quant_algo", None) == "NVFP4":
self.model_kwargs["torch_dtype"] = torch.half

model_config = self.autoconfig_from_pretrained(
self.model, trust_remote_code=True, **self.model_kwargs
)
# NOTE (lucaslie): HF doesn't recursively update nested PreTrainedConfig objects. Instead,
# the entire subconfig will be overwritten.
# we want to recursively update model_config from model_kwargs here.
model_config = self.autoconfig_from_pretrained(self.model, trust_remote_code=True)
model_config = self._recursive_update_config(model_config, self.model_kwargs)

with (init_empty_weights if device == "meta" else nullcontext)():
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(model_config.torch_dtype)
model = self.automodel_from_config(model_config, trust_remote_code=True)
torch.set_default_dtype(default_dtype)

# post-init --> this must be called explicitly for HF models the way we initialize them since
# this "gets lost" with the init_empty_weights context manager.
Expand Down Expand Up @@ -163,10 +201,10 @@ def prefetch_checkpoint(self):
is_hf_repo = False
if is_hf_repo:
# we don't expect to use bin files or pt/pth checkpoint files (they are quite large)
ignore_patterns = ["**pytorch_model*.bin*", "**.pt", "**.pth"]
ignore_patterns = ["*.bin", "*.pt", "*.pth"]
# we will also ignore the .safetensors files if we skip loading weights
if self.skip_loading_weights:
ignore_patterns.append("**safetensors")
ignore_patterns.append("*.safetensors")
ad_logger.info("Pre-fetching checkpoint directory from HF repo.")
fetched_dir = snapshot_download(self.model, ignore_patterns=ignore_patterns)
else:
Expand All @@ -177,7 +215,7 @@ def prefetch_checkpoint(self):

self._prefetched_path = fetched_dir

def _load_checkpoint(self, model, **kwargs):
def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
"""Load the checkpoint into the model."""
# check if we skip loading weights
if self.skip_loading_weights:
Expand All @@ -186,31 +224,9 @@ def _load_checkpoint(self, model, **kwargs):
# prefetch if needed
self.prefetch_checkpoint()

ckpt_path = self.model

# sharded checkpoint
if os.path.isfile(os.path.join(ckpt_path, "model.safetensors.index.json")):
_to_maybe_empty(model, device="cpu")
with load_state_dict_with_assign():
load_sharded_checkpoint(model, ckpt_path, strict=False)
return

# look for a single file in the directory ending with .safetensors or .pt/.pth
safetensors_files = [f for f in os.listdir(ckpt_path) if f.endswith(".safetensors")]
torch_files = [f for f in os.listdir(ckpt_path) if f.endswith((".pt", ".pth"))]
if len(safetensors_files) > 1:
raise ValueError(f"Multiple .safetensors files in {ckpt_path}: {safetensors_files}")
elif len(safetensors_files) == 1:
state_dict = load_state_dict(os.path.join(ckpt_path, safetensors_files[0]))
elif len(torch_files) > 1:
raise ValueError(f"Multiple .pt/.pth files found in {ckpt_path}: {torch_files}")
elif len(torch_files) == 1:
state_dict = torch.load(os.path.join(ckpt_path, torch_files[0]), **kwargs)
else:
raise ValueError(f"No checkpoint found in {ckpt_path}")

with load_state_dict_with_assign():
model.load_state_dict(state_dict, strict=False)
# reuse the load checkpoint utility from accelerate
with load_state_dict_with_assign(), hf_load_state_dict_with_device(device):
load_checkpoint_and_dispatch(model, checkpoint=self.model)

def _load_quantization_config(self):
assert self.model
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/shim/demollm.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def __init__(
self.mpi_session = None
self.runtime_context = None
self._tokenizer = self._try_load_tokenizer()
self.input_processor = create_input_processor(model, self.tokenizer)
self.input_processor = create_input_processor(None, self.tokenizer)

# construct sequence info object
seq_info = SequenceInfo(
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/shim/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def resize_cache(self, new_num_pages: int):
@dataclass
class AutoDeployConfig(PyTorchConfig):
# model factory to choose from
model_factory: str = "hf" # only 'hf' supported for "trtllm" runtime
model_factory: str = "AutoModelForCausalLM"

### MODEL EXTRA KWARGS ###
# Extra kwargs for the model config class to customize the model config. Those arguments will
Expand Down
Loading