Skip to content

Commit b3fa8cb

Browse files
committed
refactor
1 parent 654865e commit b3fa8cb

File tree

3 files changed

+122
-85
lines changed

3 files changed

+122
-85
lines changed
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""A layer that compute logits from hidden_stats."""
2+
from typing import Optional
3+
4+
import torch
5+
import torch.nn as nn
6+
7+
from vllm.utils import is_neuron
8+
9+
from vllm.model_executor.parallel_utils.communication_op import (
10+
tensor_model_parallel_gather)
11+
from vllm.model_executor.sampling_metadata import SamplingMetadata
12+
13+
14+
class LogitProcessor(nn.Module):
15+
"""Process logits and apply logits processors from sampling metadata.
16+
17+
This layer does the following:
18+
1. Gather logits from model hidden_states.
19+
2. Scale logits if needed.
20+
3. Apply logits processors (if any).
21+
"""
22+
23+
def __init__(self,
24+
vocab_size: int,
25+
org_vocab_size: Optional[int] = None,
26+
scale: Optional[float] = 1.0) -> None:
27+
"""
28+
Args:
29+
scale: A scaling factor to apply to the logits.
30+
"""
31+
super().__init__()
32+
self.scale = scale
33+
# Transformers-neuronx generate outputs as logits directly.
34+
self.logits_as_hidden_states = is_neuron()
35+
# original vocabulary size (without LoRA).
36+
self.org_vocab_size = org_vocab_size or vocab_size
37+
38+
def forward(
39+
self,
40+
embedding: torch.Tensor,
41+
hidden_states: torch.Tensor,
42+
sampling_metadata: SamplingMetadata,
43+
embedding_bias: Optional[torch.Tensor] = None,
44+
) -> torch.Tensor:
45+
if self.logits_as_hidden_states:
46+
logits = hidden_states
47+
else:
48+
hidden_states = _prune_hidden_states(hidden_states,
49+
sampling_metadata)
50+
51+
# Get the logits for the next tokens.
52+
logits = self._get_logits(hidden_states, embedding, embedding_bias)
53+
54+
logits *= self.scale
55+
56+
# Only perform sampling in the driver worker.
57+
# Note: `_get_logits` is still distributed across TP workers because
58+
# the `embedding` weight is distributed across TP workers.
59+
if not sampling_metadata.perform_sampling:
60+
return None
61+
62+
# Apply logits processors (if any).
63+
logits = _apply_logits_processors(logits, sampling_metadata)
64+
65+
return logits
66+
67+
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
68+
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
69+
# Get the logits for the next tokens.
70+
logits = torch.matmul(hidden_states, embedding.t())
71+
if embedding_bias is not None:
72+
logits += embedding_bias
73+
logits = tensor_model_parallel_gather(logits)
74+
# Remove paddings in vocab (if any).
75+
if logits is not None:
76+
logits = logits[:, :self.org_vocab_size]
77+
return logits
78+
79+
80+
def _prune_hidden_states(
81+
hidden_states: torch.Tensor,
82+
sampling_metadata: SamplingMetadata,
83+
) -> torch.Tensor:
84+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
85+
return hidden_states.index_select(0,
86+
sampling_metadata.selected_token_indices)
87+
88+
89+
def _apply_logits_processors(
90+
logits: torch.Tensor,
91+
sampling_metadata: SamplingMetadata,
92+
) -> torch.Tensor:
93+
logits_row_idx = 0
94+
found_logits_processors = False
95+
for seq_ids, sampling_params in sampling_metadata.seq_groups:
96+
logits_processors = sampling_params.logits_processors
97+
if logits_processors:
98+
found_logits_processors = True
99+
for seq_id in seq_ids:
100+
logits_row = logits[logits_row_idx]
101+
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
102+
for logits_processor in logits_processors:
103+
logits_row = logits_processor(token_ids, logits_row)
104+
logits[logits_row_idx] = logits_row
105+
logits_row_idx += 1
106+
else:
107+
logits_row_idx += len(seq_ids)
108+
if found_logits_processors:
109+
assert logits_row_idx == logits.shape[0]
110+
return logits

vllm/model_executor/layers/sampler.py

Lines changed: 1 addition & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@
44
import torch
55
import torch.nn as nn
66

7-
from vllm.model_executor.parallel_utils.communication_op import (
8-
tensor_model_parallel_gather)
97
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
108
SamplingTensors)
119
from vllm.sampling_params import SamplingParams, SamplingType
1210
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
1311
SamplerOutput, SequenceData, SequenceGroupOutput,
1412
SequenceOutput)
15-
from vllm.utils import is_neuron
1613

1714

1815
class Sampler(nn.Module):
@@ -30,58 +27,14 @@ class Sampler(nn.Module):
3027
parameters (e.g., sampling method, temperature, top-p, top-k, etc.).
3128
"""
3229

33-
def __init__(self,
34-
vocab_size: int,
35-
org_vocab_size: Optional[int] = None) -> None:
36-
super().__init__()
37-
self.vocab_size = vocab_size
38-
# Transformers-neuronx generate outputs as logits directly.
39-
self.logits_as_hidden_states = is_neuron()
40-
# original vocabulary size (without LoRA).
41-
self.org_vocab_size = org_vocab_size or vocab_size
42-
43-
def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor,
44-
embedding_bias: Optional[torch.Tensor]) -> torch.Tensor:
45-
# Get the logits for the next tokens.
46-
logits = torch.matmul(hidden_states, embedding.t())
47-
if embedding_bias is not None:
48-
logits += embedding_bias
49-
logits = tensor_model_parallel_gather(logits)
50-
# Remove paddings in vocab (if any).
51-
if logits is not None:
52-
logits = logits[:, :self.org_vocab_size]
53-
return logits
54-
5530
def forward(
5631
self,
57-
embedding: torch.Tensor,
58-
hidden_states: torch.Tensor,
32+
logits: torch.Tensor,
5933
sampling_metadata: SamplingMetadata,
60-
embedding_bias: Optional[torch.Tensor] = None,
6134
) -> Optional[SamplerOutput]:
62-
# Get the hidden states that we use for sampling.
63-
if self.logits_as_hidden_states:
64-
logits = hidden_states
65-
else:
66-
hidden_states = _prune_hidden_states(hidden_states,
67-
sampling_metadata)
68-
69-
# Get the logits for the next tokens.
70-
logits = self._get_logits(hidden_states, embedding, embedding_bias)
71-
72-
# Only perform sampling in the driver worker.
73-
# Note: `_get_logits` is still distributed across TP workers because
74-
# the `embedding` weight is distributed across TP workers.
75-
# TODO(zhuohan): Change the get_logits part to a separate stage.
76-
if not sampling_metadata.perform_sampling:
77-
return None
78-
7935
assert logits is not None
8036
_, vocab_size = logits.shape
8137

82-
# Apply logits processors (if any).
83-
logits = _apply_logits_processors(logits, sampling_metadata)
84-
8538
# Prepare sampling tensors with pinned memory to avoid blocking.
8639
(sampling_tensors, do_penalties, do_top_p_top_k,
8740
do_min_p) = SamplingTensors.from_sampling_metadata(
@@ -122,15 +75,6 @@ def forward(
12275
prompt_logprobs, sample_logprobs)
12376

12477

125-
def _prune_hidden_states(
126-
hidden_states: torch.Tensor,
127-
sampling_metadata: SamplingMetadata,
128-
) -> torch.Tensor:
129-
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
130-
return hidden_states.index_select(0,
131-
sampling_metadata.selected_token_indices)
132-
133-
13478
def _get_bin_counts_and_mask(
13579
tokens: torch.Tensor,
13680
vocab_size: int,
@@ -148,30 +92,6 @@ def _get_bin_counts_and_mask(
14892
return bin_counts, mask
14993

15094

151-
def _apply_logits_processors(
152-
logits: torch.Tensor,
153-
sampling_metadata: SamplingMetadata,
154-
) -> torch.Tensor:
155-
logits_row_idx = 0
156-
found_logits_processors = False
157-
for seq_ids, sampling_params in sampling_metadata.seq_groups:
158-
logits_processors = sampling_params.logits_processors
159-
if logits_processors:
160-
found_logits_processors = True
161-
for seq_id in seq_ids:
162-
logits_row = logits[logits_row_idx]
163-
token_ids = sampling_metadata.seq_data[seq_id].output_token_ids
164-
for logits_processor in logits_processors:
165-
logits_row = logits_processor(token_ids, logits_row)
166-
logits[logits_row_idx] = logits_row
167-
logits_row_idx += 1
168-
else:
169-
logits_row_idx += len(seq_ids)
170-
if found_logits_processors:
171-
assert logits_row_idx == logits.shape[0]
172-
return logits
173-
174-
17595
def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
17696
output_tokens_tensor: torch.Tensor,
17797
presence_penalties: torch.Tensor,

vllm/model_executor/models/llama.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
QKVParallelLinear,
3838
RowParallelLinear)
3939
from vllm.model_executor.layers.rotary_embedding import get_rope
40-
from vllm.model_executor.layers.sampler import Sampler
40+
from vllm.model_executor.layers.sampler import (Sampler, LogitProcessor)
4141
from vllm.model_executor.layers.vocab_parallel_embedding import (
4242
VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE)
4343
from vllm.model_executor.parallel_utils.parallel_state import (
@@ -325,7 +325,11 @@ def __init__(
325325
# compatibility
326326
if not lora_config else lora_config.lora_vocab_padding_size,
327327
)
328-
self.sampler = Sampler(self.unpadded_vocab_size, config.vocab_size)
328+
329+
logit_scale = getattr(config, "logit_scale", 1.0)
330+
self.logit_processor = LogitProcessor(self.unpadded_vocab_size,
331+
config.vocab_size, logit_scale)
332+
self.sampler = Sampler()
329333

330334
def forward(
331335
self,
@@ -343,8 +347,11 @@ def sample(
343347
hidden_states: torch.Tensor,
344348
sampling_metadata: SamplingMetadata,
345349
) -> Optional[SamplerOutput]:
346-
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
347-
sampling_metadata)
350+
351+
logits = self.logit_processor(self.lm_head.weight, hidden_states,
352+
sampling_metadata)
353+
354+
next_tokens = self.sampler(logits, sampling_metadata)
348355
return next_tokens
349356

350357
def load_weights(self,

0 commit comments

Comments
 (0)