Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,4 +260,4 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str],
install_requires=required_deps,
dependency_links=
extra_URLs, # Warning: Dependency links support has been dropped by pip 19.0
python_requires=">=3.7, <4")
python_requires=">=3.10, <4")
6 changes: 0 additions & 6 deletions tensorrt_llm/_torch/auto_deploy/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings):
description="Disable the overlap scheduler in trtllm runtime",
)

enable_mixed_sampler: bool = Field(
default=False,
description="If true, will iterate over sampling_params of each request and use the corresponding "
"sampling strategy, e.g. top-k, top-p, etc.",
)

world_size: int = Field(
default=1,
ge=0,
Expand Down
5 changes: 0 additions & 5 deletions tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,16 +337,11 @@ def create_autodeploy_executor(ad_config: LlmArgs):
scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler)

# search sampler with speculative decoding
# TODO (lucaslie, fridah-nv): some models require enable_mixed_sampler=True to have good outputs, see
# https://github.com/NVIDIA/TensorRT-LLM/issues/5254
# We should expose mixed_sample to our build_and_run_ad script so we can configure this
# correctly for models as needed.
sampler_args = TorchSampler.Args(
max_seq_len=ad_config.max_seq_len,
max_draft_len=max_draft_len,
max_num_sequences=max_num_sequences,
max_beam_width=ad_config.max_beam_width,
enable_mixed_sampler=ad_config.enable_mixed_sampler,
)
sampler = TorchSampler(sampler_args)

Expand Down
37 changes: 25 additions & 12 deletions tensorrt_llm/_torch/modules/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
# limitations under the License.

import enum
from typing import Optional, Tuple, Union
from types import EllipsisType # https://stackoverflow.com/a/66636313
from typing import Optional, Tuple, TypeAlias, Union, cast

import torch
from torch import nn
Expand All @@ -24,6 +25,9 @@

class RMSNorm(nn.Module):

_ARGUMENT_NOT_SPECIFIED_SENTINEL = ...
_ArgumentNotSpecifiedSentinelType: TypeAlias = EllipsisType

def __init__(
self,
*,
Expand All @@ -48,12 +52,19 @@ def __init__(
def forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = ...,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
residual: Union[
Optional[torch.Tensor],
_ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
return_residual = True
if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL:
return_residual = False
residual = None

if IS_FLASHINFER_AVAILABLE:
from ..custom_ops import (flashinfer_fused_add_rmsnorm,
flashinfer_rmsnorm)
if isinstance(residual, torch.Tensor):
if residual is not None:
flashinfer_fused_add_rmsnorm(hidden_states, residual,
self.weight, self.variance_epsilon)
else:
Expand All @@ -62,7 +73,7 @@ def forward(
else:
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
if isinstance(residual, torch.Tensor):
if residual is not None:
hidden_states = hidden_states + residual.to(torch.float32)
residual = hidden_states.to(input_dtype)

Expand All @@ -71,20 +82,22 @@ def forward(
self.variance_epsilon)
hidden_states = self.weight * hidden_states.to(input_dtype)

if residual is ...:
return hidden_states
if return_residual:
return hidden_states, cast(Optional[torch.Tensor], residual)
else:
return hidden_states, residual
return hidden_states

def skip_forward(
self,
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor] = ...,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is ...:
residual: Union[
Optional[torch.Tensor],
_ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL,
) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]:
if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL:
return hidden_states
else:
return hidden_states, residual
return hidden_states, cast(Optional[torch.Tensor], residual)


class GroupRMSNormKernelSelection(enum.Enum):
Expand Down
4 changes: 1 addition & 3 deletions tensorrt_llm/_torch/pyexecutor/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ def create_py_executor_instance(


def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
enable_mixed_sampler: bool, max_batch_size: int,
max_batch_size: int,
speculative_config: SpeculativeConfig,
max_beam_width: int):
max_num_sequences = max_batch_size * mapping.pp_size
Expand All @@ -708,7 +708,6 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
max_draft_len=max_draft_len,
max_num_sequences=max_num_sequences,
max_beam_width=max_beam_width,
enable_mixed_sampler=enable_mixed_sampler,
)


Expand All @@ -722,7 +721,6 @@ def instantiate_sampler(engine: PyTorchModelEngine,
sampler_args = create_torch_sampler_args(
mapping,
max_seq_len=engine.max_seq_len,
enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler,
max_batch_size=max_batch_size,
speculative_config=speculative_config,
max_beam_width=max_beam_width)
Expand Down
5 changes: 0 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ class PyTorchConfig:

moe_disable_finalize_fusion: bool = False

enable_mixed_sampler: bool = False
"""
If true, will iterate over sampling_params of each request and use the
corresponding sampling strategy, e.g. top-k, top-p, etc.
"""
sampler_type: SamplerType = SamplerType.auto
"""
The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto.
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def __init__(
exclude_last_generation_logits)
self.child_requests = []

self._py_embedding_bias_1d = None
self._py_embedding_bias_1d: Optional[torch.Tensor] = None
if hasattr(self, 'embedding_bias') and self.embedding_bias is not None:
# Pre-squeeze to 1D if needed (remove batch dimension)
if self.embedding_bias.dim() > 1:
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def create_py_executor(
if _get_allow_chain_drafter():
use_chain_drafter = (
guided_decoding_config is None
and not pytorch_backend_config.enable_mixed_sampler
and draft_spec_config._allow_greedy_draft_tokens
and pytorch_backend_config.attn_backend == "TRTLLM")
else:
use_chain_drafter = False
Expand Down
Loading