44import torch
55import torch .nn as nn
66
7- from vllm .model_executor .parallel_utils .communication_op import (
8- tensor_model_parallel_gather )
97from vllm .model_executor .sampling_metadata import (SamplingMetadata ,
108 SamplingTensors )
119from vllm .sampling_params import SamplingParams , SamplingType
1210from vllm .sequence import (Logprob , PromptLogprobs , SampleLogprobs ,
1311 SamplerOutput , SequenceData , SequenceGroupOutput ,
1412 SequenceOutput )
15- from vllm .utils import is_neuron
1613
1714
1815class Sampler (nn .Module ):
@@ -30,58 +27,14 @@ class Sampler(nn.Module):
3027 parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
3128 """
3229
33- def __init__ (self ,
34- vocab_size : int ,
35- org_vocab_size : Optional [int ] = None ) -> None :
36- super ().__init__ ()
37- self .vocab_size = vocab_size
38- # Transformers-neuronx generate outputs as logits directly.
39- self .logits_as_hidden_states = is_neuron ()
40- # original vocabulary size (without LoRA).
41- self .org_vocab_size = org_vocab_size or vocab_size
42-
43- def _get_logits (self , hidden_states : torch .Tensor , embedding : torch .Tensor ,
44- embedding_bias : Optional [torch .Tensor ]) -> torch .Tensor :
45- # Get the logits for the next tokens.
46- logits = torch .matmul (hidden_states , embedding .t ())
47- if embedding_bias is not None :
48- logits += embedding_bias
49- logits = tensor_model_parallel_gather (logits )
50- # Remove paddings in vocab (if any).
51- if logits is not None :
52- logits = logits [:, :self .org_vocab_size ]
53- return logits
54-
5530 def forward (
5631 self ,
57- embedding : torch .Tensor ,
58- hidden_states : torch .Tensor ,
32+ logits : torch .Tensor ,
5933 sampling_metadata : SamplingMetadata ,
60- embedding_bias : Optional [torch .Tensor ] = None ,
6134 ) -> Optional [SamplerOutput ]:
62- # Get the hidden states that we use for sampling.
63- if self .logits_as_hidden_states :
64- logits = hidden_states
65- else :
66- hidden_states = _prune_hidden_states (hidden_states ,
67- sampling_metadata )
68-
69- # Get the logits for the next tokens.
70- logits = self ._get_logits (hidden_states , embedding , embedding_bias )
71-
72- # Only perform sampling in the driver worker.
73- # Note: `_get_logits` is still distributed across TP workers because
74- # the `embedding` weight is distributed across TP workers.
75- # TODO(zhuohan): Change the get_logits part to a separate stage.
76- if not sampling_metadata .perform_sampling :
77- return None
78-
7935 assert logits is not None
8036 _ , vocab_size = logits .shape
8137
82- # Apply logits processors (if any).
83- logits = _apply_logits_processors (logits , sampling_metadata )
84-
8538 # Prepare sampling tensors with pinned memory to avoid blocking.
8639 (sampling_tensors , do_penalties , do_top_p_top_k ,
8740 do_min_p ) = SamplingTensors .from_sampling_metadata (
@@ -122,15 +75,6 @@ def forward(
12275 prompt_logprobs , sample_logprobs )
12376
12477
125- def _prune_hidden_states (
126- hidden_states : torch .Tensor ,
127- sampling_metadata : SamplingMetadata ,
128- ) -> torch .Tensor :
129- hidden_states = hidden_states .view (- 1 , hidden_states .shape [- 1 ])
130- return hidden_states .index_select (0 ,
131- sampling_metadata .selected_token_indices )
132-
133-
13478def _get_bin_counts_and_mask (
13579 tokens : torch .Tensor ,
13680 vocab_size : int ,
@@ -148,30 +92,6 @@ def _get_bin_counts_and_mask(
14892 return bin_counts , mask
14993
15094
151- def _apply_logits_processors (
152- logits : torch .Tensor ,
153- sampling_metadata : SamplingMetadata ,
154- ) -> torch .Tensor :
155- logits_row_idx = 0
156- found_logits_processors = False
157- for seq_ids , sampling_params in sampling_metadata .seq_groups :
158- logits_processors = sampling_params .logits_processors
159- if logits_processors :
160- found_logits_processors = True
161- for seq_id in seq_ids :
162- logits_row = logits [logits_row_idx ]
163- token_ids = sampling_metadata .seq_data [seq_id ].output_token_ids
164- for logits_processor in logits_processors :
165- logits_row = logits_processor (token_ids , logits_row )
166- logits [logits_row_idx ] = logits_row
167- logits_row_idx += 1
168- else :
169- logits_row_idx += len (seq_ids )
170- if found_logits_processors :
171- assert logits_row_idx == logits .shape [0 ]
172- return logits
173-
174-
17595def _apply_penalties (logits : torch .Tensor , prompt_tokens_tensor : torch .Tensor ,
17696 output_tokens_tensor : torch .Tensor ,
17797 presence_penalties : torch .Tensor ,
0 commit comments