diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 5ab46da90ea6..0880a4530d8c 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -379,6 +379,10 @@ For chat-like input (i.e. if `messages` is passed), these extra parameters are s Our Transcriptions API is compatible with [OpenAI's Transcriptions API](https://platform.openai.com/docs/api-reference/audio/createTranscription); you can use the [official OpenAI Python client](https://github.com/openai/openai-python) to interact with it. +:::{note} +To use the Transcriptions API, please install with extra audio dependencies using `pip install vllm[audio]`. +::: + Code example: diff --git a/examples/online_serving/openai_transcription_client.py b/examples/online_serving/openai_transcription_client.py index bd3c02a8a95e..494e7c8ebe12 100644 --- a/examples/online_serving/openai_transcription_client.py +++ b/examples/online_serving/openai_transcription_client.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 +import asyncio +import json + +import httpx from openai import OpenAI from vllm.assets.audio import AudioAsset @@ -13,11 +17,50 @@ api_key=openai_api_key, base_url=openai_api_base, ) -with open(str(mary_had_lamb), "rb") as f: - transcription = client.audio.transcriptions.create( - file=f, - model="openai/whisper-large-v3", - language="en", - response_format="text", - temperature=0.0) - print("transcription result:", transcription) + + +def sync_openai(): + with open(str(mary_had_lamb), "rb") as f: + transcription = client.audio.transcriptions.create( + file=f, + model="openai/whisper-small", + language="en", + response_format="json", + temperature=0.0) + print("transcription result:", transcription.text) + + +sync_openai() + + +# OpenAI Transcription API client does not support streaming. +async def stream_openai_response(): + data = { + "language": "en", + 'stream': True, + "model": "openai/whisper-large-v3", + } + url = openai_api_base + "/audio/transcriptions" + print("transcription result:", end=' ') + async with httpx.AsyncClient() as client: + with open(str(winning_call), "rb") as f: + async with client.stream('POST', url, files={'file': f}, + data=data) as response: + async for line in response.aiter_lines(): + # Each line is a JSON object prefixed with 'data: ' + if line: + if line.startswith('data: '): + line = line[len('data: '):] + # Last chunk, stream ends + if line.strip() == '[DONE]': + break + # Parse the JSON response + chunk = json.loads(line) + # Extract and print the content + content = chunk['choices'][0].get('delta', + {}).get('content') + print(content, end='') + + +# Run the asynchronous function +asyncio.run(stream_openai_response()) diff --git a/tests/entrypoints/openai/test_transcription_validation.py b/tests/entrypoints/openai/test_transcription_validation.py index 5d4a5de4badd..29571bcd7649 100644 --- a/tests/entrypoints/openai/test_transcription_validation.py +++ b/tests/entrypoints/openai/test_transcription_validation.py @@ -3,12 +3,14 @@ # imports for guided decoding tests import io import json +from unittest.mock import patch import librosa import numpy as np import openai import pytest import soundfile as sf +from openai._base_client import AsyncAPIClient from vllm.assets.audio import AudioAsset @@ -120,3 +122,73 @@ async def test_completion_endpoints(): res = await client.completions.create(model=model_name, prompt="Hello") assert res.code == 400 assert res.message == "The model does not support Completions API" + + +@pytest.mark.asyncio +async def test_streaming_response(winning_call): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + transcription = "" + with RemoteOpenAIServer(model_name, server_args) as remote_server: + client = remote_server.get_async_client() + res_no_stream = await client.audio.transcriptions.create( + model=model_name, + file=winning_call, + response_format="json", + language="en", + temperature=0.0) + # Unfortunately this only works when the openai client is patched + # to use streaming mode, not exposed in the transcription api. + original_post = AsyncAPIClient.post + + async def post_with_stream(*args, **kwargs): + kwargs['stream'] = True + return await original_post(*args, **kwargs) + + with patch.object(AsyncAPIClient, "post", new=post_with_stream): + client = remote_server.get_async_client() + res = await client.audio.transcriptions.create( + model=model_name, + file=winning_call, + language="en", + temperature=0.0, + extra_body=dict(stream=True)) + # Reconstruct from chunks and validate + async for chunk in res: + # just a chunk + text = chunk.choices[0]['delta']['content'] + transcription += text + + assert transcription == res_no_stream.text + + +@pytest.mark.asyncio +async def test_stream_options(winning_call): + model_name = "openai/whisper-small" + server_args = ["--enforce-eager"] + with RemoteOpenAIServer(model_name, server_args) as remote_server: + original_post = AsyncAPIClient.post + + async def post_with_stream(*args, **kwargs): + kwargs['stream'] = True + return await original_post(*args, **kwargs) + + with patch.object(AsyncAPIClient, "post", new=post_with_stream): + client = remote_server.get_async_client() + res = await client.audio.transcriptions.create( + model=model_name, + file=winning_call, + language="en", + temperature=0.0, + extra_body=dict(stream=True, + stream_include_usage=True, + stream_continuous_usage_stats=True)) + final = False + continuous = True + async for chunk in res: + if not len(chunk.choices): + # final usage sent + final = True + else: + continuous = continuous and hasattr(chunk, 'usage') + assert final and continuous diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 2c740caf20fb..6b519e1b7041 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -1289,6 +1289,21 @@ class ChatCompletionStreamResponse(OpenAIBaseModel): usage: Optional[UsageInfo] = Field(default=None) +class TranscriptionResponseStreamChoice(OpenAIBaseModel): + delta: DeltaMessage + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None + + +class TranscriptionStreamResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"trsc-{random_uuid()}") + object: Literal["transcription.chunk"] = "transcription.chunk" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + choices: list[TranscriptionResponseStreamChoice] + usage: Optional[UsageInfo] = Field(default=None) + + class BatchRequestInput(OpenAIBaseModel): """ The per-line object of the batch input file. @@ -1514,6 +1529,15 @@ class TranscriptionRequest(OpenAIBaseModel): timestamps incurs additional latency. """ + stream: Optional[bool] = False + """Custom field not present in the original OpenAI definition. When set, + it will enable output to be streamed in a similar fashion as the Chat + Completion endpoint. + """ + # Flattened stream option to simplify form data. + stream_include_usage: Optional[bool] = False + stream_continuous_usage_stats: Optional[bool] = False + # Default sampling parameters for transcription requests. _DEFAULT_SAMPLING_PARAMS: dict = { "temperature": 0, @@ -1534,7 +1558,21 @@ def to_sampling_params( "temperature", self._DEFAULT_SAMPLING_PARAMS["temperature"]) return SamplingParams.from_optional(temperature=temperature, - max_tokens=max_tokens) + max_tokens=max_tokens, + output_kind=RequestOutputKind.DELTA + if self.stream \ + else RequestOutputKind.FINAL_ONLY) + + @model_validator(mode="before") + @classmethod + def validate_stream_options(cls, data): + stream_opts = ["stream_include_usage", "stream_continuous_usage_stats"] + stream = data.get("stream", False) + if any(bool(data.get(so, False)) for so in stream_opts) and not stream: + raise ValueError( + "Stream options can only be defined when `stream=True`.") + + return data # Transcription response objects diff --git a/vllm/entrypoints/openai/serving_transcription.py b/vllm/entrypoints/openai/serving_transcription.py index 402a0bb7a6b0..13565d0ef8dd 100644 --- a/vllm/entrypoints/openai/serving_transcription.py +++ b/vllm/entrypoints/openai/serving_transcription.py @@ -1,24 +1,26 @@ # SPDX-License-Identifier: Apache-2.0 import asyncio import io +import time from collections.abc import AsyncGenerator -from typing import Optional, Union, cast +from math import ceil +from typing import Final, Optional, Union, cast from fastapi import Request from vllm.config import ModelConfig from vllm.engine.protocol import EngineClient from vllm.entrypoints.logger import RequestLogger -from vllm.entrypoints.openai.protocol import (ErrorResponse, - RequestResponseMetadata, - TranscriptionRequest, - TranscriptionResponse, - TranscriptionResponseVerbose) +from vllm.entrypoints.openai.protocol import ( + DeltaMessage, ErrorResponse, RequestResponseMetadata, TranscriptionRequest, + TranscriptionResponse, TranscriptionResponseStreamChoice, + TranscriptionStreamResponse, UsageInfo) from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import PromptType from vllm.logger import init_logger from vllm.outputs import RequestOutput +from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import PlaceholderModule try: @@ -140,8 +142,6 @@ # As per https://platform.openai.com/docs/guides/speech-to-text#overview. # TODO configurable MAX_AUDIO_CLIP_FILESIZE_MB = 25 -# TODO get from processor.feature_extractor.chunk_length -MAX_AUDIO_CLIP_DURATION_S = 30 class OpenAIServingTranscription(OpenAIServing): @@ -163,6 +163,11 @@ def __init__( self.default_sampling_params = ( self.model_config.get_diff_sampling_param()) + processor = cached_get_processor(model_config.model) + self.max_audio_clip_s = processor.feature_extractor.chunk_length + self.model_sr = processor.feature_extractor.sampling_rate + self.hop_length = processor.feature_extractor.hop_length + if self.default_sampling_params: logger.info( "Overwriting default completion sampling param with: %s", @@ -172,7 +177,7 @@ async def _preprocess_transcription( self, request: TranscriptionRequest, audio_data: bytes, - ) -> PromptType: + ) -> tuple[PromptType, float]: # Validate request # TODO language should be optional and can be guessed. # For now we default to en. See @@ -198,9 +203,11 @@ async def _preprocess_transcription( with io.BytesIO(audio_data) as bytes_: y, sr = librosa.load(bytes_) - if librosa.get_duration(y=y, sr=sr) > MAX_AUDIO_CLIP_DURATION_S: + + duration = librosa.get_duration(y=y, sr=sr) + if duration > self.max_audio_clip_s: raise ValueError( - f"Maximum clip duration ({MAX_AUDIO_CLIP_DURATION_S}s) " + f"Maximum clip duration ({self.max_audio_clip_s}s) " "exceeded.") prompt = { @@ -213,13 +220,13 @@ async def _preprocess_transcription( "decoder_prompt": f"<|startoftranscript|>{lang_token}<|transcribe|><|notimestamps|>{request.prompt}" } - return cast(PromptType, prompt) + return cast(PromptType, prompt), duration # TODO (varun) : Make verbose response work ! async def create_transcription( self, audio_data: bytes, request: TranscriptionRequest, raw_request: Request - ) -> Union[TranscriptionResponse, TranscriptionResponseVerbose, + ) -> Union[TranscriptionResponse, AsyncGenerator[str, None], ErrorResponse]: """Transcription API similar to OpenAI's API. @@ -240,8 +247,7 @@ async def create_transcription( return self.create_error_response( "Currently only support response_format `text` or `json`") - # TODO cmpl->transcription? - request_id = f"cmpl-{self._base_request_id(raw_request)}" + request_id = f"trsc-{self._base_request_id(raw_request)}" request_metadata = RequestResponseMetadata(request_id=request_id) if raw_request: @@ -261,7 +267,7 @@ async def create_transcription( "Currently do not support PromptAdapter for Transcription." ) - prompt = await self._preprocess_transcription( + prompt, duration_s = await self._preprocess_transcription( request=request, audio_data=audio_data, ) @@ -293,7 +299,12 @@ async def create_transcription( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) - # TODO(rob): figure out a way to pipe streaming in. + if request.stream: + return self.transcription_stream_generator(request, + result_generator, + request_id, + request_metadata, + duration_s) # Non-streaming response. try: assert result_generator is not None @@ -305,3 +316,106 @@ async def create_transcription( except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) + + async def transcription_stream_generator( + self, request: TranscriptionRequest, + result_generator: AsyncGenerator[RequestOutput, None], + request_id: str, request_metadata: RequestResponseMetadata, + audio_duration_s: float) -> AsyncGenerator[str, None]: + created_time = int(time.time()) + model_name = request.model + chunk_object_type: Final = "transcription.chunk" + + completion_tokens = 0 + num_prompt_tokens = 0 + + include_usage = request.stream_include_usage \ + if request.stream_include_usage else False + include_continuous_usage = request.stream_continuous_usage_stats\ + if include_usage and request.stream_continuous_usage_stats\ + else False + + try: + async for res in result_generator: + # On first result. + if res.prompt_token_ids is not None: + # Do not account the 4-tokens `<|startoftranscript|>..` + # Could be negative when language token is not specified. + num_prompt_tokens = max(len(res.prompt_token_ids) - 4, 0) + # NOTE(NickLucche) user can't pass encoder prompts directly + # at least not to Whisper. One indicator of the encoder + # amount of processing is the log-mel spectogram length. + num_prompt_tokens += ceil(audio_duration_s * + self.model_sr / self.hop_length) + + # We need to do it here, because if there are exceptions in + # the result_generator, it needs to be sent as the FIRST + # response (by the try...catch). + + # Just one output (n=1) supported. + assert len(res.outputs) == 1 + output = res.outputs[0] + + delta_message = DeltaMessage(content=output.text) + completion_tokens += len(output.token_ids) + + if output.finish_reason is None: + # Still generating, send delta update. + choice_data = TranscriptionResponseStreamChoice( + delta=delta_message) + else: + # Model is finished generating. + choice_data = TranscriptionResponseStreamChoice( + delta=delta_message, + finish_reason=output.finish_reason, + stop_reason=output.stop_reason) + + chunk = TranscriptionStreamResponse(id=request_id, + object=chunk_object_type, + created=created_time, + choices=[choice_data], + model=model_name) + + # handle usage stats if requested & if continuous + if include_continuous_usage: + chunk.usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens, + ) + + data = chunk.model_dump_json(exclude_unset=True) + yield f"data: {data}\n\n" + + # Once the final token is handled, if stream_options.include_usage + # is sent, send the usage. + if include_usage: + final_usage = UsageInfo(prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + + completion_tokens) + + final_usage_chunk = TranscriptionStreamResponse( + id=request_id, + object=chunk_object_type, + created=created_time, + choices=[], + model=model_name, + usage=final_usage) + final_usage_data = (final_usage_chunk.model_dump_json( + exclude_unset=True, exclude_none=True)) + yield f"data: {final_usage_data}\n\n" + + # report to FastAPI middleware aggregate usage across all choices + request_metadata.final_usage_info = UsageInfo( + prompt_tokens=num_prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=num_prompt_tokens + completion_tokens) + + except Exception as e: + # TODO: Use a vllm-specific Validation Error + logger.exception("Error in chat completion stream generator.") + data = self.create_streaming_error_response(str(e)) + yield f"data: {data}\n\n" + # Send the final done message after all response.n are finished + yield "data: [DONE]\n\n"