Skip to content

Commit 8b43a1e

Browse files
Went-Lianghissu-hyvarinen
authored andcommitted
[Model] Support math-shepherd-mistral-7b-prm model (vllm-project#9697)
Signed-off-by: Went-Liang <[email protected]>
1 parent c0276c0 commit 8b43a1e

File tree

14 files changed

+312
-62
lines changed

14 files changed

+312
-62
lines changed

vllm/config.py

Lines changed: 85 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -112,38 +112,58 @@ class ModelConfig:
112112
Defaults to 'auto' which defaults to 'hf'.
113113
mm_processor_kwargs: Arguments to be forwarded to the model's processor
114114
for multi-modal data, e.g., image processor.
115+
pooling_type: Used to configure the pooling method in the embedding
116+
model.
117+
pooling_norm: Used to determine whether to normalize the pooled
118+
data in the embedding model.
119+
pooling_softmax: Used to determine whether to softmax the pooled
120+
data in the embedding model.
121+
pooling_step_tag_id: When pooling_step_tag_id is not -1, it indicates
122+
that the score corresponding to the pooling_step_tag_id in the
123+
generated sentence should be returned. Otherwise, it returns
124+
the scores for all tokens.
125+
pooling_returned_token_ids: pooling_returned_token_ids represents a
126+
list of indices for the vocabulary dimensions to be extracted,
127+
such as the token IDs of good_token and bad_token in the
128+
math-shepherd-mistral-7b-prm model.
115129
"""
116130

117-
def __init__(self,
118-
model: str,
119-
task: Union[TaskOption, _Task],
120-
tokenizer: str,
121-
tokenizer_mode: str,
122-
trust_remote_code: bool,
123-
dtype: Union[str, torch.dtype],
124-
seed: int,
125-
revision: Optional[str] = None,
126-
code_revision: Optional[str] = None,
127-
rope_scaling: Optional[dict] = None,
128-
rope_theta: Optional[float] = None,
129-
tokenizer_revision: Optional[str] = None,
130-
max_model_len: Optional[int] = None,
131-
spec_target_max_model_len: Optional[int] = None,
132-
quantization: Optional[str] = None,
133-
quantization_param_path: Optional[str] = None,
134-
enforce_eager: Optional[bool] = None,
135-
max_context_len_to_capture: Optional[int] = None,
136-
max_seq_len_to_capture: Optional[int] = None,
137-
max_logprobs: int = 20,
138-
disable_sliding_window: bool = False,
139-
skip_tokenizer_init: bool = False,
140-
served_model_name: Optional[Union[str, List[str]]] = None,
141-
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
142-
use_async_output_proc: bool = True,
143-
override_neuron_config: Optional[Dict[str, Any]] = None,
144-
config_format: ConfigFormat = ConfigFormat.AUTO,
145-
chat_template_text_format: str = "string",
146-
mm_processor_kwargs: Optional[Dict[str, Any]] = None) -> None:
131+
def __init__(
132+
self,
133+
model: str,
134+
task: Union[TaskOption, _Task],
135+
tokenizer: str,
136+
tokenizer_mode: str,
137+
trust_remote_code: bool,
138+
dtype: Union[str, torch.dtype],
139+
seed: int,
140+
revision: Optional[str] = None,
141+
code_revision: Optional[str] = None,
142+
rope_scaling: Optional[dict] = None,
143+
rope_theta: Optional[float] = None,
144+
tokenizer_revision: Optional[str] = None,
145+
max_model_len: Optional[int] = None,
146+
spec_target_max_model_len: Optional[int] = None,
147+
quantization: Optional[str] = None,
148+
quantization_param_path: Optional[str] = None,
149+
enforce_eager: Optional[bool] = None,
150+
max_context_len_to_capture: Optional[int] = None,
151+
max_seq_len_to_capture: Optional[int] = None,
152+
max_logprobs: int = 20,
153+
disable_sliding_window: bool = False,
154+
skip_tokenizer_init: bool = False,
155+
served_model_name: Optional[Union[str, List[str]]] = None,
156+
limit_mm_per_prompt: Optional[Mapping[str, int]] = None,
157+
use_async_output_proc: bool = True,
158+
override_neuron_config: Optional[Dict[str, Any]] = None,
159+
config_format: ConfigFormat = ConfigFormat.AUTO,
160+
chat_template_text_format: str = "string",
161+
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
162+
pooling_type: Optional[str] = None,
163+
pooling_norm: Optional[bool] = None,
164+
pooling_softmax: Optional[bool] = None,
165+
pooling_step_tag_id: Optional[int] = None,
166+
pooling_returned_token_ids: Optional[List[int]] = None) -> None:
147167
self.model = model
148168
self.tokenizer = tokenizer
149169
self.tokenizer_mode = tokenizer_mode
@@ -224,6 +244,13 @@ def __init__(self,
224244
supported_tasks, task = self._resolve_task(task, self.hf_config)
225245
self.supported_tasks = supported_tasks
226246
self.task: Final = task
247+
self.pooler_config = self._init_pooler_config(
248+
pooling_type,
249+
pooling_norm,
250+
pooling_softmax,
251+
pooling_step_tag_id,
252+
pooling_returned_token_ids,
253+
)
227254

228255
self._verify_quantization()
229256
self._verify_cuda_graph()
@@ -242,6 +269,23 @@ def _init_multimodal_config(
242269

243270
return None
244271

272+
def _init_pooler_config(
273+
self,
274+
pooling_type: Optional[str] = None,
275+
pooling_norm: Optional[bool] = None,
276+
pooling_softmax: Optional[bool] = None,
277+
pooling_step_tag_id: Optional[int] = None,
278+
pooling_returned_token_ids: Optional[List[int]] = None
279+
) -> Optional["PoolerConfig"]:
280+
if self.task == "embedding":
281+
return PoolerConfig(
282+
pooling_type=pooling_type,
283+
pooling_norm=pooling_norm,
284+
pooling_softmax=pooling_softmax,
285+
pooling_step_tag_id=pooling_step_tag_id,
286+
pooling_returned_token_ids=pooling_returned_token_ids)
287+
return None
288+
245289
def _init_attention_free(self) -> bool:
246290
architectures = getattr(self.hf_config, "architectures", [])
247291
return ModelRegistry.is_attention_free_model(architectures)
@@ -1660,6 +1704,17 @@ class MultiModalConfig:
16601704
# TODO: Add configs to init vision tower or not.
16611705

16621706

1707+
@dataclass
1708+
class PoolerConfig:
1709+
"""Controls the behavior of pooler in embedding model"""
1710+
1711+
pooling_type: Optional[str] = None
1712+
pooling_norm: Optional[bool] = None
1713+
pooling_softmax: Optional[bool] = None
1714+
pooling_step_tag_id: Optional[int] = None
1715+
pooling_returned_token_ids: Optional[List[int]] = None
1716+
1717+
16631718
_STR_DTYPE_TO_TORCH_DTYPE = {
16641719
"half": torch.float16,
16651720
"float16": torch.float16,

vllm/engine/arg_utils.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,13 @@ class EngineArgs:
184184
mm_processor_kwargs: Optional[Dict[str, Any]] = None
185185
scheduling_policy: Literal["fcfs", "priority"] = "fcfs"
186186

187+
# Pooling configuration.
188+
pooling_type: Optional[str] = None
189+
pooling_norm: Optional[bool] = None
190+
pooling_softmax: Optional[bool] = None
191+
pooling_step_tag_id: Optional[int] = None
192+
pooling_returned_token_ids: Optional[List[int]] = None
193+
187194
def __post_init__(self):
188195
if not self.tokenizer:
189196
self.tokenizer = self.model
@@ -850,6 +857,58 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
850857
'priority (lower value means earlier handling) and time of '
851858
'arrival deciding any ties).')
852859

860+
parser.add_argument(
861+
'--pooling-type',
862+
choices=['LAST', 'ALL', 'CLS', 'STEP'],
863+
default=None,
864+
help='Used to configure the pooling method in the embedding model.'
865+
)
866+
867+
parser.add_argument('--pooling-norm',
868+
default=None,
869+
action='store_true',
870+
help="Used to determine whether to normalize "
871+
"the pooled data in the embedding model.")
872+
873+
parser.add_argument('--no-pooling-norm',
874+
default=None,
875+
action='store_false',
876+
dest='pooling_norm',
877+
help="Used to determine whether to normalize "
878+
"the pooled data in the embedding model.")
879+
880+
parser.add_argument('--pooling-softmax',
881+
default=None,
882+
action='store_true',
883+
help="Used to determine whether to softmax "
884+
"the pooled data in the embedding model.")
885+
886+
parser.add_argument('--no-pooling-softmax',
887+
default=None,
888+
action='store_false',
889+
dest='pooling_softmax',
890+
help="Used to determine whether to softmax "
891+
"the pooled data in the embedding model.")
892+
893+
parser.add_argument(
894+
'--pooling-step-tag-id',
895+
type=int,
896+
default=None,
897+
help="When pooling-step-tag-id is not -1, it indicates "
898+
"that the score corresponding to the step-tag-ids in the "
899+
"generated sentence should be returned. Otherwise, it "
900+
"returns the scores for all tokens.")
901+
902+
parser.add_argument(
903+
'--pooling-returned-token-ids',
904+
nargs='+',
905+
type=int,
906+
default=None,
907+
help="pooling-returned-token-ids represents a list of "
908+
"indices for the vocabulary dimensions to be extracted, "
909+
"such as the token IDs of good_token and bad_token in "
910+
"the math-shepherd-mistral-7b-prm model.")
911+
853912
return parser
854913

855914
@classmethod
@@ -891,6 +950,11 @@ def create_model_config(self) -> ModelConfig:
891950
override_neuron_config=self.override_neuron_config,
892951
config_format=self.config_format,
893952
mm_processor_kwargs=self.mm_processor_kwargs,
953+
pooling_type=self.pooling_type,
954+
pooling_norm=self.pooling_norm,
955+
pooling_softmax=self.pooling_softmax,
956+
pooling_step_tag_id=self.pooling_step_tag_id,
957+
pooling_returned_token_ids=self.pooling_returned_token_ids,
894958
)
895959

896960
def create_load_config(self) -> LoadConfig:

vllm/engine/llm_engine.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,8 @@ def __init__(
257257
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
258258
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
259259
"use_async_output_proc=%s, use_cached_outputs=%s, "
260-
"chat_template_text_format=%s, mm_processor_kwargs=%s)",
260+
"chat_template_text_format=%s, mm_processor_kwargs=%s, "
261+
"pooler_config=%r)",
261262
VLLM_VERSION,
262263
model_config.model,
263264
speculative_config,
@@ -294,6 +295,7 @@ def __init__(
294295
use_cached_outputs,
295296
model_config.chat_template_text_format,
296297
model_config.mm_processor_kwargs,
298+
model_config.pooler_config,
297299
)
298300
# TODO(woosuk): Print more configs in debug mode.
299301
self.model_config = model_config

vllm/entrypoints/llm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,11 @@ def __init__(
159159
mm_processor_kwargs: Optional[Dict[str, Any]] = None,
160160
# After positional args are removed, move this right below `model`
161161
task: TaskOption = "auto",
162+
pooling_type: Optional[str] = None,
163+
pooling_norm: Optional[bool] = None,
164+
pooling_softmax: Optional[bool] = None,
165+
pooling_step_tag_id: Optional[int] = None,
166+
pooling_returned_token_ids: Optional[List[int]] = None,
162167
**kwargs,
163168
) -> None:
164169
'''
@@ -193,6 +198,11 @@ def __init__(
193198
disable_custom_all_reduce=disable_custom_all_reduce,
194199
disable_async_output_proc=disable_async_output_proc,
195200
mm_processor_kwargs=mm_processor_kwargs,
201+
pooling_type=pooling_type,
202+
pooling_norm=pooling_norm,
203+
pooling_softmax=pooling_softmax,
204+
pooling_step_tag_id=pooling_step_tag_id,
205+
pooling_returned_token_ids=pooling_returned_token_ids,
196206
**kwargs,
197207
)
198208
self.llm_engine = LLMEngine.from_engine_args(

vllm/model_executor/layers/pooler.py

Lines changed: 58 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from enum import IntEnum
2+
from typing import List, Optional
23

34
import torch
45
import torch.nn as nn
56

7+
from vllm.config import PoolerConfig
68
from vllm.model_executor.pooling_metadata import (PoolingMetadata,
79
PoolingTensors)
810
from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput
@@ -13,6 +15,7 @@ class PoolingType(IntEnum):
1315
LAST = 0
1416
ALL = 1
1517
CLS = 2
18+
STEP = 3
1619

1720

1821
class Pooler(nn.Module):
@@ -28,15 +31,47 @@ class Pooler(nn.Module):
2831
normalize: Whether to normalize the pooled data.
2932
"""
3033

31-
def __init__(self,
32-
pooling_type: PoolingType,
33-
normalize: bool,
34-
softmax: bool = False):
34+
def __init__(
35+
self,
36+
pooling_type: PoolingType,
37+
normalize: bool,
38+
softmax: bool,
39+
step_tag_id: Optional[int] = None,
40+
returned_token_ids: Optional[List[int]] = None,
41+
):
3542
super().__init__()
3643

3744
self.pooling_type = pooling_type
3845
self.normalize = normalize
3946
self.softmax = softmax
47+
self.step_tag_id = step_tag_id
48+
self.returned_token_ids = returned_token_ids
49+
50+
@classmethod
51+
def from_config_with_defaults(
52+
cls,
53+
pooler_config: PoolerConfig,
54+
pooling_type: PoolingType,
55+
normalize: bool,
56+
softmax: bool,
57+
step_tag_id: Optional[int] = None,
58+
returned_token_ids: Optional[List[int]] = None,
59+
) -> Optional["Pooler"]:
60+
if pooler_config is None:
61+
return None
62+
return cls(
63+
pooling_type=PoolingType[pooler_config.pooling_type]
64+
if pooler_config.pooling_type is not None else pooling_type,
65+
normalize=pooler_config.pooling_norm
66+
if pooler_config.pooling_norm is not None else normalize,
67+
softmax=pooler_config.pooling_softmax
68+
if pooler_config.pooling_softmax is not None else softmax,
69+
step_tag_id=pooler_config.pooling_step_tag_id
70+
if pooler_config.pooling_step_tag_id is not None else step_tag_id,
71+
returned_token_ids=pooler_config.pooling_returned_token_ids
72+
if pooler_config.pooling_returned_token_ids is not None else
73+
returned_token_ids,
74+
)
4075

4176
def forward(
4277
self,
@@ -62,6 +97,25 @@ def forward(
6297
for prompt_len in prompt_lens:
6398
pooled_data.append(hidden_states[offset:offset + prompt_len])
6499
offset += prompt_len
100+
elif self.pooling_type == PoolingType.STEP:
101+
if self.returned_token_ids is not None and len(
102+
self.returned_token_ids) > 0:
103+
logits = hidden_states[:,
104+
self.returned_token_ids].softmax(dim=-1)
105+
else:
106+
logits = hidden_states.softmax(dim=-1)
107+
offset = 0
108+
pooled_data = []
109+
for prompt_len, seq_data_i in zip(
110+
prompt_lens, pooling_metadata.seq_data.values()):
111+
if self.step_tag_id is None:
112+
pooled_data.append(logits[offset:offset + prompt_len])
113+
else:
114+
step_idxs = torch.tensor(
115+
seq_data_i.prompt_token_ids) == self.step_tag_id
116+
pooled_data.append(logits[offset:offset +
117+
prompt_len][step_idxs])
118+
offset += prompt_len
65119
else:
66120
raise ValueError(f"Invalid pooling type: {self.pooling_type}")
67121

0 commit comments

Comments
 (0)