33import os
44import sys
55from collections import UserList
6+ from enum import Enum
67from typing import Any , Dict , List , Optional , Tuple , TypedDict , TypeVar , Union
78
89import pytest
1415 AutoModelForVision2Seq , AutoTokenizer , BatchEncoding ,
1516 BatchFeature )
1617
17- from tests .models .utils import DecoderPromptType
1818from vllm import LLM , SamplingParams
1919from vllm .assets .image import ImageAsset
2020from vllm .config import TokenizerPoolConfig
2121from vllm .connections import global_http_connection
2222from vllm .distributed import (destroy_distributed_environment ,
2323 destroy_model_parallel )
24- from vllm .inputs import TextPrompt
24+ from vllm .inputs import (ExplicitEncoderDecoderPrompt , TextPrompt ,
25+ to_enc_dec_tuple_list , zip_enc_dec_prompts )
2526from vllm .logger import init_logger
2627from vllm .outputs import RequestOutput
2728from vllm .sequence import SampleLogprobs
2829from vllm .utils import (STR_DTYPE_TO_TORCH_DTYPE , cuda_device_count_stateless ,
29- is_cpu , to_enc_dec_tuple_list ,
30- zip_enc_dec_prompt_lists )
30+ is_cpu )
3131
3232logger = init_logger (__name__ )
3333
@@ -124,10 +124,16 @@ def example_prompts() -> List[str]:
124124 return prompts
125125
126126
127+ class DecoderPromptType (Enum ):
128+ """For encoder/decoder models only."""
129+ CUSTOM = 1
130+ NONE = 2
131+ EMPTY_STR = 3
132+
133+
127134@pytest .fixture
128- def example_encoder_decoder_prompts () \
129- -> Dict [DecoderPromptType ,
130- Tuple [List [str ], List [Optional [str ]]]]:
135+ def example_encoder_decoder_prompts (
136+ ) -> Dict [DecoderPromptType , List [ExplicitEncoderDecoderPrompt ]]:
131137 '''
132138 Returns an encoder prompt list and a decoder prompt list, wherein each pair
133139 of same-index entries in both lists corresponds to an (encoder prompt,
@@ -150,11 +156,11 @@ def example_encoder_decoder_prompts() \
150156 # NONE decoder prompt type
151157 return {
152158 DecoderPromptType .NONE :
153- zip_enc_dec_prompt_lists (encoder_prompts , none_decoder_prompts ),
159+ zip_enc_dec_prompts (encoder_prompts , none_decoder_prompts ),
154160 DecoderPromptType .EMPTY_STR :
155- zip_enc_dec_prompt_lists (encoder_prompts , empty_str_decoder_prompts ),
161+ zip_enc_dec_prompts (encoder_prompts , empty_str_decoder_prompts ),
156162 DecoderPromptType .CUSTOM :
157- zip_enc_dec_prompt_lists (encoder_prompts , custom_decoder_prompts ),
163+ zip_enc_dec_prompts (encoder_prompts , custom_decoder_prompts ),
158164 }
159165
160166
@@ -444,7 +450,7 @@ def generate_greedy_logprobs_limit(
444450
445451 def generate_encoder_decoder_greedy_logprobs_limit (
446452 self ,
447- encoder_decoder_prompts : Tuple [ List [str ], List [ str ]],
453+ encoder_decoder_prompts : List [ExplicitEncoderDecoderPrompt [ str , str ]],
448454 max_tokens : int ,
449455 num_logprobs : int ,
450456 ** kwargs : Any ,
@@ -608,7 +614,7 @@ def generate_w_logprobs(
608614
609615 def generate_encoder_decoder_w_logprobs (
610616 self ,
611- encoder_decoder_prompts : Tuple [ List [str ], List [ str ]],
617+ encoder_decoder_prompts : List [ExplicitEncoderDecoderPrompt [ str , str ]],
612618 sampling_params : SamplingParams ,
613619 ) -> List [Tuple [List [int ], str , Optional [SampleLogprobs ]]]:
614620 '''
@@ -653,7 +659,7 @@ def generate_greedy_logprobs(
653659
654660 def generate_encoder_decoder_greedy_logprobs (
655661 self ,
656- encoder_decoder_prompts : Tuple [ List [str ], List [ str ]],
662+ encoder_decoder_prompts : List [ExplicitEncoderDecoderPrompt [ str , str ]],
657663 max_tokens : int ,
658664 num_logprobs : int ,
659665 ) -> List [Tuple [List [int ], str , Optional [SampleLogprobs ]]]:
0 commit comments