Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def __init__(
quant_config=config.get_quant_config(),
allreduce_strategy=config.allreduce_strategy)

self._mamba_ssm_cache_dtype = config.quant_config.mamba_ssm_cache_dtype

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -230,6 +232,7 @@ def forward(
seq_idx=seq_idx,
return_varlen_states=True,
return_final_states=False,
mamba_ssm_cache_dtype=self._mamba_ssm_cache_dtype,
)
out.append(rearrange(y, "b l h p -> (b l) (h p)"))

Expand Down
47 changes: 27 additions & 20 deletions tensorrt_llm/_torch/modules/mamba/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
from einops import rearrange

Expand Down Expand Up @@ -43,6 +45,7 @@ def _mamba_chunk_scan_combined_fwd(
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
mamba_ssm_cache_dtype=None,
):
batch, seqlen, nheads, headdim = x.shape
_, _, ngroups, dstate = B.shape
Expand Down Expand Up @@ -120,7 +123,7 @@ def _mamba_chunk_scan_combined_fwd(
if initial_states is not None else None),
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=C.dtype,
out_dtype=mamba_ssm_cache_dtype or C.dtype,
is_cont_batched=cu_seqlens is not None)
states, final_states = [
rearrange(t, "... (p n) -> ... p n", n=dstate)
Expand Down Expand Up @@ -174,24 +177,26 @@ def _mamba_chunk_scan_combined_fwd(
return out, out_x, dt, dA_cumsum, states, final_states, varlen_states


def mamba_chunk_scan_combined(x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
return_final_states=False,
return_varlen_states=False):
def mamba_chunk_scan_combined(
x,
dt,
A,
B,
C,
chunk_size,
D=None,
z=None,
dt_bias=None,
initial_states=None,
seq_idx=None,
chunk_indices=None,
chunk_offsets=None,
cu_seqlens=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
return_final_states=False,
return_varlen_states=False,
mamba_ssm_cache_dtype: Optional[torch.dtype] = None):
"""
Argument:
x: (batch, seqlen, nheads, headdim)
Expand All @@ -207,6 +212,7 @@ def mamba_chunk_scan_combined(x,
seq_idx: (batch, seqlen)
cu_seqlens: (num_sequences + 1) or None, only used if return_varlen_states is True
dt_softplus: Whether to apply softplus to dt
mamba_ssm_cache_dtype: torch.dtype, default to None
Return:
out: (batch, seqlen, nheads, headdim)
"""
Expand All @@ -231,7 +237,8 @@ def mamba_chunk_scan_combined(x,
chunk_offsets=chunk_offsets,
cu_seqlens=cu_seqlens,
dt_softplus=dt_softplus,
dt_limit=dt_limit)
dt_limit=dt_limit,
mamba_ssm_cache_dtype=mamba_ssm_cache_dtype)
if not return_varlen_states:
return out if not return_final_states else (out, final_states)
else:
Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def _create_kv_cache_manager(
mamba_layer_mask = [
char == "M" for char in config.hybrid_override_pattern
]

kv_cache_manager = MambaHybridCacheManager(
# mamba cache parameters
config.ssm_state_size,
Expand All @@ -340,6 +341,8 @@ def _create_kv_cache_manager(
mamba_num_layers,
mamba_layer_mask,
config.torch_dtype,
model_engine.model.model_config.quant_config.
mamba_ssm_cache_dtype,
# kv cache parameters
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class PyTorchConfig:
"""

kv_cache_dtype: str = "auto"
mamba_ssm_cache_dtype: str = "auto"

enable_iter_perf_stats: bool = False
# If true, enables per request stats per iteration
# Must also set enable_iter_perf_stats to true to get request stats
Expand Down
16 changes: 15 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
get_num_extra_kv_tokens, update_spec_config_from_model_config)
from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP
from tensorrt_llm._utils import (is_trace_enabled, nvtx_range, release_gc,
torch_dtype_to_str, trace_func)
str_dtype_to_torch, torch_dtype_to_str,
trace_func)
from tensorrt_llm.inputs.multimodal import (MultimodalParams,
MultimodalRuntimeData)
from tensorrt_llm.logger import logger
Expand Down Expand Up @@ -98,6 +99,16 @@ def warmup(self, resource_manager: ResourceManager) -> None:
_VALID_KV_CACHE_DTYPES = ("fp8", "auto")


def validate_and_set_mamba_ssm_cache_dtype(config: ModelConfig,
mamba_ssm_cache_dtype: str) -> None:
if mamba_ssm_cache_dtype == "auto":
mamba_ssm_cache_dtype = config.pretrained_config.torch_dtype
else:
mamba_ssm_cache_dtype = str_dtype_to_torch(mamba_ssm_cache_dtype)

config.quant_config.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype


def validate_and_set_kv_cache_quant(model_config: ModelConfig,
pyt_kv_cache_dtype: str) -> QuantAlgo:
logger.info(
Expand Down Expand Up @@ -1022,6 +1033,9 @@ def _load_model(self,

validate_and_set_kv_cache_quant(
config, self.pytorch_backend_config.kv_cache_dtype)
validate_and_set_mamba_ssm_cache_dtype(
config, self.pytorch_backend_config.mamba_ssm_cache_dtype)

num_layers = int(os.environ.get("TLLM_OVERRIDE_LAYER_NUM", "0"))
if num_layers > 0:
config.pretrained_config.num_hidden_layers = num_layers
Expand Down
11 changes: 10 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/resource_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,9 +939,12 @@ def __init__(
max_batch_size: int,
mapping: Mapping,
dtype: torch.dtype,
ssm_cache_dtype: torch.dtype,
layer_mask: Optional[List[bool]] = None,
) -> None:

self.mamba_ssm_cache_dtype = ssm_cache_dtype

# get tp size
tp_size = mapping.tp_size

Expand Down Expand Up @@ -993,7 +996,7 @@ def __init__(
head_dim,
d_state,
],
dtype=dtype,
dtype=self.mamba_ssm_cache_dtype,
device=device,
)

Expand Down Expand Up @@ -1051,6 +1054,9 @@ def get_ssm_states(self, layer_idx: int) -> torch.Tensor:
layer_offset = self.mamba_layer_offsets[layer_idx]
return self.ssm_states[layer_offset]

def get_mamba_ssm_cache_dtype(self) -> torch.dtype:
return self.mamba_ssm_cache_dtype

def shutdown(self):
# release tensor memory, keeping python references as tensors
self.conv_states = torch.tensor([])
Expand All @@ -1072,6 +1078,8 @@ def __init__(
mamba_num_layers: int,
mamba_layer_mask: List[bool],
mamba_cache_dtype: torch.dtype,
mamba_ssm_cache_dtype: torch.dtype,

# kv cache parameters
kv_cache_config: KvCacheConfigCpp,
kv_cache_type: CacheTypeCpp,
Expand Down Expand Up @@ -1105,6 +1113,7 @@ def __init__(
max_batch_size,
mapping,
mamba_cache_dtype,
mamba_ssm_cache_dtype,
mamba_layer_mask,
)

Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/bench/benchmark/low_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@
default=.90,
help="The percentage of memory to use for KV Cache after model load.",
)
@optgroup.option(
"--mamba_ssm_cache_dtype",
type=click.Choice(["auto", "float16", "bfloat16", "float32"]),
default="auto",
help="Data type for Mamba SSM cache. If 'auto', inferred from model config.",
)
@optgroup.option(
"--max_seq_len",
type=int,
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/bench/benchmark/throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@
default=.90,
help="The percentage of memory to use for KV Cache after model load.",
)
@optgroup.option(
"--mamba_ssm_cache_dtype",
type=click.Choice(["auto", "float16", "bfloat16", "float32"]),
default="auto",
help="Data type for Mamba SSM cache. If 'auto', inferred from model config.",
)
@optgroup.group(
"Engine Input Configuration",
help="Input configuration for driving the engine.",
Expand Down
8 changes: 8 additions & 0 deletions tensorrt_llm/bench/benchmark/utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
validate_and_set_kv_cache_quant
from tensorrt_llm.bench.build.build import (get_benchmark_engine_settings,
get_model_config)
from tensorrt_llm.bench.build.dataclasses import NemotronHybridConfig
from tensorrt_llm.bench.dataclasses.general import (DatasetMetadata,
InferenceRequest)
from tensorrt_llm.logger import logger
Expand Down Expand Up @@ -88,6 +89,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
enable_chunked_prefill = params.get("enable_chunked_prefill", False)

kv_cache_dtype = "auto"
mamba_ssm_cache_dtype = params.get("mamba_ssm_cache_dtype", "auto")
kv_cache_config = {}
if extra_llm_api_options:
with open(extra_llm_api_options, 'r') as f:
Expand All @@ -96,6 +98,8 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
"dtype": "auto",
})
kv_cache_dtype = kv_cache_config.get("dtype", "auto")
mamba_ssm_cache_dtype = kv_cache_config.get("mamba_ssm_cache_dtype",
mamba_ssm_cache_dtype)

enable_chunked_prefill = llm_args_dict.get("enable_chunked_prefill",
enable_chunked_prefill)
Expand All @@ -115,6 +119,9 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
else:
model_config = get_model_config(model, model_path)

if isinstance(model_config, NemotronHybridConfig):
model_config.set_mamba_ssm_cache_dtype(mamba_ssm_cache_dtype)

from tensorrt_llm._torch.model_config import ModelConfig
model = model_path or model
tllm_model_config = ModelConfig.from_pretrained(model,
Expand Down Expand Up @@ -161,6 +168,7 @@ def get_settings(params: dict, dataset_metadata: DatasetMetadata, model: str,
}

kv_cache_config["dtype"] = kv_cache_dtype
kv_cache_config["mamba_ssm_cache_dtype"] = mamba_ssm_cache_dtype

pyt_options = {
"cuda_graph_config": cuda_graph_config,
Expand Down
4 changes: 4 additions & 0 deletions tensorrt_llm/bench/build/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ class NemotronHybridConfig(ModelConfig):
mamba_head_dim: int
d_inner: Optional[int] = Field(default=None)
num_mamba_layers: Optional[int] = Field(default=None)
mamba_ssm_cache_dtype: Optional[str] = Field(default="auto")

@model_validator(mode="after")
def set_values_if_none(self):
Expand All @@ -248,3 +249,6 @@ def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
def cache_memory_fraction(self, cache_memory_fraction):
# Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size
return cache_memory_fraction**2

def set_mamba_ssm_cache_dtype(self, mamba_ssm_cache_dtype: str):
self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype
13 changes: 12 additions & 1 deletion tensorrt_llm/bench/build/tuning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Tuple

import torch

from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.llmapi.llm_utils import QuantConfig
from tensorrt_llm.logger import logger
from tensorrt_llm.quantization.mode import QuantAlgo
Expand Down Expand Up @@ -77,8 +80,16 @@ def calc_engine_setting(
target_seq_len = target_input_len + target_output_len
cache_memory = available_memory * model_config.cache_memory_fraction(
kv_cache_gpu_mem_fraction)

bytes_per_elem = BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT)
if isinstance(model_config, NemotronHybridConfig):
mamba_ssm_cache_dtype = model_config.mamba_ssm_cache_dtype
if mamba_ssm_cache_dtype != "auto":
if str_dtype_to_torch(mamba_ssm_cache_dtype) == torch.float32:
bytes_per_elem = 4.0

gb_per_extra_cache = model_config.extra_model_cache_in_gb(
BYTES_PER_ELEM.get(QuantAlgo.NO_QUANT), target_seq_len)
bytes_per_elem, target_seq_len)
kv_cache_max_requests = cache_memory / (gb_per_token * target_seq_len +
gb_per_extra_cache)
extra_cache_memory = gb_per_extra_cache * kv_cache_max_requests
Expand Down
32 changes: 21 additions & 11 deletions tensorrt_llm/commands/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def get_llm_args(model: str,
moe_expert_parallel_size: Optional[int] = None,
gpus_per_node: Optional[int] = None,
free_gpu_memory_fraction: Optional[float] = None,
mamba_ssm_cache_dtype: str = "auto",
num_postprocess_workers: int = 0,
trust_remote_code: bool = False,
reasoning_parser: Optional[str] = None,
Expand All @@ -96,7 +97,8 @@ def get_llm_args(model: str,
max_beam_width=max_beam_width,
max_seq_len=max_seq_len)
kv_cache_config = KvCacheConfig(
free_gpu_memory_fraction=free_gpu_memory_fraction)
free_gpu_memory_fraction=free_gpu_memory_fraction,
mamba_ssm_cache_dtype=mamba_ssm_cache_dtype)

dynamic_batch_config = DynamicBatchConfig(
enable_batch_size_tuning=True,
Expand Down Expand Up @@ -237,6 +239,12 @@ def launch_server(host: str,
default=0.9,
help="Free GPU memory fraction reserved for KV Cache, "
"after allocating model weights and buffers.")
@click.option(
"--mamba_ssm_cache_dtype",
type=click.Choice(["auto", "float16", "bfloat16", "float32"]),
default="auto",
help="Data type for Mamba SSM cache. If 'auto', inferred from model config."
)
@click.option(
"--num_postprocess_workers",
type=int,
Expand Down Expand Up @@ -277,16 +285,17 @@ def launch_server(host: str,
help=
"Exit with runtime error when attention window is too large to fit even a single sequence in the KV cache."
)
def serve(
model: str, tokenizer: Optional[str], host: str, port: int,
log_level: str, backend: str, max_beam_width: int, max_batch_size: int,
max_num_tokens: int, max_seq_len: int, tp_size: int, pp_size: int,
ep_size: Optional[int], cluster_size: Optional[int],
gpus_per_node: Optional[int], kv_cache_free_gpu_memory_fraction: float,
num_postprocess_workers: int, trust_remote_code: bool,
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str], server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool):
def serve(model: str, tokenizer: Optional[str], host: str, port: int,
log_level: str, backend: str, max_beam_width: int,
max_batch_size: int, max_num_tokens: int, max_seq_len: int,
tp_size: int, pp_size: int, ep_size: Optional[int],
cluster_size: Optional[int], gpus_per_node: Optional[int],
kv_cache_free_gpu_memory_fraction: float, mamba_ssm_cache_dtype: str,
num_postprocess_workers: int, trust_remote_code: bool,
extra_llm_api_options: Optional[str], reasoning_parser: Optional[str],
metadata_server_config_file: Optional[str],
server_role: Optional[str],
fail_fast_on_attention_window_too_large: bool):
"""Running an OpenAI API compatible server

MODEL: model name | HF checkpoint path | TensorRT engine path
Expand All @@ -307,6 +316,7 @@ def serve(
moe_cluster_parallel_size=cluster_size,
gpus_per_node=gpus_per_node,
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction,
mamba_ssm_cache_dtype=mamba_ssm_cache_dtype,
num_postprocess_workers=num_postprocess_workers,
trust_remote_code=trust_remote_code,
reasoning_parser=reasoning_parser,
Expand Down
Loading