Skip to content

Commit 7eb4a51

Browse files
[Core] Support serving encoder/decoder models (#7258)
1 parent 0fa1490 commit 7eb4a51

File tree

25 files changed

+603
-464
lines changed

25 files changed

+603
-464
lines changed

.github/workflows/mypy.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
- name: Install dependencies
2626
run: |
2727
python -m pip install --upgrade pip
28-
pip install mypy==1.9.0
28+
pip install mypy==1.11.1
2929
pip install types-setuptools
3030
pip install types-PyYAML
3131
pip install types-requests

examples/offline_inference_encoder_decoder.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
'''
55

66
from vllm import LLM, SamplingParams
7-
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
8-
from vllm.utils import zip_enc_dec_prompt_lists
7+
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
8+
TokensPrompt, zip_enc_dec_prompts)
99

1010
dtype = "float"
1111

@@ -61,9 +61,9 @@
6161
)
6262

6363
# - Finally, here's a useful helper function for zipping encoder and
64-
# decoder prompt lists together into a list of ExplicitEncoderDecoderPrompt
64+
# decoder prompts together into a list of ExplicitEncoderDecoderPrompt
6565
# instances
66-
zipped_prompt_list = zip_enc_dec_prompt_lists(
66+
zipped_prompt_list = zip_enc_dec_prompts(
6767
['An encoder prompt', 'Another encoder prompt'],
6868
['A decoder prompt', 'Another decoder prompt'])
6969

requirements-common.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0
1919
tiktoken >= 0.6.0 # Required for DBRX tokenizer
2020
lm-format-enforcer == 0.10.3
2121
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
22-
typing_extensions
22+
typing_extensions >= 4.10
2323
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
2424
pyzmq
2525
gguf == 0.9.1

requirements-lint.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ isort==5.13.2
88
clang-format==18.1.5
99

1010
# type checking
11-
mypy==1.9.0
11+
mypy==1.11.1
1212
types-PyYAML
1313
types-requests
1414
types-setuptools

tests/conftest.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import sys
55
from collections import UserList
6+
from enum import Enum
67
from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union
78

89
import pytest
@@ -14,20 +15,19 @@
1415
AutoModelForVision2Seq, AutoTokenizer, BatchEncoding,
1516
BatchFeature)
1617

17-
from tests.models.utils import DecoderPromptType
1818
from vllm import LLM, SamplingParams
1919
from vllm.assets.image import ImageAsset
2020
from vllm.config import TokenizerPoolConfig
2121
from vllm.connections import global_http_connection
2222
from vllm.distributed import (destroy_distributed_environment,
2323
destroy_model_parallel)
24-
from vllm.inputs import TextPrompt
24+
from vllm.inputs import (ExplicitEncoderDecoderPrompt, TextPrompt,
25+
to_enc_dec_tuple_list, zip_enc_dec_prompts)
2526
from vllm.logger import init_logger
2627
from vllm.outputs import RequestOutput
2728
from vllm.sequence import SampleLogprobs
2829
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless,
29-
is_cpu, to_enc_dec_tuple_list,
30-
zip_enc_dec_prompt_lists)
30+
is_cpu)
3131

3232
logger = init_logger(__name__)
3333

@@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
124124
return prompts
125125

126126

127+
class DecoderPromptType(Enum):
128+
"""For encoder/decoder models only."""
129+
CUSTOM = 1
130+
NONE = 2
131+
EMPTY_STR = 3
132+
133+
127134
@pytest.fixture
128-
def example_encoder_decoder_prompts() \
129-
-> Dict[DecoderPromptType,
130-
Tuple[List[str], List[Optional[str]]]]:
135+
def example_encoder_decoder_prompts(
136+
) -> Dict[DecoderPromptType, List[ExplicitEncoderDecoderPrompt]]:
131137
'''
132138
Returns an encoder prompt list and a decoder prompt list, wherein each pair
133139
of same-index entries in both lists corresponds to an (encoder prompt,
@@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
150156
# NONE decoder prompt type
151157
return {
152158
DecoderPromptType.NONE:
153-
zip_enc_dec_prompt_lists(encoder_prompts, none_decoder_prompts),
159+
zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
154160
DecoderPromptType.EMPTY_STR:
155-
zip_enc_dec_prompt_lists(encoder_prompts, empty_str_decoder_prompts),
161+
zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
156162
DecoderPromptType.CUSTOM:
157-
zip_enc_dec_prompt_lists(encoder_prompts, custom_decoder_prompts),
163+
zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
158164
}
159165

160166

@@ -444,7 +450,7 @@ def generate_greedy_logprobs_limit(
444450

445451
def generate_encoder_decoder_greedy_logprobs_limit(
446452
self,
447-
encoder_decoder_prompts: Tuple[List[str], List[str]],
453+
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
448454
max_tokens: int,
449455
num_logprobs: int,
450456
**kwargs: Any,
@@ -608,7 +614,7 @@ def generate_w_logprobs(
608614

609615
def generate_encoder_decoder_w_logprobs(
610616
self,
611-
encoder_decoder_prompts: Tuple[List[str], List[str]],
617+
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
612618
sampling_params: SamplingParams,
613619
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:
614620
'''
@@ -653,7 +659,7 @@ def generate_greedy_logprobs(
653659

654660
def generate_encoder_decoder_greedy_logprobs(
655661
self,
656-
encoder_decoder_prompts: Tuple[List[str], List[str]],
662+
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
657663
max_tokens: int,
658664
num_logprobs: int,
659665
) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]:

tests/distributed/test_basic_distributed_correctness_enc_dec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111

1212
import pytest
1313

14-
from tests.models.utils import DecoderPromptType
1514
from vllm.utils import cuda_device_count_stateless
1615

16+
from ..conftest import DecoderPromptType
1717
from ..models.utils import check_logprobs_close
1818
from ..utils import fork_new_process_for_each_test
1919

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import openai
2+
import pytest
3+
4+
from ...utils import RemoteOpenAIServer
5+
6+
MODEL_NAME = "facebook/bart-base"
7+
8+
9+
@pytest.fixture(scope="module")
10+
def server():
11+
args = [
12+
"--dtype",
13+
"bfloat16",
14+
"--enforce-eager",
15+
]
16+
17+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
18+
yield remote_server
19+
20+
21+
@pytest.fixture(scope="module")
22+
def client(server):
23+
return server.get_async_client()
24+
25+
26+
@pytest.mark.asyncio
27+
@pytest.mark.parametrize("model_name", [MODEL_NAME])
28+
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
29+
completion = await client.completions.create(model=model_name,
30+
prompt="Hello, my name is",
31+
max_tokens=5,
32+
temperature=0.0)
33+
34+
assert completion.id is not None
35+
assert completion.choices is not None and len(completion.choices) == 1
36+
37+
choice = completion.choices[0]
38+
assert len(choice.text) >= 5
39+
assert choice.finish_reason == "length"
40+
assert completion.usage == openai.types.CompletionUsage(
41+
completion_tokens=5, prompt_tokens=2, total_tokens=7)
42+
43+
# test using token IDs
44+
completion = await client.completions.create(
45+
model=model_name,
46+
prompt=[0, 0, 0, 0, 0],
47+
max_tokens=5,
48+
temperature=0.0,
49+
)
50+
assert len(completion.choices[0].text) >= 1

tests/models/test_bart.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
33
Run `pytest tests/models/test_bart.py`.
44
"""
5+
from typing import List, Optional, Tuple
6+
57
from vllm.utils import is_cpu
68

79
if not is_cpu():
@@ -11,22 +13,31 @@
1113

1214
import pytest
1315

14-
from tests.models.utils import DecoderPromptType
16+
from vllm.sequence import SampleLogprobs
1517

18+
from ..conftest import DecoderPromptType
1619
from .utils import check_logprobs_close
1720

1821
MODELS = ["facebook/bart-base", "facebook/bart-large-cnn"]
1922

20-
DECODER_PROMPT_TYPES = ([
21-
DecoderPromptType.CUSTOM, DecoderPromptType.EMPTY_STR,
22-
DecoderPromptType.NONE
23-
])
23+
def vllm_to_hf_output(
24+
vllm_output: Tuple[List[int], str, Optional[SampleLogprobs]],
25+
decoder_prompt_type: DecoderPromptType,
26+
):
27+
"""Sanitize vllm output to be comparable with hf output."""
28+
output_ids, output_str, out_logprobs = vllm_output
29+
30+
hf_output_str = output_str + "</s>"
31+
if decoder_prompt_type == DecoderPromptType.NONE:
32+
hf_output_str = "<s>" + hf_output_str
33+
34+
return output_ids, hf_output_str, out_logprobs
2435

2536
@pytest.mark.parametrize("model", MODELS)
2637
@pytest.mark.parametrize("dtype", ["float", "bfloat16"])
2738
@pytest.mark.parametrize("max_tokens", [64])
2839
@pytest.mark.parametrize("num_logprobs", [5])
29-
@pytest.mark.parametrize("decoder_prompt_type", DECODER_PROMPT_TYPES)
40+
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
3041
def test_models(
3142
hf_runner,
3243
vllm_runner,
@@ -146,8 +157,13 @@ def test_models(
146157
hf_skip_tokens = (1 if decoder_prompt_type == DecoderPromptType.NONE
147158
else 0)
148159

149-
check_logprobs_close(outputs_0_lst=hf_outputs,
150-
outputs_1_lst=vllm_outputs,
151-
name_0="hf",
152-
name_1="vllm",
153-
num_outputs_0_skip_tokens=hf_skip_tokens)
160+
check_logprobs_close(
161+
outputs_0_lst=hf_outputs,
162+
outputs_1_lst=[
163+
vllm_to_hf_output(vllm_output, decoder_prompt_type)
164+
for vllm_output in vllm_outputs
165+
],
166+
name_0="hf",
167+
name_1="vllm",
168+
num_outputs_0_skip_tokens=hf_skip_tokens,
169+
)

tests/models/utils.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import warnings
2-
from enum import Enum
32
from typing import Dict, List, Optional, Sequence, Tuple, Union
43

54
from vllm.sequence import SampleLogprobs
@@ -136,13 +135,3 @@ def check_logprobs_close(
136135
warnings.simplefilter("always")
137136

138137
warnings.warn(fail_msg, stacklevel=2)
139-
140-
141-
class DecoderPromptType(Enum):
142-
'''
143-
For encoder/decoder models only -
144-
145-
'''
146-
CUSTOM = 1
147-
NONE = 2
148-
EMPTY_STR = 3

tests/test_inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import pytest
44

5-
from vllm.inputs import parse_and_batch_prompt
5+
from vllm.inputs.parse import parse_and_batch_prompt
66

77
STRING_INPUTS = [
88
'',

0 commit comments

Comments
 (0)