diff --git a/setup.py b/setup.py index 5f72c709d6b..82eae4883e8 100644 --- a/setup.py +++ b/setup.py @@ -260,4 +260,4 @@ def extract_from_precompiled(precompiled_location: str, package_data: List[str], install_requires=required_deps, dependency_links= extra_URLs, # Warning: Dependency links support has been dropped by pip 19.0 - python_requires=">=3.7, <4") + python_requires=">=3.10, <4") diff --git a/tensorrt_llm/_torch/auto_deploy/llm_args.py b/tensorrt_llm/_torch/auto_deploy/llm_args.py index 2ff564b17e3..6a02aa65ad5 100644 --- a/tensorrt_llm/_torch/auto_deploy/llm_args.py +++ b/tensorrt_llm/_torch/auto_deploy/llm_args.py @@ -105,12 +105,6 @@ class AutoDeployConfig(DynamicYamlMixInForSettings, BaseSettings): description="Disable the overlap scheduler in trtllm runtime", ) - enable_mixed_sampler: bool = Field( - default=False, - description="If true, will iterate over sampling_params of each request and use the corresponding " - "sampling strategy, e.g. top-k, top-p, etc.", - ) - world_size: int = Field( default=1, ge=0, diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 998c8a178f2..ca067235324 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -337,16 +337,11 @@ def create_autodeploy_executor(ad_config: LlmArgs): scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler) # search sampler with speculative decoding - # TODO (lucaslie, fridah-nv): some models require enable_mixed_sampler=True to have good outputs, see - # https://github.com/NVIDIA/TensorRT-LLM/issues/5254 - # We should expose mixed_sample to our build_and_run_ad script so we can configure this - # correctly for models as needed. sampler_args = TorchSampler.Args( max_seq_len=ad_config.max_seq_len, max_draft_len=max_draft_len, max_num_sequences=max_num_sequences, max_beam_width=ad_config.max_beam_width, - enable_mixed_sampler=ad_config.enable_mixed_sampler, ) sampler = TorchSampler(sampler_args) diff --git a/tensorrt_llm/_torch/modules/rms_norm.py b/tensorrt_llm/_torch/modules/rms_norm.py index 2a22d858250..4b7a388983b 100644 --- a/tensorrt_llm/_torch/modules/rms_norm.py +++ b/tensorrt_llm/_torch/modules/rms_norm.py @@ -14,7 +14,8 @@ # limitations under the License. import enum -from typing import Optional, Tuple, Union +from types import EllipsisType # https://stackoverflow.com/a/66636313 +from typing import Optional, Tuple, TypeAlias, Union, cast import torch from torch import nn @@ -24,6 +25,9 @@ class RMSNorm(nn.Module): + _ARGUMENT_NOT_SPECIFIED_SENTINEL = ... + _ArgumentNotSpecifiedSentinelType: TypeAlias = EllipsisType + def __init__( self, *, @@ -48,12 +52,19 @@ def __init__( def forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor] = ..., - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + residual: Union[ + Optional[torch.Tensor], + _ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + return_residual = True + if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL: + return_residual = False + residual = None + if IS_FLASHINFER_AVAILABLE: from ..custom_ops import (flashinfer_fused_add_rmsnorm, flashinfer_rmsnorm) - if isinstance(residual, torch.Tensor): + if residual is not None: flashinfer_fused_add_rmsnorm(hidden_states, residual, self.weight, self.variance_epsilon) else: @@ -62,7 +73,7 @@ def forward( else: input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) - if isinstance(residual, torch.Tensor): + if residual is not None: hidden_states = hidden_states + residual.to(torch.float32) residual = hidden_states.to(input_dtype) @@ -71,20 +82,22 @@ def forward( self.variance_epsilon) hidden_states = self.weight * hidden_states.to(input_dtype) - if residual is ...: - return hidden_states + if return_residual: + return hidden_states, cast(Optional[torch.Tensor], residual) else: - return hidden_states, residual + return hidden_states def skip_forward( self, hidden_states: torch.Tensor, - residual: Optional[torch.Tensor] = ..., - ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - if residual is ...: + residual: Union[ + Optional[torch.Tensor], + _ArgumentNotSpecifiedSentinelType] = _ARGUMENT_NOT_SPECIFIED_SENTINEL, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, Optional[torch.Tensor]]]: + if residual is self._ARGUMENT_NOT_SPECIFIED_SENTINEL: return hidden_states else: - return hidden_states, residual + return hidden_states, cast(Optional[torch.Tensor], residual) class GroupRMSNormKernelSelection(enum.Enum): diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index e8d68a59381..47bf53e48fb 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -697,7 +697,7 @@ def create_py_executor_instance( def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, - enable_mixed_sampler: bool, max_batch_size: int, + max_batch_size: int, speculative_config: SpeculativeConfig, max_beam_width: int): max_num_sequences = max_batch_size * mapping.pp_size @@ -708,7 +708,6 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, max_draft_len=max_draft_len, max_num_sequences=max_num_sequences, max_beam_width=max_beam_width, - enable_mixed_sampler=enable_mixed_sampler, ) @@ -722,7 +721,6 @@ def instantiate_sampler(engine: PyTorchModelEngine, sampler_args = create_torch_sampler_args( mapping, max_seq_len=engine.max_seq_len, - enable_mixed_sampler=pytorch_backend_config.enable_mixed_sampler, max_batch_size=max_batch_size, speculative_config=speculative_config, max_beam_width=max_beam_width) diff --git a/tensorrt_llm/_torch/pyexecutor/config.py b/tensorrt_llm/_torch/pyexecutor/config.py index 7f46c521b6f..407d3fa1ec2 100644 --- a/tensorrt_llm/_torch/pyexecutor/config.py +++ b/tensorrt_llm/_torch/pyexecutor/config.py @@ -56,11 +56,6 @@ class PyTorchConfig: moe_disable_finalize_fusion: bool = False - enable_mixed_sampler: bool = False - """ - If true, will iterate over sampling_params of each request and use the - corresponding sampling strategy, e.g. top-k, top-p, etc. - """ sampler_type: SamplerType = SamplerType.auto """ The type of sampler to use. Options are TRTLLMSampler, TorchSampler or auto. diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index b20382006ec..068ffce4704 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -366,7 +366,7 @@ def __init__( exclude_last_generation_logits) self.child_requests = [] - self._py_embedding_bias_1d = None + self._py_embedding_bias_1d: Optional[torch.Tensor] = None if hasattr(self, 'embedding_bias') and self.embedding_bias is not None: # Pre-squeeze to 1D if needed (remove batch dimension) if self.embedding_bias.dim() > 1: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 5f97e2e37af..d2187138e2f 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -349,7 +349,7 @@ def create_py_executor( if _get_allow_chain_drafter(): use_chain_drafter = ( guided_decoding_config is None - and not pytorch_backend_config.enable_mixed_sampler + and draft_spec_config._allow_greedy_draft_tokens and pytorch_backend_config.attn_backend == "TRTLLM") else: use_chain_drafter = False diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 6e6f57bc214..9a800628620 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1,7 +1,12 @@ +import dataclasses +import enum +import sys from abc import ABC, abstractmethod +from collections import defaultdict from collections.abc import Iterable from dataclasses import dataclass -from typing import List, Literal, Optional +from itertools import repeat +from typing import Any, List, Literal, Optional, cast import torch @@ -25,6 +30,11 @@ from .llm_request import LlmRequest, LlmRequestState, get_draft_token_length from .scheduler import ScheduledRequests +if sys.version_info[:2] >= (3, 12): + from typing import override +else: + from typing_extensions import override + @dataclass(kw_only=True) class SampleStateTensors: @@ -74,6 +84,10 @@ def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: def is_generation_model(self) -> bool: raise NotImplementedError + def should_provide_draft_probs(self, request: LlmRequest) -> bool: + """Check if sampler wants to receive draft token probabilities.""" + return True # conservative default + class EarlyStopSampler(Sampler): """ @@ -81,11 +95,13 @@ class EarlyStopSampler(Sampler): such as encoder-only model (e.g., BERT) or reward models that only need context phase. """ + @override def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs, num_context_logits_prefix_sum: list[int]) -> SampleState: host = SampleStateTensors(new_tokens=torch.empty(0)) return SampleState(scheduled_requests=scheduled_requests, host=host) + @override def update_requests(self, state: SampleState) -> None: assert isinstance(state, SampleState) scheduled_requests = state.scheduled_requests @@ -95,6 +111,7 @@ def update_requests(self, state: SampleState) -> None: # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) + @override def is_generation_model(self) -> bool: return False @@ -119,6 +136,7 @@ class EarlyStopWithMMResult(Sampler): Use for skipping decoding step for non generation model, and return the batch_output (such as mm_embeddings) """ + @override def sample_async( self, scheduled_requests: ScheduledRequests, model_outputs, num_context_logits_prefix_sum: list[int] @@ -128,6 +146,7 @@ def sample_async( return SampleStateWithMMResult(scheduled_requests=scheduled_requests, data=data) + @override def update_requests(self, state: SampleStateWithMMResult) -> None: assert isinstance(state, SampleStateWithMMResult) scheduled_requests = state.scheduled_requests @@ -145,13 +164,16 @@ def update_requests(self, state: SampleStateWithMMResult) -> None: request.py_result.append_mm_embeddings(mm_embedding) + @override def is_generation_model(self) -> bool: return False -def top_k_sampling_batch(logits, - top_k=50, - generator: Optional[torch.Generator] = None): +def top_k_sampling_batch( + logits, + top_k=50, + generator: Optional[torch.Generator] = None +) -> tuple[torch.Tensor, torch.Tensor]: logits_dim = logits.dim() if logits_dim == 1: logits = logits.unsqueeze(0) @@ -176,13 +198,14 @@ def top_k_sampling_batch(logits, return next_tokens, softmax -def top_p_sampling_batch(logits: torch.Tensor, - top_p: float = 0.9, - temperature: float = 1.0, - generator: Optional[torch.Generator] = None): +def top_p_sampling_batch( + logits: torch.Tensor, + *, + top_p: float = 0.9, + temperature: float = 1.0, + generator: Optional[torch.Generator] = None +) -> tuple[torch.Tensor, torch.Tensor]: logits_dim = logits.dim() - if logits_dim == 1: - logits = logits.unsqueeze(0) assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" if temperature != 0: @@ -214,20 +237,19 @@ def top_p_sampling_batch(logits: torch.Tensor, def top_k_top_p_sampling_batch(logits: torch.Tensor, + *, top_k: int, top_p: float, temperature: float = 1.0, generator: Optional[torch.Generator] = None): logits_dim = logits.dim() - if logits_dim == 1: - logits = logits.unsqueeze(0) assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" if temperature != 0: logits = logits / max(temperature, 1e-5) batch_size, vocab_size = logits.size() # get first top_k logits of each sample and their indices if top_k > 0: - values, indices = torch.topk(logits, top_k, dim=-1) + values, _ = torch.topk(logits, top_k, dim=-1) min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) # set the logits who is less than first top_k logits to -inf @@ -259,19 +281,39 @@ def top_k_top_p_sampling_batch(logits: torch.Tensor, return next_tokens, softmax -def greedy_search_sampling_batch(logits): +def greedy_search_sampling_batch( + logits, + *, + softmax_indices: Optional[torch.IntTensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: next_tokens = torch.argmax(logits, dim=-1) + if softmax_indices is not None: + logits = logits[softmax_indices.to(logits.device, non_blocking=True)] softmax = torch.softmax(logits, dim=-1) return next_tokens, softmax def get_rejected_indices(draft_probs: torch.Tensor, target_probs: torch.Tensor, - generator: torch.Generator, draft_tokens: list[int]): - - p = draft_probs[torch.arange(len(draft_tokens)), draft_tokens] - q = target_probs[:-1] - q = q[torch.arange(len(draft_tokens)), draft_tokens] - accept_probs = torch.minimum(torch.ones(()), q / p) + generator: torch.Generator, + draft_tokens: list[int]) -> torch.Tensor: + # NB: ModelDrafter._pad_to_max_draft_tokens pads draft_tokens, but + # not draft_probs. Relying on shape of draft_probs here. + num_draft_tokens = draft_probs.size(0) + draft_tokens = draft_tokens[:num_draft_tokens] + # NB: torch.arange is needed to enable "advanced indexing", + # cf. https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing + token_idx = torch.arange(num_draft_tokens, + dtype=torch.int32, + device=generator.device) + draft_tokens_cuda = torch.tensor(draft_tokens, + dtype=torch.int32, + pin_memory=True).to( + device=generator.device, + non_blocking=True) + p = draft_probs[token_idx, draft_tokens_cuda] + q = target_probs.squeeze(0)[token_idx, draft_tokens_cuda] + accept_probs = torch.minimum( + torch.ones((), device=generator.device, dtype=q.dtype), q / p) # Use deterministic random generation for multi-GPU consistency rejected_indices = (torch.rand(accept_probs.shape, generator=generator, @@ -298,45 +340,83 @@ def sample_rejected(draft_probs: torch.Tensor, target_probs: torch.Tensor, TopKTopP = tuple[Literal["top_k_top_p"], int, float, float] Greedy = tuple[Literal["greedy"], None] GREEDY: Greedy = ("greedy", None) -Strategy = TopK | TopP | Greedy +Strategy = TopK | TopP | Greedy | TopKTopP + +def _request_strategy(request: LlmRequest) -> Strategy: + # top_p and top_K with temperature=0.0 reduces to greedy + # sampling + temperature = request.sampling_config.temperature + if temperature is not None: + temperature = temperature[0] + if temperature == 0.0: + return GREEDY -def request_strategy(request: LlmRequest) -> Strategy: if request.sampling_config.top_k is not None and len( request.sampling_config.top_k ) > 0 and request.sampling_config.top_p is not None and len( request.sampling_config.top_p) > 0: return ("top_k_top_p", request.sampling_config.top_k[0], - request.sampling_config.top_p[0], - request.sampling_config.temperature[0]) - if request.sampling_config.top_p is not None and len( + request.sampling_config.top_p[0], temperature) + elif request.sampling_config.top_p is not None and len( request.sampling_config.top_p) > 0: - return ("top_p", request.sampling_config.top_p[0], - request.sampling_config.temperature[0]) + top_p = request.sampling_config.top_p[0] + return ("top_p", top_p, temperature) elif request.sampling_config.top_k is not None and len( request.sampling_config.top_k) > 0: return ("top_k", request.sampling_config.top_k[0]) else: - return ("greedy", None) - - -def sampling_strategies(requests: Iterable[LlmRequest]) -> list[Strategy]: - return [request_strategy(req) for req in requests] - - -def sample(strategy: Strategy, - logits: torch.Tensor, - generator: Optional[torch.Generator] = None): + return GREEDY + + +def _group_requests_by_sampling_strategy( + requests: Iterable[LlmRequest], + *, + pin_memory: bool = False) -> dict[Strategy, torch.Tensor]: + strategy_dict: dict[Strategy, list[int]] = defaultdict(list) + for req_index, req in enumerate(requests): + strategy_dict[_request_strategy(req)].append(req_index) + return { + strategy: torch.tensor(indices, + pin_memory=pin_memory, + dtype=torch.int32) + for strategy, indices in strategy_dict.items() + } + + +def sample( + strategy: Strategy, + logits: torch.Tensor, + generator: Optional[torch.Generator] = None, + *, + softmax_indices: Optional[torch.IntTensor] = None +) -> tuple[torch.Tensor, torch.Tensor]: + filter_softmax = True match strategy: case ("top_k", top_k): - return top_k_sampling_batch(logits, top_k, generator) + tokens, softmax = top_k_sampling_batch(logits, top_k, generator) case ("top_p", top_p, temperature): - return top_p_sampling_batch(logits, top_p, temperature, generator) + tokens, softmax = top_p_sampling_batch( + logits, + top_p=top_p, + generator=generator, + **(dict(temperature=temperature) + if temperature is not None else dict())) case ("top_k_top_p", top_k, top_p, temperature): - return top_k_top_p_sampling_batch(logits, top_k, top_p, temperature, - generator) + tokens, softmax = top_k_top_p_sampling_batch( + logits, + top_k=top_k, + top_p=top_p, + generator=generator, + **(dict(temperature=temperature) + if temperature is not None else dict())) case ("greedy", None): - return greedy_search_sampling_batch(logits) + tokens, softmax = greedy_search_sampling_batch( + logits, softmax_indices=softmax_indices) + filter_softmax = False + if filter_softmax and softmax_indices is not None: + softmax = softmax[softmax_indices.to(softmax.device, non_blocking=True)] + return tokens, softmax def add_token(request: LlmRequest, @@ -355,10 +435,326 @@ def int_tensor(shape: tuple[int, ...], device: str = 'cuda') -> torch.Tensor: return torch.empty(shape, dtype=torch.int, device=device) +@dataclass(kw_only=True, frozen=True) +class _BatchedSamplingResult: + # Original request indices for all requests (permuted due to batching by strategy): + batch_req_indices: torch.Tensor + # Next tokens for all requests: + batch_next_tokens_cuda_int: torch.Tensor + # Probs for all requests which need them: + batch_softmax_cuda: torch.Tensor + # (request, batch_softmax indices), for requests having py_draft_logits / requesting py_target_probs: + py_draft_logits_indices: list[tuple[LlmRequest, + torch.Tensor]] = dataclasses.field( + default_factory=list) + + +# Inspired by https://github.com/pytorch/pytorch/issues/80577; note also the +# suggestion to consider torch.nested. +def torch_multi_arange(ends: torch.Tensor, + *, + starts: Optional[torch.Tensor] = None, + steps: Optional[torch.Tensor] = None) -> torch.Tensor: + """Efficiently compute torch.cat([torch.arange(b, e, d) for b, e, d in zip(starts, ends, steps)]). + + Starts, ends, steps need to share dtype and shape. Invalid ranges like range(1, 2, -1) are + silently discarded. 'steps' defaults to 1 and 'starts' defaults to 0. + """ + if steps is not None: + assert ends.dtype == steps.dtype + assert ends.shape == steps.shape + if starts is not None: + assert ends.dtype == starts.dtype + assert ends.shape == starts.shape + + # This algorithm combines torch.repeat_interleaved() and torch.cumsum() to + # construct the result. + # + # 1. Given N ranges (characterized by starts, ends, steps), construct a sequence + # of 2N numbers, in which the non-overlapping pairs of consecutive numbers + # correspond to the ranges. For a given range, the pair (a, b) is chosen such + # that upon torch.cumsum() application 'a' turns the last element of the + # preceding range into the start element for the current range and 'b' is + # simply the step size for the current range. + # + repeats = ends # number of elements in each range + if starts is not None: + repeats = repeats.clone() + repeats -= starts + if steps is not None: + repeats = repeats.div(steps, rounding_mode="trunc") + repeats = repeats.clip(0) # ignore invalid ranges + range_ends = repeats - 1 # last element in each range + if steps is not None: + range_ends *= steps + if starts is not None: + range_ends += starts + prev_range_ends = range_ends.roll( + 1) # last element in preceding range (or 0) + prev_range_ends[0] = 0 + ones = torch.tensor(1, dtype=ends.dtype, pin_memory=True).to( + device=ends.device, non_blocking=True).broadcast_to(ends.shape) + if steps is None: + steps = ones + jumps = -prev_range_ends # delta from one range to the next + if starts is not None: + jumps += starts + seq = torch.cat((jumps.unsqueeze(-1), steps.unsqueeze(-1)), dim=1).view(-1) + # + # 2. Construct output via torch.repeat_interleave() and torch.cumsum() + seq_repeats = torch.cat((ones.unsqueeze(-1), (repeats - 1).unsqueeze(-1)), + dim=1).view(-1) + seq = seq.repeat_interleave(seq_repeats) + seq = seq.cumsum(0) + return seq + + +# Helper class for _PackedStepIndexer and _UnpackedStepIndexer, facilitating the +# selection of memory locations of tokens associated with given sets of requests. +class _StepIndexTranslator(ABC): + + def __init__( + self, + *, + num_steps: torch.Tensor, + req_offsets: Optional[torch.Tensor] = None, + max_steps: Optional[int] = None, + index_dtype: Optional[torch.dtype] = None, + ): + """Build the index. + + Arguments: + index_dtype: torch.dtype to use for indices (defaults to torch.int32). + num_steps (index_dtype): Number of steps/tokens for each request + req_offsets (index_dtype): Index offset at which the data for each request starts. + If not provided, it is computed using calculate_request_offsets(), + which assumes dense packing. + max_steps (int): The largest value allowed to occur in num_steps. + If not provided, it is computed from num_steps. + """ + if req_offsets is None: + req_offsets, _ = self.calculate_request_offsets(num_steps) + if max_steps is None: + max_steps = cast(int, num_steps.max().item()) + self._index_map, self._index_mask = self._build_index( + req_offsets=req_offsets, + num_steps=num_steps, + max_steps=max_steps, + index_dtype=(index_dtype or torch.int32), + ) + + @staticmethod + def calculate_request_offsets( + req_num_steps: torch.Tensor, + pin_memory: bool = False) -> tuple[torch.Tensor, int]: + if req_num_steps.numel(): + req_offsets = torch.cumsum(req_num_steps, 0) + sum_steps = int(req_offsets[-1].item()) + req_offsets_rolled = torch.empty_like(req_offsets, + pin_memory=pin_memory) + req_offsets_rolled[1:] = req_offsets[:-1] + req_offsets_rolled[0] = 0 + req_offsets = req_offsets_rolled + else: + req_offsets = torch.empty_like(req_num_steps, pin_memory=pin_memory) + sum_steps = 0 + return req_offsets, sum_steps + + def _build_index( + self, req_offsets: torch.Tensor, num_steps: torch.Tensor, + max_steps: int, + index_dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: + steps_dim = torch.arange(max_steps, + device=num_steps.device, + dtype=index_dtype) + valid_mask = steps_dim.unsqueeze(0) < num_steps.unsqueeze(-1) + indices = self._compute_index_map(index_dtype=index_dtype, + steps_dim=steps_dim, + req_offsets=req_offsets) + # NB: steps_dim and req_offsets may have been overwritten by this point. + return indices, valid_mask + + @abstractmethod + def _compute_index_map(self, index_dtype: torch.dtype, + steps_dim: torch.Tensor, + req_offsets: torch.Tensor) -> torch.Tensor: + """Compute full tensor index map. + + Should return a tensor of shape (len(num_steps), max_steps) containing the linear + token index (index_dtype) corresponding to a given request and decoding step. + Each row corresponds to a request (same ordering as 'req_offsets' and 'num_steps'), + and the columns correspond to decoding steps 0, ..., num_steps[i]. Entries corresponding + to decoding steps which are invalid for the given request are masked elsewhere within + _StepIndexTranslator. + + This method is allowed to repurpose/overwrite 'steps_dim' and 'req_offsets'. + + Arguments: + num_steps (index_dtype): Number of steps/tokens for each request + req_offsets (index_dtype): Index offset at which the data for each request starts. + steps_dim (index_dtype): arange(max_steps) + index_dtype: torch.dtype to use for indices + """ + + def __getitem__(self, req_indices: Any) -> torch.Tensor: + """Gather indices for a given set of requests. + + Arguments: + req_indices: Any 1d torch-compatible indexing expression to select requests, corresponds + to the linear indices of the entries in 'num_steps' and 'req_offsets' (cf. __init__). + Returns: + Array of linear indices (index_dtype) selecting the tokens/steps associated + with the requests identified by req_indices, in the same order as + req_indices. + """ + indices = self._index_map[req_indices].view(-1) + mask = self._index_mask[req_indices].view(-1) + # NB: Return value has dynamic shape (depends on mask nnz), which + # implies stream sync if CUDA is used. + return indices[mask] + + +# Helper class for _PackedStepIndexer and _UnpackedStepIndexer, facilitating the +# selection of memory locations of tokens associated with given sets of requests, +# for memory layouts that can be parametrized via request offsets and step stride. +class _StridedStepIndexTranslator(_StepIndexTranslator): + + def __init__( + self, + *, + num_steps: torch.Tensor, + req_offsets: Optional[torch.Tensor] = None, + max_steps: Optional[int] = None, + index_dtype: Optional[torch.dtype] = None, + step_stride: Optional[int] = None, + ): + """Build the index. + + Allows to specify a custom stride for steps dimension. + + Arguments: + index_dtype: torch.dtype to use for indices (defaults to torch.int32). + num_steps (index_dtype): Number of steps/tokens for each request + req_offsets (index_dtype): Index offset at which the data for each request starts. + If not provided, it is computed using calculate_request_offsets(), + assuming dense packing of tokens (grouped by request). Overriding + this also allows for "request major" indexing into rectangular + tensors. + max_steps (int): The largest value allowed to occur in num_steps. + If not provided, it is computed from 'num_steps'. + step_stride: Additional stride to multiply 'steps_dim' with (defaults to 1). Allows, + e.g., "step major" indexing into rectangular tensors. + """ + self._step_stride = step_stride + super().__init__(num_steps=num_steps, + req_offsets=req_offsets, + max_steps=max_steps, + index_dtype=index_dtype) + + @override + def _compute_index_map(self, index_dtype: torch.dtype, + steps_dim: torch.Tensor, + req_offsets: torch.Tensor) -> torch.Tensor: + if self._step_stride is not None: + steps_dim *= self._step_stride # in-place OK + return req_offsets.unsqueeze(-1) + steps_dim.unsqueeze(0) + + +# In sample_async(), each request contains a different number of output positions +# (a.k.a. 'steps') and 'logits_cuda' (and other tensors derived from it) packs those +# tokens into a single contiguous array, with the 'step' axis being the rapidly +# changing one. +# +# The class below builds an index to simplify selecting the linear indices of the +# tokens associated with a given set of requests. +# +# NB: Consider switching to torch.nested (cf. https://github.com/pytorch/pytorch/issues/80577) +class _PackedStepIndexer(_StridedStepIndexTranslator): + + def __init__( + self, + *, + num_steps: torch.Tensor, + req_offsets: Optional[torch.Tensor] = None, + max_steps: Optional[int] = None, + index_dtype: Optional[torch.dtype] = None, + ): + """Build the index. + + Arguments: + index_dtype: torch.dtype to use for indices (defaults to torch.int32). + num_steps (index_dtype): Number of steps/tokens for each request + req_offsets (index_dtype): Index offset at which the data for each request starts. + If not provided, it is computed using calculate_request_offsets(). + max_steps (int): The largest value allowed to occur in num_steps. + If not provided, it is computed from 'num_steps'. + """ + super().__init__(num_steps=num_steps, + req_offsets=req_offsets, + max_steps=max_steps, + index_dtype=index_dtype) + + +# After gathering results with _PackedStepIndexer in TorchSampler._sample_batched_by_strategy, +# they need to be scattered into result buffers in TorchSampler._unbatch_sampling_results. +# This helper class provides the translation from linear packed request + step/token indices +# to unpacked / rectangular-tensor (but still linearized) request + step/token indices. +# +# NB: Consider switching to torch.nested (cf. https://github.com/pytorch/pytorch/issues/80577) +class _UnpackedStepIndexer(_StridedStepIndexTranslator): + + class DimOrder(enum.Enum): + SLOT_MAJOR = enum.auto() + STEP_MAJOR = enum.auto() + + def __init__( + self, + *, + seq_slots: torch.Tensor, + num_steps: torch.Tensor, + dim_order: DimOrder = DimOrder.SLOT_MAJOR, + steps_dim_size: int, + slots_dim_size: Optional[int] = None, + index_dtype: Optional[torch.dtype] = None, + ): + """Build the index. + + Arguments: + index_dtype: torch.dtype to use for indices (defaults to torch.int32). + seq_slots (index_dtype): Request indices in unpacked tensor, enumerated in packed tensor + request order. + num_steps (index_dtype): Number of steps/tokens for each request + dim_order: Memory layout of indexed tensor. + steps_dim_size (int): The extent of the step dimension in the unpacked tensor. + slots_dim_size (int): The extent of the slot dimension in the unpacked tensor. + Required if dim_order is DimOrder.STEP_MAJOR. + """ + if dim_order is self.DimOrder.SLOT_MAJOR: + super().__init__( + num_steps=num_steps, + req_offsets=(steps_dim_size * seq_slots), + max_steps=steps_dim_size, + index_dtype=index_dtype, + ) + elif dim_order is self.DimOrder.STEP_MAJOR: + if slots_dim_size is None: + raise ValueError("slots_dim_size required for step-major order") + super().__init__( + num_steps=num_steps, + req_offsets=seq_slots, # no need for stride here + max_steps=steps_dim_size, + index_dtype=index_dtype, + step_stride=slots_dim_size, + ) + else: + raise ValueError(f"Invalid dim_order: {dim_order}") + + class TorchSampler(Sampler): BEAM = 0 MAX_BEAM_WIDTH = BEAM + 1 + @override def is_generation_model(self) -> bool: return True @@ -376,11 +772,9 @@ class Args: max_draft_len: int max_num_sequences: int max_beam_width: int - enable_mixed_sampler: bool def __init__(self, args: Args): self.max_seq_len = args.max_seq_len - self.enable_mixed_sampler = args.enable_mixed_sampler self.max_tokens = args.max_draft_len + 1 assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1" self.max_num_sequences = args.max_num_sequences @@ -410,6 +804,7 @@ def get_generator(self, device: torch.device) -> torch.Generator: # Fallback to a default seed if not set self._generator = torch.Generator(device=device) self._generator.manual_seed(self._global_seed) + assert self._generator.device == device return self._generator def _meet_max_token_stop_criteria(self, request: LlmRequest): @@ -493,11 +888,12 @@ def _process_draft_tokens_greedy(self, request: LlmRequest, def _process_draft_tokens_rejection_sampling( self, request: LlmRequest, new_tokens: torch.Tensor) -> int: - sampling_strategy = request_strategy(request) + sampling_strategy = _request_strategy(request) generator = self.get_generator(request.py_draft_logits.device) _, draft_probs = sample(sampling_strategy, - request.py_draft_logits[0], + request.py_draft_logits, generator=generator) + draft_probs = draft_probs.squeeze(0) target_probs = request.py_target_probs rejected_indices = get_rejected_indices(draft_probs, target_probs, generator, @@ -541,6 +937,8 @@ def process_draft_tokens(self, request: LlmRequest, return self._process_draft_tokens_rejection_sampling( request, new_tokens) + @override + @torch.inference_mode() def update_requests(self, state: SampleState) -> None: assert isinstance(state, SampleState) if state.sampler_event: @@ -570,7 +968,9 @@ def update_requests(self, state: SampleState) -> None: self.handle_logprobs(req, state, beam=self.BEAM, count=processed) req.py_decoding_iter += 1 - def log_probs_host(self, scheduled_requests: ScheduledRequests): + def log_probs_host( + self, + scheduled_requests: ScheduledRequests) -> Optional[torch.Tensor]: """Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103""" if any(req.py_return_log_probs for req in scheduled_requests.all_requests()): @@ -580,17 +980,25 @@ def log_probs_host(self, scheduled_requests: ScheduledRequests): pin_memory=True) return None + @override + @torch.inference_mode() def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs: dict[str, torch.Tensor], num_context_logits_prefix_sum: list[int]) -> SampleState: + requests = scheduled_requests.all_requests() new_tokens = self.store.new_tokens log_probs_host = self.log_probs_host(scheduled_requests) - self._process_requests(scheduled_requests, - model_outputs, - new_tokens, - num_context_logits_prefix_sum, - log_probs_host=log_probs_host) - new_tokens_host = new_tokens.to(device="cpu", non_blocking=True) + seq_slots_host = torch.tensor( + [r.py_seq_slot for r in requests], + dtype=torch.int64, # for index_fill_ + pin_memory=True) + new_tokens_host = self._process_requests(scheduled_requests, + model_outputs, + new_tokens, + num_context_logits_prefix_sum, + seq_slots=seq_slots_host, + log_probs_host=log_probs_host) + sampler_event = torch.cuda.Event() sampler_event.record() return SampleState(scheduled_requests=scheduled_requests, @@ -600,62 +1008,345 @@ def sample_async(self, scheduled_requests: ScheduledRequests, sampler_event=sampler_event) @staticmethod - def append_eagle3(tokens: torch.Tensor, model_outputs): + def _apply_d2t(tokens: torch.Tensor, model_outputs) -> None: + """Applies draft-to-target token translation table. + + Modifies tokens in-place. + """ if "d2t" in model_outputs: d2t = model_outputs["d2t"][tokens] tokens += d2t @staticmethod def _apply_embedding_bias( - logits: torch.Tensor, - requests: list[LlmRequest], - steps_per_request: list[int] = None) -> torch.Tensor: + logits: torch.Tensor, + requests: list[LlmRequest], + request_steps: torch.Tensor, + ) -> None: """Apply embedding bias (aka logit bias) to logits. - If steps_per_request is None, assumes 1 step per request (non-batched path). + + Arguments: + request_steps: Number of steps/tokens for each request. + + Modifies logits in-place. """ - # Collect biases and their associated data - bias_list = [] - bias_data = [] # Either indices (fast path) or steps (batched path) - - for i, req in enumerate(requests): - bias = req._py_embedding_bias_1d - if bias is not None: - bias_list.append(bias) - bias_data.append(i if steps_per_request is - None else steps_per_request[i]) - - if not bias_list: - return logits - - bias_tensor = torch.stack(bias_list).to(logits.device, - non_blocking=True) - logits = logits.clone() - - if steps_per_request is None: - # Fast path: direct indexing - indices = torch.tensor(bias_data, device=logits.device) - logits[indices] += bias_tensor + # NB: Unfortunately, Torch provides no combination of torch.index_select (similar to + # torch.Tensor.gather -- allows one-to-many mapping) and addition, analogous to how + # torch.Tensor.scatter_add_ (and it's variant torch.Tensor.index_add_ -- allows + # many-to-one mapping) combine addition with torch.Tensor.scatter_. + # + # Notwithstanding the previous point, there are two options: + # (i) materialize a permuted bias tensor with repeated consecutive rows via + # torch.repeat_interleave and then use torch.Tensor.index_add_ (poor write + # locality / risk of false sharing) + # (ii) materialize the correctly ordered bias tensor via torch.index_select and then + # perform a masked addition (poor read locality for request batches randomly + # mixing uniform and heterogeneous bias tensors, i.e., mixing slices with high + # and low reuse). + # Since read-caching is expected to help in typical cases, option (ii) is implemented here. + + # Track which logits require logit bias application + logits_bias_mask = torch.zeros((logits.size(0), ), + dtype=torch.bool, + pin_memory=True) + + _next_bias_index = 0 + + def provision_bias_index() -> int: + nonlocal _next_bias_index + bias_index = _next_bias_index + _next_bias_index += 1 + return bias_index + + # Indices of unique bias tensors + # + # NB: hash(torch.Tensor) is equivalent to id(torch.Tensor), and does not + # depend on tensor contents, cf. https://github.com/pytorch/pytorch/issues/2569 + bias_to_index: dict[torch.Tensor, + int] = defaultdict(provision_bias_index) + + # Source indices for bias application + bias_gather_indices: list[int] = [] + + # Collect bias information + req_bias = None + for i, (req, steps) in enumerate(zip(requests, request_steps)): + steps = int(steps.item()) + req_bias = req._py_embedding_bias_1d + if req_bias is not None: + logits_bias_mask[i:(i + steps)] = True + req_bias_index = bias_to_index[req_bias] + bias_gather_indices.extend(repeat(req_bias_index, steps)) + + if not bias_to_index: + return + assert req_bias is not None # otherwise bias_to_index is empty + + bias_gather_indices_cuda = torch.tensor(bias_gather_indices, + pin_memory=True, + dtype=torch.int32).to( + logits.device, + non_blocking=True) + logits_bias_mask_cuda = logits_bias_mask.to(logits.device, + non_blocking=True) + biases_tensor = torch.empty((len(bias_to_index), *req_bias.shape), + pin_memory=True) + biases_tensor = torch.stack( + tuple(bias_to_index.keys()), + out=biases_tensor, + ) + biases_tensor_cuda = biases_tensor.to(logits.device, non_blocking=True) + + biases_tensor_cuda = torch.index_select(biases_tensor_cuda, 0, + bias_gather_indices_cuda) + # NB: Avoiding logits[bias_scatter_indices] += biases_tensor (and torch.Tensor.scatter_add_), because it + # is unclear if this allows for repeated indices, cf. + # https://docs.pytorch.org/docs/2.8/generated/torch.Tensor.index_put_.html#torch-tensor-index-put + # and thus introduces read-after-write dependencies (including possible false + # sharing). + logits[logits_bias_mask_cuda] += biases_tensor_cuda + + def _sample_batched_by_strategy( + self, + logits_cuda: torch.Tensor, + requests: list[LlmRequest], + model_outputs: dict[str, torch.Tensor], + *, + cuda_device: torch.device, + log_probs_host: torch.Tensor | None = None, + req_num_steps: torch.Tensor, + req_offsets: torch.Tensor, + steps_dim_size: int, + token_dtype: torch.dtype, + ) -> _BatchedSamplingResult: + requests_by_strategy = _group_requests_by_sampling_strategy( + requests, pin_memory=True) + generator_cuda = self.get_generator(cuda_device) + + # FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens + # + # NB: Currently, "d2t" is applied to draft tokens, but not to draft logits, + # breaking _process_draft_tokens_rejection_sampling. + needs_d2t = "d2t" in model_outputs + if needs_d2t and (len(requests_by_strategy) > 1 or + (requests_by_strategy + and next(iter(requests_by_strategy)) != GREEDY)): + raise ValueError("d2t does not yet support non-greedy sampling") + + # Indexer for accessing tokens in 'logits_cuda', corresponding to the + # requests in 'requests'. + logits_cuda_indexer = _PackedStepIndexer( + num_steps=req_num_steps, + max_steps=steps_dim_size, + req_offsets=req_offsets, + ) + + batched_results: list[tuple[torch.Tensor, torch.Tensor, + torch.Tensor]] = [] + py_draft_logits_indices: list[tuple[LlmRequest, torch.Tensor]] = [ + ] # see _BatchedSamplingResult for details + softmax_index_offset = 0 + for strategy, group_req_indices in requests_by_strategy.items(): + # group_req_indices: Indices of 'requests' entries having the same sampling + # strategy, ordered ascending. + + # Indices of 'group_req_indices' entries corresponding to requests + # with draft logits. + speculation_group_indices = [ + grp_idx + for grp_idx, glob_req_idx in enumerate(group_req_indices) + if requests[glob_req_idx].py_draft_logits is not None + ] + # To skip softmax computation where it is not needed, track + # softmax_req_indices: Indices of 'requests' entries requesting probs + # softmax_grp_indices: Indices of 'speculation_group_indices' entries requesting probs + # speculation_softmax_indices: Indices of 'softmax_grp_indices' entries corresponding + # to requests with draft logits. + if log_probs_host is not None: + softmax_req_indices = group_req_indices + softmax_grp_indices = torch.arange(len(group_req_indices), + dtype=torch.int32) + speculation_softmax_indices = torch.tensor( + speculation_group_indices, dtype=torch.int32) + else: + speculation_group_indices_tensor = torch.tensor( + speculation_group_indices, dtype=torch.int32) + softmax_req_indices = group_req_indices[ + speculation_group_indices_tensor] + softmax_grp_indices = speculation_group_indices_tensor + speculation_softmax_indices = torch.arange( + len(speculation_group_indices), dtype=torch.int32) + + group_logits_cuda_indices = logits_cuda_indexer[group_req_indices] + if group_logits_cuda_indices.numel() != logits_cuda.size(0): + group_logits_cuda_indices_cuda = group_logits_cuda_indices.to( + device=logits_cuda.device, non_blocking=True) + group_logits_cuda = logits_cuda[group_logits_cuda_indices_cuda] + else: + group_logits_cuda = logits_cuda + + # Indexer for accessing tokens in 'group_logits_cuda' (and 'group_next_tokens_cuda') + # corresponding to the requests in 'group_req_indices'. + group_logits_cuda_indexer = _PackedStepIndexer( + num_steps=req_num_steps[group_req_indices], + max_steps=steps_dim_size) + softmax_group_indices_cuda = group_logits_cuda_indexer[ + softmax_grp_indices].to(device=logits_cuda.device, + non_blocking=True) + + # Indexer for accessing tokens in 'group_softmax_cuda' corresponding to the + # requests in 'softmax_req_indices'. + if softmax_req_indices is not group_req_indices: + group_softmax_cuda_indexer = _PackedStepIndexer( + num_steps=req_num_steps[softmax_req_indices], + max_steps=steps_dim_size) + else: + group_softmax_cuda_indexer = group_logits_cuda_indexer + + py_draft_logits_indices += [( + requests[request_idx], + softmax_index_offset + group_softmax_cuda_indexer[softmax_idx], + ) for softmax_idx, request_idx in zip( + speculation_softmax_indices, + softmax_req_indices[speculation_softmax_indices], + )] + + group_next_tokens_cuda, group_softmax_cuda = sample( + strategy, + group_logits_cuda, + generator_cuda, + softmax_indices=cast(torch.IntTensor, + softmax_group_indices_cuda), + ) + group_next_tokens_cuda_int = group_next_tokens_cuda.to( + dtype=token_dtype, non_blocking=True) + batched_results.append( + (group_req_indices, group_next_tokens_cuda_int, + group_softmax_cuda)) + softmax_index_offset += group_softmax_cuda.size(0) + # Batched sampling results, see _BatchedSamplingResult for details. + if len(batched_results) > 1: + batch_req_indices = torch.cat([res[0] for res in batched_results]) + batch_next_tokens_cuda_int = torch.cat( + [res[1] for res in batched_results]) + batch_softmax_cuda = torch.cat([res[2] for res in batched_results]) else: - # Batched path: expand biases and use boolean mask - expanded_biases = torch.repeat_interleave(bias_tensor, - torch.tensor( - bias_data, - device=logits.device), - dim=0) - - mask = torch.zeros(sum(steps_per_request), - dtype=torch.bool, - device=logits.device) - offset = 0 - for i, req in enumerate(requests): - steps = steps_per_request[i] - if req._py_embedding_bias_1d is not None: - mask[offset:offset + steps] = True - offset += steps + (batch_req_indices, batch_next_tokens_cuda_int, + batch_softmax_cuda), = batched_results + + # FIXME: This should be done in ModelDrafter.prepare_draft_tokens, but for performance + # parity py_draft_tokens might need to be replaced / backed by a torch.Tensor, so + # that d2t can be applied in a batched manner similar to the code below. + if needs_d2t: + # NB: The sampler is either called directly by PyExecutor, for the target model, + # or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former + # case there are 1 + get_draft_token_length(request) tokens per request. In the + # latter case, only there is always only 1 token per request because draft + # tokens are sampled one-by-one. + self._apply_d2t(batch_next_tokens_cuda_int, model_outputs) + + return _BatchedSamplingResult( + batch_req_indices=batch_req_indices, + batch_next_tokens_cuda_int=batch_next_tokens_cuda_int, + batch_softmax_cuda=batch_softmax_cuda, + py_draft_logits_indices=py_draft_logits_indices, + ) - logits[mask] += expanded_biases + def _unbatch_sampling_results( + self, + batched_sampling_result: _BatchedSamplingResult, + *, + new_tokens_cuda: torch.Tensor, + req_num_steps: torch.Tensor, + seq_slots: torch.Tensor, + log_probs_host: torch.Tensor | None = None, + ) -> torch.Tensor: + beam = self.BEAM + assert beam == 0, "beam_width != 1 not supported" + + batch_req_indices = batched_sampling_result.batch_req_indices + batch_next_tokens_cuda_int = batched_sampling_result.batch_next_tokens_cuda_int + batch_softmax_cuda = batched_sampling_result.batch_softmax_cuda + py_draft_logits_indices = batched_sampling_result.py_draft_logits_indices + + def _dims_canonically_ordered(t: torch.Tensor) -> bool: + return len(t.dim_order( + ambiguity_check=[torch.contiguous_format])) == t.ndim + + # Assert destination tensor dimensions are canonically ordered ("row"-major); this + # matters for element ordering in the .view(...).scatter_(...) calls below. + assert _dims_canonically_ordered(new_tokens_cuda) + assert log_probs_host is None or _dims_canonically_ordered( + log_probs_host) + + # new_tokens_cuda indexed by + # slice(0, steps), slot, beam + # log_probs_host indexed by + # slot, beam, slice(0, steps) + # batch_... tensors indexed by slice(batch_req_index, batch_req_index + steps) + # + if log_probs_host is not None: + assert new_tokens_cuda.size(0) == log_probs_host.size(-2) + + # Construct index mapping from slice indices of computed tensors + # (packed request_idx and step dimensions) to linearized indices + # in (steps, seq_slot). + batch_destination_cuda_indexer = _UnpackedStepIndexer( + seq_slots=seq_slots[batch_req_indices], + num_steps=req_num_steps[batch_req_indices], + steps_dim_size=new_tokens_cuda.size(0), + slots_dim_size=new_tokens_cuda.size(1), + dim_order=_UnpackedStepIndexer.DimOrder.STEP_MAJOR, + index_dtype=torch.int64, # enforced by Tensor.scatter_ + ) - return logits + # Batch update output tensors + batch_dest_indices_1d_cuda = batch_destination_cuda_indexer[:].to( + new_tokens_cuda.device, non_blocking=True) + new_tokens_cuda.view(-1, + *new_tokens_cuda.shape[2:])[:, beam, ...].scatter_( + 0, batch_dest_indices_1d_cuda, + batch_next_tokens_cuda_int) + new_tokens_host = new_tokens_cuda.to("cpu", non_blocking=True) + # NB: In order to avoid a scatter_ on the host and the necessary D2H copy + synchronization, + # the 'step' and 'seq_slot' dimensions are unpacked on GPU and later asynchronously + # copied into the destination buffer. Note that this overwrites all 'step' and token slots for the + # requests in 'requests' (passed to _process_requests). In fact, the current implementation + # even overwrites the destination tensors completely (including slices corresponding to request + # slots not present in 'requests', cf. 'FIXME' below). + if log_probs_host is not None: + # FIXME: If log_probs_host were indexed by request indices, rather than request slots, this + # tensor could be packed densely along the request axis. + log_probs_cuda = torch.empty_like( + log_probs_host, device=batch_dest_indices_1d_cuda.device) + # FIXME: Needs a separate indexer because tensor layout differs from new_tokens_cuda + batch_dest_probs_cuda_indexer = _UnpackedStepIndexer( + seq_slots=seq_slots[batch_req_indices], + num_steps=req_num_steps[batch_req_indices], + steps_dim_size=new_tokens_cuda.size(0), + slots_dim_size=new_tokens_cuda.size(1), + dim_order=_UnpackedStepIndexer.DimOrder.SLOT_MAJOR, + index_dtype=torch.int64, # enforced by Tensor.scatter_ + ) + batch_dest_probs_indices_cuda = batch_dest_probs_cuda_indexer[:].to( + batch_softmax_cuda.device, non_blocking=True) + # NB: torch.arange is needed to enable "advanced indexing", + # cf. https://numpy.org/devdocs/user/basics.indexing.html#integer-array-indexing + batch_token_probs = batch_softmax_cuda[ + torch.arange(batch_softmax_cuda.size(0), + device=batch_softmax_cuda.device, + dtype=torch.int32), batch_next_tokens_cuda_int] + log_probs_cuda[:, beam, + ...].view(-1, *log_probs_cuda.shape[3:]).scatter_( + 0, batch_dest_probs_indices_cuda, + torch.log(batch_token_probs)) + log_probs_host.copy_(log_probs_cuda, non_blocking=True) + # For requests with LlmRequest.py_draft_logits, return py_target_probs + for request, batch_softmax_index_cuda in py_draft_logits_indices: + request.py_target_probs = batch_softmax_cuda[ + batch_softmax_index_cuda].clone() + + return new_tokens_host @staticmethod @torch.inference_mode() @@ -687,103 +1378,141 @@ def _apply_min_length_penalty(logits: torch.Tensor, current_offset += num_steps[index] return logits - def _process_requests(self, - scheduled_requests: ScheduledRequests, - model_outputs: dict[str, torch.Tensor], - new_tokens: torch.Tensor, - num_context_logits_prefix_sum: list[int], - *, - log_probs_host: torch.Tensor | None = None): - beam_width = self.MAX_BEAM_WIDTH - beam = self.BEAM - - # raw_logits should contain only the logits from the gen requests. - # If return context logits is requested, fetch only the logits from gen requests. + @staticmethod + def _select_generated_logits( + scheduled_requests: ScheduledRequests, + raw_logits_cuda: torch.Tensor, + *, + req_num_generation_steps: torch.Tensor, + num_context_logits_prefix_sum: list[int], + generation_requests_total_steps: int, + ) -> torch.Tensor: + # raw_logits should contain only the generated logits. + # If return context logits is requested, select only the generated logits. + # + # NB: Context request logits always precede generation request logits, also + # requests == scheduled_requests.context_requests + scheduled_requests.generation_requests if any(r.py_return_context_logits for r in scheduled_requests.context_requests): - gen_logits_indices = [] - total_context_logits = num_context_logits_prefix_sum[-1] - for i in range(len(scheduled_requests.context_requests)): - gen_logits_indices.append(num_context_logits_prefix_sum[i + 1] - - 1) - gen_logits_indices.extend( - range( - total_context_logits, total_context_logits + - len(scheduled_requests.generation_requests))) - raw_logits = model_outputs["logits"][gen_logits_indices] - else: - raw_logits = model_outputs["logits"] + assert len(num_context_logits_prefix_sum) == len( + scheduled_requests.context_requests) + 1 + req_num_generation_steps_cuda = req_num_generation_steps.to( + raw_logits_cuda.device, non_blocking=True) + context_req_offsets_cuda = torch.tensor( + num_context_logits_prefix_sum, + dtype=torch.int32, + pin_memory=True).to(device=raw_logits_cuda.device, + non_blocking=True) + + # Since the goal is to keep the req_num_steps[i] last tokens for each requests[i], + # only end-offsets of the token storage locations matter. + next_context_req_offsets_cuda = context_req_offsets_cuda.roll( + -1) # trailing '0' is overwritten below + # Since logits for generation requests are densely packed, cover them all by a single + # fictituous entry in 'context_req_offsets_cuda'. + if scheduled_requests.generation_requests: + req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[:( + len(scheduled_requests.context_requests) + 1)].clone() + req_num_steps_fictitious_cuda[ + -1] = generation_requests_total_steps + next_context_req_offsets_cuda[-1] = ( + next_context_req_offsets_cuda[-2] + + req_num_steps_fictitious_cuda[-1]) + else: + req_num_steps_fictitious_cuda = req_num_generation_steps_cuda[:len( + scheduled_requests.context_requests)] + next_context_req_offsets_cuda = next_context_req_offsets_cuda[: + -1] + + # Now, the generated tokens for context request i are at indices + # range(next_context_req_offsets_cuda[i] - req_num_steps_fictitious_cuda[i], next_context_req_offsets_cuda[i]) + # And if generation requests are present, those tensors each include a trailing entry selecting + # all tokens/logits generated by all generation requests. + indices_to_keep_cuda = torch_multi_arange( + starts=(next_context_req_offsets_cuda - + req_num_steps_fictitious_cuda), + ends=next_context_req_offsets_cuda, + ) + + raw_logits_cuda = raw_logits_cuda[indices_to_keep_cuda] + return raw_logits_cuda + + @nvtx_range("_process_requests") + def _process_requests( + self, + scheduled_requests: ScheduledRequests, + model_outputs: dict[str, torch.Tensor], + new_tokens_cuda: torch.Tensor, + num_context_logits_prefix_sum: list[int], + *, + seq_slots: torch.Tensor, + log_probs_host: torch.Tensor | None = None) -> torch.Tensor: + seq_slots = seq_slots.to(dtype=torch.int32) # int32 suffices here + + raw_logits_cuda = model_outputs["logits"] requests = scheduled_requests.all_requests() - num_steps = [1 + get_draft_token_length(req) for req in requests] - raw_logits = self._apply_min_length_penalty(raw_logits, requests, - num_steps) - sum_steps = sum(num_steps) - no_draft_tokens = len(requests) == sum_steps - fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None - - seq_slots_host = torch.as_tensor([r.py_seq_slot for r in requests]) - seq_slots = seq_slots_host.to(device="cuda", non_blocking=True) - - if fast_path: - logits = raw_logits[:len(requests)] - logits = self._apply_embedding_bias(logits, requests) - next_tokens = torch.argmax(logits, dim=-1) - self.append_eagle3(next_tokens, model_outputs) - int_next_tokens = next_tokens.to(torch.int, non_blocking=True) - next_tokens = int_next_tokens.view(1, -1, beam_width) - new_tokens[:1].index_copy_(1, seq_slots, next_tokens) - return + cuda_device = raw_logits_cuda.device + req_num_steps_list = [ + 1 + get_draft_token_length(req) for req in requests + ] + req_num_steps = torch.tensor(req_num_steps_list, + dtype=torch.int32, + pin_memory=True) + # NB: These offsets consider generated tokens _only_ (draft and target, but not context) + # and are thus only correct after _select_generated_logits() below. + req_offsets, sum_steps = _PackedStepIndexer.calculate_request_offsets( + req_num_steps, pin_memory=True) + + raw_logits_cuda = self._select_generated_logits( + scheduled_requests, + raw_logits_cuda, + req_num_generation_steps=req_num_steps, + num_context_logits_prefix_sum=num_context_logits_prefix_sum, + generation_requests_total_steps=( + # NB: requests == scheduled_requests.context_requests + scheduled_requests.generation_requests + sum_steps - cast( + int, req_offsets[len( + scheduled_requests.context_requests)].item()) + if scheduled_requests.generation_requests else 0), + ) - strategies = sampling_strategies(requests) - batched_next_tokens, batched_softmax = None, None - batched_strategy: Strategy | None = GREEDY - if self.enable_mixed_sampler: - assert "d2t" not in model_outputs, "eagle3 does not yet support non-greedy sampling" - if len(set(strategies)) == 1: - batched_strategy = strategies[0] - else: - batched_strategy = None - generator = self.get_generator(raw_logits.device) - if batched_strategy is not None: - logits = raw_logits[:sum_steps] - # Collect steps per request for batched strategy - steps_per_request = [ - 1 + get_draft_token_length(req) for req in requests - ] - logits = self._apply_embedding_bias(logits, requests, - steps_per_request) - batched_next_tokens, batched_softmax = sample( - batched_strategy, logits, generator) - self.append_eagle3(batched_next_tokens, model_outputs) - - offset = 0 - for i, (strategy, slot, steps, request) in enumerate( - zip(strategies, seq_slots_host, num_steps, requests)): - input_slice = slice(offset, offset + steps) - logits = raw_logits[input_slice] - - req = requests[i] - - if batched_next_tokens is None: - logits = self._apply_embedding_bias(logits, [req]) - next_tokens, softmax = sample(strategy, logits, generator) - else: - # Batched processing already applied bias, just use the results - next_tokens = batched_next_tokens[input_slice] - softmax = batched_softmax[input_slice] - current_slice = slice(0, steps), slot, beam - new_tokens[current_slice] = next_tokens - if request.py_draft_logits is not None: - request.py_target_probs = softmax.clone() - if log_probs_host is not None: - assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze" - token_probs = torch.gather( - softmax, dim=1, index=next_tokens.unsqueeze(1)).squeeze(-1) - log_probs = torch.log(token_probs) - log_probs_host[slot, beam, :steps].copy_(log_probs, - non_blocking=True) - offset += steps + # Handle embedding bias + logits_cuda = raw_logits_cuda[:sum_steps] + self._apply_embedding_bias(logits_cuda, requests, req_num_steps) + + logits_cuda = self._apply_min_length_penalty(logits_cuda, requests, + req_num_steps_list) + + # Perform sampling in batches + batched_sampling_result = self._sample_batched_by_strategy( + logits_cuda, + requests, + model_outputs, + cuda_device=cuda_device, + log_probs_host=log_probs_host, + req_offsets=req_offsets, + steps_dim_size=new_tokens_cuda.size(0), + req_num_steps=req_num_steps, + token_dtype=new_tokens_cuda.dtype, + ) + + # Fill results into output buffers + new_tokens_host = self._unbatch_sampling_results( + batched_sampling_result, + new_tokens_cuda=new_tokens_cuda, + log_probs_host=log_probs_host, + req_num_steps=req_num_steps, + seq_slots=seq_slots, + ) + + # NB: update_requests syncs w/ device computation and async D2H copies + return new_tokens_host + + @override + def should_provide_draft_probs(self, request: LlmRequest) -> bool: + # Do not request draft probs when sampling is greedy. + return _request_strategy(request) is not GREEDY class Algorithms: @@ -816,6 +1545,7 @@ class TRTLLMSampler(Sampler): MAX_DECODING_TOKENS = 1 # It must be 1 when not in speculative decoding SampleState = SampleStateTRTLLM + @override def is_generation_model(self) -> bool: return True @@ -968,6 +1698,7 @@ def _update_cache_indirection_buffer(self, @torch.inference_mode() @nvtx_range("sample_async") + @override def sample_async( self, scheduled_requests: ScheduledRequests, model_outputs, num_context_logits_prefix_sum: list[int]) -> SampleStateTRTLLM: @@ -1059,6 +1790,7 @@ def sample_async( finalize_events=finalize_events) @torch.inference_mode() + @override def update_requests(self, state: SampleStateTRTLLM): assert isinstance(state, SampleStateTRTLLM) if state.scheduled_requests.batch_size == 0: diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index 7fb7d9f0736..a018809c41b 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import List, Optional, Set +from typing import TYPE_CHECKING, List, Optional, Set import torch from torch import nn @@ -15,6 +15,9 @@ from .interface import SpecMetadata from .mtp import MTPSampler +if TYPE_CHECKING: + from ...llmapi.llm_args import EagleDecodingConfig + class Eagle3ResourceManager(BaseResourceManager): """ diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 5e2080c1d71..19673982b0b 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -12,14 +12,14 @@ from ..pyexecutor.handle_logits import HandleLogits from ..pyexecutor.llm_request import LlmRequest, LlmRequestState from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager -from ..pyexecutor.sampler import (Sampler, SampleState, SampleStateTensors, - TorchSampler) +from ..pyexecutor.sampler import Sampler, SampleState, SampleStateTensors from ..pyexecutor.scheduler import ScheduledRequests from ..pyexecutor.seq_slot_manager import SeqSlotManager from ..speculative.mtp import SampleStateTensorsMTP from .drafter import Drafter if TYPE_CHECKING: + from ...llmapi.llm_args import DecodingBaseConfig from ..pyexecutor.model_engine import ModelEngine from .interface import SpeculativeDecodingMode @@ -56,7 +56,7 @@ def __init__( if draft_model_engine is None: raise ValueError("draft_model_engine cannot be None") if max_draft_tokens < 0: - raise ValueError(f"max_draft_tokens must be >= 0") + raise ValueError("max_draft_tokens must be >= 0") # Model and resource management self.draft_model_engine = draft_model_engine @@ -68,30 +68,32 @@ def __init__( self.max_draft_tokens = max_draft_tokens # Sampling self.sampler = sampler - self._request_draft_logits = False - if isinstance(sampler, TorchSampler): - self._request_draft_logits = sampler.enable_mixed_sampler self.guided_decoder = guided_decoder self.use_static_draft_loop = draft_model_engine.model_is_wrapped if self.use_static_draft_loop: # TODO: enable sampling/guided decoding on static draft loop assert guided_decoder is None - assert not sampler.enable_mixed_sampler + assert spec_config._allow_greedy_draft_tokens def _create_draft_request(self, request: LlmRequest, input_tokens: Optional[List]) -> LlmRequest: """Create a draft request with common parameters.""" - return LlmRequest(input_tokens=input_tokens, - request_id=request.py_request_id, - max_new_tokens=request.py_max_new_tokens, - sampling_config=request.sampling_config, - guided_decoding_params=request.guided_decoding_params, - target_seq_slot=request.py_seq_slot, - return_perf_metrics=request.return_perf_metrics, - is_streaming=False, - is_draft=True, - return_generation_logits=self._request_draft_logits) + return LlmRequest( + input_tokens=input_tokens, + request_id=request.py_request_id, + max_new_tokens=request.py_max_new_tokens, + sampling_config=request.sampling_config, + guided_decoding_params=request.guided_decoding_params, + target_seq_slot=request.py_seq_slot, + return_perf_metrics=request.return_perf_metrics, + is_streaming=False, + exclude_last_generation_logits= + True, # prepare_draft_tokens uses overlap scheduling + is_draft=True, + # NB: self.sampler is shared with PyExecutor + return_generation_logits=self.sampler.should_provide_draft_probs( + request)) def _initialize_draft_tokens(self, request: LlmRequest) -> Tuple[int, int]: """Initialize draft token tracking for a request.""" @@ -285,23 +287,20 @@ def sample_async(self, draft_batch: ScheduledRequests, outputs: Dict[str, Any]) -> Optional[SampleState]: """Sample tokens from draft model outputs.""" try: - if self.sampler is not None: - num_context_logits_prefix_sum = [0] - prefix_sum = 0 - for request in draft_batch.context_requests: - prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1 - num_context_logits_prefix_sum.append(prefix_sum) - - HandleLogits()( - draft_batch.context_requests, - draft_batch.generation_requests, outputs["logits"], - self.sampler.beam_width(draft_batch.all_requests()), - num_context_logits_prefix_sum, - self.sampler.is_generation_model()) - - return self.sampler.sample_async(draft_batch, outputs, - num_context_logits_prefix_sum) - return None + num_context_logits_prefix_sum = [0] + prefix_sum = 0 + for request in draft_batch.context_requests: + prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1 + num_context_logits_prefix_sum.append(prefix_sum) + + HandleLogits()(draft_batch.context_requests, + draft_batch.generation_requests, outputs["logits"], + self.sampler.beam_width(draft_batch.all_requests()), + num_context_logits_prefix_sum, + self.sampler.is_generation_model()) + + return self.sampler.sample_async(draft_batch, outputs, + num_context_logits_prefix_sum) except Exception as e: logger.error(f"Error in sampling: {str(e)}") return None @@ -317,8 +316,7 @@ def update_request_states(self, def update_requests(self, sample_state: SampleState) -> None: """Update requests with sample state.""" - if self.sampler is not None: - self.sampler.update_requests(sample_state) + self.sampler.update_requests(sample_state) def process_decoded_tokens( self, draft_batch: ScheduledRequests, @@ -334,8 +332,7 @@ def process_decoded_tokens( continue target_model_req.py_draft_tokens.append(req.get_last_tokens(0)) - if self._request_draft_logits: - target_model_req.py_draft_logits = req.py_result.generation_logits + target_model_req.py_draft_logits = req.py_result.generation_logits # forwards Nones if req.state != LlmRequestState.GENERATION_COMPLETE and len( target_model_req.py_draft_tokens ) < target_model_req.py_draft_pages_allocated: diff --git a/tensorrt_llm/evaluate/json_mode_eval.py b/tensorrt_llm/evaluate/json_mode_eval.py index 37360754e50..122cbd6e7e4 100644 --- a/tensorrt_llm/evaluate/json_mode_eval.py +++ b/tensorrt_llm/evaluate/json_mode_eval.py @@ -64,7 +64,8 @@ def generate_samples(self) -> Iterable[tuple]: schema["x-guidance"] = {"lenient": True} schema = json.dumps(schema) sampling_args = { - "guided_decoding": GuidedDecodingParams(json=schema) + "guided_decoding": GuidedDecodingParams(json=schema), + "temperature": 0, } yield sample["prompt"], sampling_args, sample["completion"], sample[ "schema"] diff --git a/tensorrt_llm/evaluate/mmlu.py b/tensorrt_llm/evaluate/mmlu.py index 92d7ae1171a..b3b3f4ee7cf 100644 --- a/tensorrt_llm/evaluate/mmlu.py +++ b/tensorrt_llm/evaluate/mmlu.py @@ -219,7 +219,7 @@ def generate_samples(self) -> Iterable[tuple]: include_answer=False) prompt = train_prompt + prompt_end label = test_df.iloc[i, test_df.shape[1] - 1] - yield prompt, None, label, subject + yield prompt, {"temperature": 0}, label, subject def compute_score(self, outputs: List[RequestOutput], references: List[str], subjects: List[str]) -> float: diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index ecf0ffdf362..8e2b1818549 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -356,8 +356,12 @@ class DecodingBaseConfig(StrictBaseModel): # When specified, speculation will be disabled at batch sizes above # this value. Otherwise, speculation will always be on. max_concurrency: Optional[int] = None + load_format: Optional[str] = None + # If set, drafting uses greedy sampling, irrespective of sampling parameters. + _allow_greedy_draft_tokens: bool = PrivateAttr(True) + @classmethod def from_dict(cls, data: dict): # dispatch to the correct decoding config @@ -2169,12 +2173,6 @@ class TorchLlmArgs(BaseLlmArgs): description="Attention backend to use.", status="beta") - enable_mixed_sampler: bool = Field( - default=False, - description= - "If true, will iterate over sampling_params of each request and use the corresponding sampling strategy, e.g. top-k, top-p, etc.", - status="beta") - sampler_type: Union[str, SamplerType] = Field( default=SamplerType.auto, description= @@ -2502,7 +2500,6 @@ def get_pytorch_backend_config(self) -> "PyTorchConfig": moe_load_balancer=self.moe_config.load_balancer, attn_backend=self.attn_backend, moe_backend=self.moe_config.backend, - enable_mixed_sampler=self.enable_mixed_sampler, sampler_type=self.sampler_type, kv_cache_dtype=self.kv_cache_config.dtype, mamba_ssm_cache_dtype=self.kv_cache_config.mamba_ssm_cache_dtype, diff --git a/tensorrt_llm/scaffolding/worker.py b/tensorrt_llm/scaffolding/worker.py index f154203d46c..20987c4bfd8 100644 --- a/tensorrt_llm/scaffolding/worker.py +++ b/tensorrt_llm/scaffolding/worker.py @@ -165,7 +165,6 @@ def init_with_new_llm( llm = LLM(model_dir, tokenizer=tokenizer, - enable_mixed_sampler=True, disable_overlap_scheduler=disable_overlap_scheduler, kv_cache_config=kv_cache_config, max_batch_size=max_batch_size, diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 51518bbccdb..d534c57618e 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -111,10 +111,6 @@ methods: annotation: str default: TRTLLM status: beta - enable_mixed_sampler: - annotation: bool - default: False - status: beta sampler_type: annotation: Union[str, tensorrt_llm.llmapi.llm_args.SamplerType] default: auto diff --git a/tests/unittest/llmapi/apps/_test_openai_misc.py b/tests/unittest/llmapi/apps/_test_openai_misc.py index 8cc715389f3..7dcac12304a 100644 --- a/tests/unittest/llmapi/apps/_test_openai_misc.py +++ b/tests/unittest/llmapi/apps/_test_openai_misc.py @@ -94,9 +94,12 @@ async def test_request_cancellation(server: RemoteOpenAIServer, # Request about 2 million tokens for _ in range(200): task = asyncio.create_task( + # FIXME: Some requests complete quickly without temperature=0, + # despite min_tokens being specified, cf. https://nvbugs/5513423 client.chat.completions.create(messages=chat_input, model=model_name, max_tokens=10000, + temperature=0, extra_body={"min_tokens": 10000})) tasks.append(task) diff --git a/tests/unittest/llmapi/test_llm_pytorch.py b/tests/unittest/llmapi/test_llm_pytorch.py index ef2de1350c1..8ff88068896 100644 --- a/tests/unittest/llmapi/test_llm_pytorch.py +++ b/tests/unittest/llmapi/test_llm_pytorch.py @@ -226,19 +226,9 @@ def test_llm_with_postprocess_parallel_and_result_handler(streaming): tp_size=1) -@pytest.mark.parametrize( - "enable_mixed_sampler,enable_logprobs", - [ - (False, False), # Fast path: no mixed sampler, no logits, greedy - (True, - False), # Batched strategy path: mixed sampler enabled, same strategy - (False, - True), # Per-request path: mixed sampler disabled, logprobs enabled - ]) @pytest.mark.part0 -def test_embedding_bias_with_torch_sampler_strategies(enable_mixed_sampler, - enable_logprobs): - """Test embedding bias application in all 3 TorchSampler paths: fast, batched strategy, and per-request""" +def test_embedding_bias_with_torch_sampler_strategies(): + """Test embedding bias application in TorchSampler.""" tokenizer = AutoTokenizer.from_pretrained(llama_model_path) biased_word_id = tokenizer.encode("Z", add_special_tokens=False)[-1] vocab_size_padded = 32000 @@ -250,17 +240,17 @@ def test_embedding_bias_with_torch_sampler_strategies(enable_mixed_sampler, "embedding_bias": embedding_bias, } - if enable_logprobs: - sampling_kwargs["logprobs"] = 1 # All test cases use greedy sampling for simplicity sampling_params = SamplingParams(**sampling_kwargs) - llm_test_harness(llama_model_path, - prompts, ["Z Z Z Z Z Z"], - sampling_params=sampling_params, - backend="pytorch", - enable_mixed_sampler=enable_mixed_sampler) + llm_test_harness( + llama_model_path, + prompts, + ["Z Z Z Z Z Z"], + sampling_params=sampling_params, + backend="pytorch", + ) def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None: @@ -869,7 +859,6 @@ def test_min_tokens(use_speculative: bool): max_batch_size=2, kv_cache_config=global_kvcache_config, max_num_tokens=2048, - enable_mixed_sampler=True, ) if use_speculative: